use super::ring_switch::reduce_tensor_claim;
use crate::{
challenger::{CanObserve, CanSample, CanSampleBits},
composition::BivariateProduct,
merkle_tree_vcs::{MerkleTreeProver, MerkleTreeScheme},
poly_commit::PolyCommitScheme,
polynomial::{Error as PolynomialError, MultivariatePoly},
protocols::{
fri::{self, FRIFolder, FRIParams, FRIVerifier, FoldRoundOutput},
sumcheck::{
self, immediate_switchover_heuristic,
prove::{RegularSumcheckProver, SumcheckProver},
verify::interpolate_round_proof,
RoundCoeffs, RoundProof, SumcheckClaim,
},
},
reed_solomon::reed_solomon::ReedSolomonCode,
tensor_algebra::TensorAlgebra,
transcript::{AdviceReader, AdviceWriter, CanRead, CanWrite},
transparent::ring_switch::RingSwitchEqInd,
};
use binius_field::{
packed::iter_packed_slice, BinaryField, ExtensionField, Field, PackedExtension, PackedField,
PackedFieldIndexable, TowerField,
};
use binius_hal::{ComputationBackend, ComputationBackendExt};
use binius_math::{EvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension};
use binius_ntt::NTTOptions;
use binius_utils::{bail, checked_arithmetics::checked_log_2};
use std::{fmt::Debug, iter, marker::PhantomData, mem, ops::Deref};
use tracing::instrument;
#[derive(Debug)]
pub struct FRIPCS<F, FDomain, FEncode, PE, DomainFactory, MerkleProver, VCS>
where
F: Field,
FDomain: Field,
FEncode: BinaryField,
PE: PackedField + PackedExtension<FEncode>,
PE::Scalar: BinaryField
+ ExtensionField<F>
+ PackedField<Scalar = PE::Scalar>
+ ExtensionField<FDomain>
+ ExtensionField<FEncode>,
{
fri_params: FRIParams<PE::Scalar, FEncode>,
merkle_prover: MerkleProver,
rs_encoder: ReedSolomonCode<<PE as PackedExtension<FEncode>>::PackedSubfield>,
domain_factory: DomainFactory,
_marker: PhantomData<(F, FDomain, PE, VCS)>,
}
impl<F, FDomain, FEncode, FExt, PE, DomainFactory, MerkleProver, DigestType, VCS>
FRIPCS<F, FDomain, FEncode, PE, DomainFactory, MerkleProver, VCS>
where
F: Field,
FDomain: Field,
FEncode: BinaryField,
FExt: TowerField
+ PackedField<Scalar = FExt>
+ ExtensionField<F>
+ ExtensionField<FDomain>
+ ExtensionField<FEncode>,
PE: PackedFieldIndexable<Scalar = FExt> + PackedExtension<FEncode>,
MerkleProver: MerkleTreeProver<FExt, Scheme = VCS> + Sync,
DigestType: PackedField<Scalar: TowerField>,
VCS: MerkleTreeScheme<FExt, Digest = DigestType, Proof = Vec<DigestType>>,
{
pub fn new(
n_vars: usize,
log_inv_rate: usize,
fold_arities: Vec<usize>,
security_bits: usize,
merkle_prover: MerkleProver,
domain_factory: DomainFactory,
ntt_options: NTTOptions,
) -> Result<Self, Error> {
let kappa = checked_log_2(<FExt as ExtensionField<F>>::DEGREE);
let n_packed_vars = n_vars
.checked_sub(kappa)
.ok_or(Error::IncorrectPolynomialSize { expected: kappa })?;
if !fold_arities.is_empty() {
if fold_arities.iter().sum::<usize>() >= n_packed_vars {
bail!(fri::Error::InvalidFoldAritySequence);
}
for &arity in fold_arities.iter() {
if arity == 0 {
bail!(fri::Error::FoldArityIsZero { index: 0 });
}
}
}
let log_batch_size = fold_arities.first().copied().unwrap_or(0);
let log_dim = n_packed_vars - log_batch_size;
let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, NTTOptions::default())?;
let n_test_queries = fri::calculate_n_test_queries::<FExt, _>(security_bits, &rs_code)?;
let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?;
let rs_encoder = ReedSolomonCode::new(log_dim, log_inv_rate, ntt_options)?;
Ok(Self {
fri_params,
merkle_prover,
rs_encoder,
domain_factory,
_marker: PhantomData,
})
}
pub fn with_optimal_arity(
n_vars: usize,
log_inv_rate: usize,
security_bits: usize,
merkle_prover: MerkleProver,
domain_factory: DomainFactory,
ntt_options: NTTOptions,
) -> Result<Self, Error> {
let kappa = checked_log_2(<FExt as ExtensionField<F>>::DEGREE);
let n_packed_vars = n_vars
.checked_sub(kappa)
.ok_or(Error::IncorrectPolynomialSize { expected: kappa })?;
let arity = estimate_optimal_arity(
n_packed_vars + log_inv_rate,
size_of::<VCS::Digest>(),
size_of::<FExt>(),
);
assert!(arity > 0);
let fold_arities = iter::repeat(arity)
.take(n_packed_vars.saturating_sub(1) / arity)
.collect::<Vec<_>>();
Self::new(
n_vars,
log_inv_rate,
fold_arities,
security_bits,
merkle_prover,
domain_factory,
ntt_options,
)
}
pub const fn kappa() -> usize {
<TensorAlgebra<F, PE::Scalar>>::kappa()
}
fn prove_interleaved_fri_sumcheck<Prover, Transcript>(
&self,
codeword: &[PE],
committed: &MerkleProver::Committed,
mut sumcheck_prover: Prover,
advice: &mut AdviceWriter,
mut transcript: Transcript,
) -> Result<(), Error>
where
Prover: SumcheckProver<FExt>,
Transcript: CanObserve<FExt>
+ CanObserve<VCS::Digest>
+ CanSample<FExt>
+ CanSampleBits<usize>
+ CanWrite,
{
let n_rounds = sumcheck_prover.n_vars();
let mut fri_prover = FRIFolder::new(
&self.fri_params,
&self.merkle_prover,
PE::unpack_scalars(codeword),
committed,
)?;
let mut rounds = Vec::with_capacity(n_rounds);
let mut fri_commitments = Vec::with_capacity(self.fri_params.n_oracles());
for _ in 0..n_rounds {
let round_coeffs = sumcheck_prover.execute(FExt::ONE)?;
let round_proof = round_coeffs.truncate();
transcript.write_scalar_slice(round_proof.coeffs());
rounds.push(round_proof);
let challenge = transcript.sample();
match fri_prover.execute_fold_round(challenge)? {
FoldRoundOutput::NoCommitment => {}
FoldRoundOutput::Commitment(round_commitment) => {
transcript.write_packed(round_commitment);
fri_commitments.push(round_commitment);
}
}
sumcheck_prover.fold(challenge)?;
}
let _ = sumcheck_prover.finish()?;
fri_prover.finish_proof(advice, transcript)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn verify_interleaved_fri_sumcheck<Transcript>(
&self,
claim: &SumcheckClaim<FExt, BivariateProduct>,
codeword_commitment: &VCS::Digest,
ring_switch_evaluator: impl FnOnce(&[FExt]) -> Result<FExt, PolynomialError>,
advice: &mut AdviceReader,
mut transcript: Transcript,
) -> Result<(), Error>
where
Transcript: CanObserve<FExt>
+ CanObserve<VCS::Digest>
+ CanSample<FExt>
+ CanSampleBits<usize>
+ CanRead,
{
let n_rounds = claim.n_vars();
let mut arities_iter = self.fri_params.fold_arities().iter();
let mut fri_commitments = Vec::with_capacity(self.fri_params.n_oracles());
let mut next_commit_round = arities_iter.next().copied();
assert_eq!(claim.composite_sums().len(), 1);
let mut sum = claim.composite_sums()[0].sum;
let mut challenges = Vec::with_capacity(n_rounds);
for round_no in 0..n_rounds {
let round_proof = transcript
.read_scalar_slice::<FExt>(claim.max_individual_degree())
.map_err(Error::TranscriptError)?;
let round_proof = RoundProof(RoundCoeffs(round_proof));
let challenge = transcript.sample();
challenges.push(challenge);
let observe_fri_comm = next_commit_round.is_some_and(|round| round == round_no + 1);
if observe_fri_comm {
let comm = transcript
.read_packed::<VCS::Digest>()
.map_err(Error::TranscriptError)?;
fri_commitments.push(comm);
next_commit_round = arities_iter.next().map(|arity| round_no + 1 + arity);
}
sum = interpolate_round_proof(round_proof, sum, challenge);
}
let verifier = FRIVerifier::new(
&self.fri_params,
self.merkle_prover.scheme(),
codeword_commitment,
&fri_commitments,
&challenges,
)?;
let ring_switch_eval = ring_switch_evaluator(&challenges)?;
let final_fri_value = verifier.verify(advice, transcript)?;
if final_fri_value * ring_switch_eval != sum {
return Err(VerificationError::IncorrectSumcheckEvaluation.into());
}
Ok(())
}
}
impl<F, FDomain, FEncode, FExt, P, PE, DomainFactory, MerkleProver, DigestType, VCS>
PolyCommitScheme<P, FExt> for FRIPCS<F, FDomain, FEncode, PE, DomainFactory, MerkleProver, VCS>
where
F: TowerField,
FDomain: Field,
FEncode: BinaryField,
FExt: TowerField
+ PackedField<Scalar = FExt>
+ ExtensionField<F>
+ ExtensionField<FDomain>
+ ExtensionField<FEncode>
+ PackedExtension<F>
+ PackedExtension<FEncode>,
P: PackedField<Scalar = F>,
PE: PackedFieldIndexable<Scalar = FExt>
+ PackedExtension<F, PackedSubfield = P>
+ PackedExtension<FDomain>
+ PackedExtension<FEncode>,
DomainFactory: EvaluationDomainFactory<FDomain>,
MerkleProver: MerkleTreeProver<FExt, Scheme = VCS> + Sync,
DigestType: PackedField<Scalar: TowerField>,
VCS: MerkleTreeScheme<FExt, Digest = DigestType, Proof = Vec<DigestType>>,
{
type Commitment = VCS::Digest;
type Committed = (Vec<PE>, MerkleProver::Committed);
type Error = Error;
fn n_vars(&self) -> usize {
self.fri_params.n_fold_rounds() + Self::kappa()
}
#[instrument("FRIPCS::commit", skip_all, level = "debug")]
fn commit<Data>(
&self,
polys: &[MultilinearExtension<P, Data>],
) -> Result<(Self::Commitment, Self::Committed), Self::Error>
where
Data: Deref<Target = [P]> + Send + Sync,
{
if polys.len() != 1 {
todo!("handle batches of size greater than 1");
}
let poly = &polys[0];
if poly.n_vars() != self.n_vars() {
return Err(Error::IncorrectPolynomialSize {
expected: self.n_vars(),
});
}
let packed_evals = <PE as PackedExtension<F>>::cast_exts(poly.evals());
let fri::CommitOutput {
commitment,
committed,
codeword,
} = fri::commit_interleaved(
&self.rs_encoder,
&self.fri_params,
&self.merkle_prover,
packed_evals,
)?;
Ok((commitment, (codeword, committed)))
}
#[allow(clippy::needless_borrows_for_generic_args)]
#[instrument(skip_all, level = "debug")]
fn prove_evaluation<Data, Transcript, Backend>(
&self,
advice: &mut AdviceWriter,
transcript: &mut Transcript,
committed: &Self::Committed,
polys: &[MultilinearExtension<P, Data>],
query: &[PE::Scalar],
backend: &Backend,
) -> Result<(), Self::Error>
where
Data: Deref<Target = [P]> + Send + Sync,
Transcript: CanObserve<PE::Scalar>
+ CanObserve<Self::Commitment>
+ CanSample<PE::Scalar>
+ CanSampleBits<usize>
+ CanWrite,
Backend: ComputationBackend,
{
if query.len() != self.n_vars() {
return Err(PolynomialError::IncorrectQuerySize {
expected: self.n_vars(),
}
.into());
}
if polys.len() != 1 {
todo!("handle batches of size greater than 1");
}
let poly = &polys[0];
let packed_poly = MultilinearExtension::from_values_slice(
<PE as PackedExtension<F>>::cast_exts(poly.evals()),
)?;
let (_, query_from_kappa) = query.split_at(Self::kappa());
let expanded_query = backend.multilinear_query::<PE>(query_from_kappa)?;
let partial_eval = poly.evaluate_partial_high(&expanded_query)?;
let sumcheck_eval =
TensorAlgebra::<F, _>::new(iter_packed_slice(partial_eval.evals()).collect());
transcript.write_scalar_slice(sumcheck_eval.vertical_elems());
let tensor_mixing_challenges = transcript.sample_vec(Self::kappa());
let sumcheck_claim =
reduce_tensor_claim(self.n_vars(), sumcheck_eval, &tensor_mixing_challenges, backend);
let rs_eq = RingSwitchEqInd::<F, _>::new(
query_from_kappa.to_vec(),
tensor_mixing_challenges.to_vec(),
)?;
let transparent = rs_eq.multilinear_extension::<PE, _>(backend)?;
let sumcheck_prover = RegularSumcheckProver::new(
[packed_poly.to_ref(), transparent.to_ref()]
.map(MLEDirectAdapter::from)
.into(),
sumcheck_claim.composite_sums().iter().cloned(),
&self.domain_factory,
immediate_switchover_heuristic,
backend,
)?;
let (codeword, vcs_committed) = committed;
self.prove_interleaved_fri_sumcheck(
codeword,
vcs_committed,
sumcheck_prover,
advice,
transcript,
)
}
fn verify_evaluation<Transcript, Backend>(
&self,
advice: &mut AdviceReader,
transcript: &mut Transcript,
commitment: &Self::Commitment,
query: &[FExt],
values: &[FExt],
backend: &Backend,
) -> Result<(), Self::Error>
where
Transcript: CanObserve<FExt>
+ CanObserve<Self::Commitment>
+ CanSample<FExt>
+ CanSampleBits<usize>
+ CanRead,
Backend: ComputationBackend,
{
if query.len() != self.n_vars() {
return Err(PolynomialError::IncorrectQuerySize {
expected: self.n_vars(),
}
.into());
}
if values.len() != 1 {
todo!("handle batches of size greater than 1");
}
let sumcheck_eval = transcript
.read_scalar_slice::<FExt>(1 << Self::kappa())
.map_err(Error::TranscriptError)?;
let n_rounds = self.n_vars() - Self::kappa();
assert!(n_rounds > 0, "this is checked in the constructor");
let (query_to_kappa, query_from_kappa) = query.split_at(Self::kappa());
let sumcheck_eval = <TensorAlgebra<F, FExt>>::new(sumcheck_eval);
let expanded_query = backend.multilinear_query::<FExt>(query_to_kappa)?;
let computed_eval =
MultilinearExtension::from_values_slice(sumcheck_eval.vertical_elems())?
.evaluate(&expanded_query)?;
if values[0] != computed_eval {
return Err(VerificationError::IncorrectEvaluation.into());
}
let tensor_mixing_challenges = transcript.sample_vec(Self::kappa());
let sumcheck_claim =
reduce_tensor_claim(self.n_vars(), sumcheck_eval, &tensor_mixing_challenges, backend);
self.verify_interleaved_fri_sumcheck(
&sumcheck_claim,
commitment,
|challenges| {
let rs_eq = RingSwitchEqInd::<F, _>::new(
query_from_kappa.to_vec(),
tensor_mixing_challenges.to_vec(),
)?;
rs_eq.evaluate(challenges)
},
advice,
transcript,
)
}
fn proof_size(&self, n_polys: usize) -> usize {
if n_polys != 1 {
todo!("handle batches of size greater than 1");
}
let fe_size = mem::size_of::<FExt>();
let vc_size = mem::size_of::<VCS::Digest>();
let fri_termination_log_len =
self.fri_params.n_final_challenges() + self.fri_params.rs_code().log_inv_rate();
let sumcheck_eval_size = <TensorAlgebra<F, FExt>>::byte_size();
let sumcheck_rounds_size = fe_size * 2 * (self.n_vars() - Self::kappa());
let fri_commitments_size = vc_size * (sumcheck_rounds_size - 1);
let fri_terminate_codeword_size = fe_size * (1 << fri_termination_log_len);
let len_round_vcss = self.fri_params.rs_code().log_len() - fri_termination_log_len;
let fri_query_proofs_size =
(vc_size + 2 * fe_size) * (len_round_vcss + 1) * self.fri_params.n_test_queries();
sumcheck_eval_size
+ sumcheck_rounds_size
+ fri_commitments_size
+ fri_terminate_codeword_size
+ fri_query_proofs_size
}
}
pub fn estimate_optimal_arity(
log_block_length: usize,
digest_size: usize,
field_size: usize,
) -> usize {
(1..=log_block_length)
.map(|arity| {
(
arity,
((log_block_length) / 2 * digest_size + (1 << arity) * field_size)
* (log_block_length - arity)
/ arity,
)
})
.scan(None, |old: &mut Option<(usize, usize)>, new| {
let should_continue = !matches!(*old, Some(ref old) if new.1 > old.1);
*old = Some(new);
if should_continue {
Some(new)
} else {
None
}
})
.last()
.map(|(arity, _)| arity)
.unwrap_or(1)
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("the polynomial must have {expected} variables")]
IncorrectPolynomialSize { expected: usize },
#[error("sumcheck error: {0}")]
Sumcheck(#[from] sumcheck::Error),
#[error("polynomial error: {0}")]
Polynomial(#[from] PolynomialError),
#[error("FRI error: {0}")]
FRI(#[from] fri::Error),
#[error("NTT error: {0}")]
NTT(#[from] binius_ntt::Error),
#[error("verification failure: {0}")]
Verification(#[from] VerificationError),
#[error("HAL error: {0}")]
HalError(#[from] binius_hal::Error),
#[error("Math error: {0}")]
MathError(#[from] binius_math::Error),
#[error("Transcript error: {0}")]
TranscriptError(#[from] crate::transcript::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum VerificationError {
#[error("sumcheck verification error: {0}")]
Sumcheck(#[from] sumcheck::VerificationError),
#[error(
"tensor algebra evaluation shape is incorrect; \
expected {expected} field elements, got {actual}"
)]
IncorrectEvaluationShape { expected: usize, actual: usize },
#[error("evaluation value is inconsistent with the tensor evaluation")]
IncorrectEvaluation,
#[error("sumcheck final evaluation is incorrect")]
IncorrectSumcheckEvaluation,
#[error("incorrect number of FRI commitments")]
IncorrectNumberOfFRICommitments,
#[error("incorrect number of FRI query proofs")]
IncorrectNumberOfFRIQueries,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
fiat_shamir::HasherChallenger,
merkle_tree_vcs::BinaryMerkleTreeProver,
transcript::{AdviceWriter, TranscriptWriter},
};
use binius_field::{
arch::packed_polyval_128::PackedBinaryPolyval1x128b,
as_packed_field::{PackScalar, PackedType},
underlier::{Divisible, UnderlierType, WithUnderlier},
BinaryField128b, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b,
};
use binius_hal::make_portable_backend;
use binius_hash::{GroestlDigestCompression, GroestlHasher};
use binius_math::IsomorphicEvaluationDomainFactory;
use groestl_crypto::Groestl256;
use iter::repeat_with;
use rand::{prelude::StdRng, SeedableRng};
fn test_commit_prove_verify_success<U, F, FA, FE>(
n_vars: usize,
log_inv_rate: usize,
fold_arities: &[usize],
) where
U: UnderlierType
+ PackScalar<F>
+ PackScalar<FA>
+ PackScalar<FE>
+ PackScalar<BinaryField8b>
+ Divisible<u8>,
F: TowerField,
FA: BinaryField,
FE: TowerField
+ ExtensionField<F>
+ ExtensionField<FA>
+ ExtensionField<BinaryField8b>
+ PackedField<Scalar = FE>
+ PackedExtension<F>
+ PackedExtension<FA, PackedSubfield: PackedFieldIndexable>
+ PackedExtension<BinaryField8b, PackedSubfield: PackedFieldIndexable>,
PackedType<U, FA>: PackedFieldIndexable,
PackedType<U, FE>: PackedFieldIndexable,
{
let mut rng = StdRng::seed_from_u64(0);
let backend = make_portable_backend();
let multilin = MultilinearExtension::from_values(
repeat_with(|| <PackedType<U, F>>::random(&mut rng))
.take(1 << (n_vars - <PackedType<U, F>>::LOG_WIDTH))
.collect(),
)
.unwrap();
assert_eq!(multilin.n_vars(), n_vars);
let eval_point = repeat_with(|| <FE as Field>::random(&mut rng))
.take(n_vars)
.collect::<Vec<_>>();
let eval_query = backend.multilinear_query::<FE>(&eval_point).unwrap();
let eval = multilin.evaluate(&eval_query).unwrap();
let merkle_prover = BinaryMerkleTreeProver::<_, GroestlHasher<_>, _>::new(
GroestlDigestCompression::default(),
);
let domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField8b>::default();
let pcs = FRIPCS::<F, BinaryField8b, FA, PackedType<U, FE>, _, _, _>::new(
n_vars,
log_inv_rate,
fold_arities.to_vec(),
32,
merkle_prover,
domain_factory,
NTTOptions::default(),
)
.unwrap();
let (commitment, committed) = pcs.commit(&[multilin.to_ref()]).unwrap();
let mut prover_proof = crate::transcript::Proof {
transcript: TranscriptWriter::<HasherChallenger<Groestl256>>::default(),
advice: AdviceWriter::default(),
};
prover_proof.transcript.observe(commitment);
pcs.prove_evaluation(
&mut prover_proof.advice,
&mut prover_proof.transcript,
&committed,
&[multilin],
&eval_point,
&backend,
)
.unwrap();
let mut verifier_proof = prover_proof.into_verifier();
verifier_proof.transcript.observe(commitment);
pcs.verify_evaluation(
&mut verifier_proof.advice,
&mut verifier_proof.transcript,
&commitment,
&eval_point,
&[eval],
&backend,
)
.unwrap();
verifier_proof.finalize().unwrap()
}
#[test]
fn test_commit_prove_verify_success_1b_128b() {
test_commit_prove_verify_success::<
<PackedBinaryPolyval1x128b as WithUnderlier>::Underlier,
BinaryField1b,
BinaryField16b,
BinaryField128b,
>(18, 2, &[3, 3, 3]);
}
#[test]
fn test_commit_prove_verify_success_32b_128b() {
test_commit_prove_verify_success::<
<PackedBinaryPolyval1x128b as WithUnderlier>::Underlier,
BinaryField32b,
BinaryField16b,
BinaryField128b,
>(12, 2, &[3, 3, 3]);
}
#[test]
fn test_estimate_optimal_arity() {
let field_size = 128;
for log_block_length in 22..35 {
let digest_size = 256;
assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 4);
}
for log_block_length in 22..28 {
let digest_size = 1024;
assert_eq!(estimate_optimal_arity(log_block_length, digest_size, field_size), 6);
}
}
}