use super::{
common::{
calculate_fold_chunk_start_rounds, calculate_fold_commit_rounds, calculate_folding_arities,
FinalMessage,
},
error::Error,
};
use crate::{
linear_code::{LinearCode, LinearCodeWithExtensionEncoding},
merkle_tree::VectorCommitScheme,
protocols::fri::common::{fold_chunk, QueryProof, QueryRoundProof},
reed_solomon::reed_solomon::ReedSolomonCode,
};
use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedFieldIndexable};
use binius_ntt::AdditiveNTT;
use binius_utils::bail;
use itertools::izip;
use rayon::prelude::*;
use tracing::instrument;
fn fold_codeword<F, FS>(
rs_code: &ReedSolomonCode<FS>,
codeword: &[F],
round: usize,
folding_challenges: &[F],
) -> Vec<F>
where
F: BinaryField + ExtensionField<FS>,
FS: BinaryField,
{
assert!(codeword.len() % (1 << folding_challenges.len()) == 0);
assert!(round + 1 >= folding_challenges.len());
assert!(round < rs_code.log_dim());
assert!(!folding_challenges.is_empty());
let start_round = round + 1 - folding_challenges.len();
let chunk_size = 1 << folding_challenges.len();
codeword
.par_chunks(chunk_size)
.enumerate()
.map_init(
|| vec![F::default(); chunk_size],
|scratch_buffer, (chunk_index, chunk)| {
fold_chunk(
rs_code,
start_round,
chunk_index,
chunk,
folding_challenges,
scratch_buffer,
)
},
)
.collect()
}
#[derive(Debug)]
pub struct CommitOutput<P, VCSCommitment, VCSCommitted> {
pub commitment: VCSCommitment,
pub committed: VCSCommitted,
pub codeword: Vec<P>,
}
pub fn commit_message<F, FA, P, PA, VCS>(
rs_code: &ReedSolomonCode<PA>,
vcs: &VCS,
message: &[P],
) -> Result<CommitOutput<P, VCS::Commitment, VCS::Committed>, Error>
where
F: BinaryField + ExtensionField<FA>,
FA: BinaryField,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<FA, PackedSubfield = PA>,
PA: PackedFieldIndexable<Scalar = FA>,
VCS: VectorCommitScheme<F>,
{
if message.len() * P::WIDTH != rs_code.dim() {
bail!(Error::InvalidArgs("message length does not match code dimension".to_string()));
}
if vcs.vector_len() != rs_code.len() {
bail!(Error::InvalidArgs("code length does not vector commitment length".to_string(),));
}
let mut encoded = vec![P::zero(); message.len() << rs_code.log_inv_rate()];
encoded[..message.len()].copy_from_slice(message);
rs_code.encode_extension_inplace(&mut encoded)?;
let (commitment, vcs_committed) = vcs
.commit_batch(&[P::unpack_scalars(&encoded)])
.map_err(|err| Error::VectorCommit(Box::new(err)))?;
Ok(CommitOutput {
commitment,
committed: vcs_committed,
codeword: encoded,
})
}
pub enum FoldRoundOutput<VCSCommitment> {
NoCommitment,
Commitment(VCSCommitment),
}
pub struct FRIFolder<'a, F, FA, VCS>
where
FA: BinaryField,
F: BinaryField,
VCS: VectorCommitScheme<F>,
{
committed_rs_code: &'a ReedSolomonCode<FA>,
final_rs_code: &'a ReedSolomonCode<F>,
codeword: &'a [F],
codeword_vcs: &'a VCS,
round_vcss: &'a [VCS],
codeword_committed: &'a VCS::Committed,
round_committed: Vec<(Vec<F>, VCS::Committed)>,
curr_round: usize,
unprocessed_challenges: Vec<F>,
commitment_fold_rounds: Vec<usize>,
}
impl<'a, F, FA, VCS> FRIFolder<'a, F, FA, VCS>
where
F: BinaryField + ExtensionField<FA>,
FA: BinaryField,
VCS: VectorCommitScheme<F> + Sync,
VCS::Committed: Send + Sync,
{
pub fn new(
committed_rs_code: &'a ReedSolomonCode<FA>,
final_rs_code: &'a ReedSolomonCode<F>,
committed_codeword: &'a [F],
committed_codeword_vcs: &'a VCS,
round_vcss: &'a [VCS],
committed: &'a VCS::Committed,
) -> Result<Self, Error> {
if committed_rs_code.len() != committed_codeword.len() {
bail!(Error::InvalidArgs(
"Reed–Solomon code length must match codeword length".to_string(),
));
}
let commitment_fold_rounds = calculate_fold_commit_rounds(
committed_rs_code,
final_rs_code,
committed_codeword_vcs,
round_vcss,
)?;
Ok(Self {
committed_rs_code,
codeword: committed_codeword,
codeword_vcs: committed_codeword_vcs,
round_vcss,
codeword_committed: committed,
round_committed: Vec::with_capacity(round_vcss.len()),
curr_round: 0,
unprocessed_challenges: Vec::with_capacity(committed_rs_code.log_dim()),
commitment_fold_rounds,
final_rs_code,
})
}
pub fn n_rounds(&self) -> usize {
self.committed_rs_code.log_dim()
}
pub fn curr_round(&self) -> usize {
self.curr_round
}
fn prev_codeword(&self) -> &[F] {
self.round_committed
.last()
.map(|(codeword, _)| codeword.as_slice())
.unwrap_or(self.codeword)
}
fn is_commitment_round(&self) -> bool {
let n_commitments = self.round_committed.len();
n_commitments < self.round_vcss.len()
&& self.commitment_fold_rounds[n_commitments] == self.curr_round
}
#[instrument(skip_all, name = "fri::FRIFolder::execute_fold_round")]
pub fn execute_fold_round(
&mut self,
challenge: F,
) -> Result<FoldRoundOutput<VCS::Commitment>, Error> {
self.unprocessed_challenges.push(challenge);
if !self.is_commitment_round() {
self.curr_round += 1;
return Ok(FoldRoundOutput::NoCommitment);
}
let folded_codeword = fold_codeword(
self.committed_rs_code,
self.prev_codeword(),
self.curr_round,
&self.unprocessed_challenges,
);
self.unprocessed_challenges.clear();
let round_vcs = self
.round_vcss
.get(self.round_committed.len())
.ok_or_else(|| Error::TooManyFoldExecutions {
max_folds: self.round_vcss.len() - 1,
})?;
let (commitment, committed) = round_vcs
.commit_batch(&[&folded_codeword])
.map_err(|err| Error::VectorCommit(Box::new(err)))?;
self.round_committed.push((folded_codeword, committed));
self.curr_round += 1;
Ok(FoldRoundOutput::Commitment(commitment))
}
#[instrument(skip_all, name = "fri::FRIFolder::finalize")]
pub fn finalize(mut self) -> Result<(FinalMessage<F>, FRIQueryProver<'a, F, VCS>), Error> {
if self.curr_round != self.n_rounds() {
bail!(Error::EarlyProverFinish);
}
let mut final_codeword = if self.unprocessed_challenges.is_empty() {
self.prev_codeword()[..1 << self.final_rs_code.log_dim()].to_vec()
} else {
let unfolded_codeword_len =
1 << (self.unprocessed_challenges.len() + self.final_rs_code.log_dim());
let unfolded_codeword = &self.prev_codeword()[..unfolded_codeword_len];
fold_codeword(
self.committed_rs_code,
unfolded_codeword,
self.curr_round - 1,
&self.unprocessed_challenges,
)
};
self.final_rs_code
.get_ntt()
.inverse_transform(&mut final_codeword, 0, 0)?;
let final_message = final_codeword;
self.unprocessed_challenges.clear();
let Self {
codeword,
codeword_vcs,
round_vcss,
codeword_committed,
round_committed,
commitment_fold_rounds,
committed_rs_code,
..
} = self;
let query_prover = FRIQueryProver {
codeword,
codeword_vcs,
round_vcss,
codeword_committed,
round_committed,
commitment_fold_rounds,
n_fold_rounds: committed_rs_code.log_dim(),
};
Ok((final_message, query_prover))
}
}
pub struct FRIQueryProver<'a, F: BinaryField, VCS: VectorCommitScheme<F>> {
codeword: &'a [F],
codeword_vcs: &'a VCS,
round_vcss: &'a [VCS],
codeword_committed: &'a VCS::Committed,
round_committed: Vec<(Vec<F>, VCS::Committed)>,
commitment_fold_rounds: Vec<usize>,
n_fold_rounds: usize,
}
impl<'a, F: BinaryField, VCS: VectorCommitScheme<F>> FRIQueryProver<'a, F, VCS> {
pub fn n_rounds(&self) -> usize {
self.round_vcss.len() + 1
}
#[instrument(skip_all, name = "fri::FRIQueryProver::prove_query")]
pub fn prove_query(&self, index: usize) -> Result<QueryProof<F, VCS::Proof>, Error> {
let mut round_proofs = Vec::with_capacity(self.n_rounds());
let fold_chunk_start_rounds =
calculate_fold_chunk_start_rounds(&self.commitment_fold_rounds);
let folding_arities =
calculate_folding_arities(self.n_fold_rounds, &fold_chunk_start_rounds);
let mut coset_index = index >> folding_arities[0];
round_proofs.push(prove_coset_opening(
self.codeword_vcs,
self.codeword,
self.codeword_committed,
coset_index,
folding_arities[0],
)?);
for (query_rd, vcs, (codeword, committed)) in
izip!((1..=self.round_vcss.len()), self.round_vcss.iter(), self.round_committed.iter())
{
coset_index >>= folding_arities[query_rd];
round_proofs.push(prove_coset_opening(
vcs,
codeword,
committed,
coset_index,
folding_arities[query_rd],
)?);
}
Ok(round_proofs)
}
}
fn prove_coset_opening<F: BinaryField, VCS: VectorCommitScheme<F>>(
vcs: &VCS,
codeword: &[F],
committed: &VCS::Committed,
coset_index: usize,
log_coset_size: usize,
) -> Result<QueryRoundProof<F, VCS::Proof>, Error> {
let start_index = coset_index << log_coset_size;
let range = start_index..start_index + (1 << log_coset_size);
let vcs_proof = vcs
.prove_range_batch_opening(committed, range.clone())
.map_err(|err| Error::VectorCommit(Box::new(err)))?;
Ok(QueryRoundProof {
values: codeword[range].to_vec(),
vcs_proof,
})
}