use crate::{
challenger::{CanObserve, CanSample, CanSampleBits},
poly_commit::PolyCommitScheme,
polynomial::Error as PolynomialError,
};
use crate::transcript::{CanRead, CanWrite};
use binius_field::{util::inner_product_unchecked, ExtensionField, Field, PackedField, TowerField};
use binius_hal::ComputationBackend;
use binius_math::{MultilinearExtension, MultilinearQuery};
use binius_utils::bail;
use bytemuck::zeroed_vec;
use std::{marker::PhantomData, ops::Deref};
use tracing::instrument;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("number of polynomials must be less than {max}")]
NumPolys { max: usize },
#[error("expected all polynomials to have {expected} variables")]
NumVars { expected: usize },
#[error("number of variables in the inner PCS, {n_inner} is not what is expected, {n_vars} + {log_num_polys}")]
NumVarsInnerOuter {
n_inner: usize,
n_vars: usize,
log_num_polys: usize,
},
#[error("inner PCS error: {0}")]
InnerPCS(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("polynomial error: {0}")]
Polynomial(#[from] PolynomialError),
#[error("HAL error: {0}")]
HalError(#[from] binius_hal::Error),
#[error("Math error: {0}")]
MathError(#[from] binius_math::Error),
}
#[instrument(skip_all, level = "trace")]
fn merge_polynomials<P, Data>(
n_vars: usize,
log_n_polys: usize,
polys: &[MultilinearExtension<P, Data>],
) -> Result<MultilinearExtension<P>, Error>
where
P: PackedField,
Data: Deref<Target = [P]> + Send + Sync,
{
if polys.len() > 1 << log_n_polys {
bail!(Error::NumPolys {
max: 1 << log_n_polys
});
}
if polys.iter().any(|poly| poly.n_vars() != n_vars) {
bail!(Error::NumVars { expected: n_vars });
}
let poly_packed_size = 1 << (n_vars - P::LOG_WIDTH);
let mut packed_merged = zeroed_vec(poly_packed_size << log_n_polys);
for (u, poly) in polys.iter().enumerate() {
packed_merged[u * poly_packed_size..(u + 1) * poly_packed_size]
.copy_from_slice(poly.evals())
}
Ok(MultilinearExtension::from_values(packed_merged)?)
}
#[derive(Debug)]
pub struct BatchPCS<P, FE, InnerPCS>
where
P: PackedField,
FE: ExtensionField<P::Scalar> + TowerField,
InnerPCS: PolyCommitScheme<P, FE>,
{
inner: InnerPCS,
n_vars: usize, log_num_polys: usize, _marker: PhantomData<(P, FE)>,
}
impl<F, FE, P, Inner> BatchPCS<P, FE, Inner>
where
F: Field,
P: PackedField<Scalar = F>,
FE: ExtensionField<F> + TowerField,
Inner: PolyCommitScheme<P, FE>,
{
pub fn new(inner: Inner, n_vars: usize, log_num_polys: usize) -> Result<Self, Error> {
if inner.n_vars() != n_vars + log_num_polys {
bail!(Error::NumVarsInnerOuter {
n_inner: inner.n_vars(),
n_vars,
log_num_polys,
});
}
Ok(Self {
inner,
n_vars, log_num_polys, _marker: PhantomData,
})
}
}
impl<F, FE, P, Inner> PolyCommitScheme<P, FE> for BatchPCS<P, FE, Inner>
where
F: Field,
P: PackedField<Scalar = F>,
FE: ExtensionField<F> + TowerField,
Inner: PolyCommitScheme<P, FE>,
{
type Commitment = Inner::Commitment;
type Committed = Inner::Committed;
type Proof = Proof<Inner::Proof>;
type Error = Error;
fn n_vars(&self) -> usize {
self.n_vars
}
fn commit<Data>(
&self,
polys: &[MultilinearExtension<P, Data>],
) -> Result<(Self::Commitment, Self::Committed), Self::Error>
where
Data: Deref<Target = [P]> + Send + Sync,
{
let merged_poly = merge_polynomials(self.n_vars, self.log_num_polys, polys)?;
self.inner
.commit(&[merged_poly])
.map_err(|err| Error::InnerPCS(Box::new(err)))
}
fn prove_evaluation<Data, Transcript, Backend>(
&self,
transcript: &mut Transcript,
committed: &Self::Committed,
polys: &[MultilinearExtension<P, Data>],
query: &[FE],
backend: &Backend,
) -> Result<Self::Proof, Self::Error>
where
Data: Deref<Target = [P]> + Send + Sync,
Transcript: CanObserve<FE>
+ CanObserve<Self::Commitment>
+ CanSample<FE>
+ CanSampleBits<usize>
+ CanWrite,
Backend: ComputationBackend,
{
if query.len() != self.n_vars {
bail!(PolynomialError::IncorrectQuerySize {
expected: self.n_vars
});
}
let challenges = transcript.sample_vec(self.log_num_polys);
let new_query = query
.iter()
.copied()
.chain(challenges.iter().copied())
.collect::<Vec<_>>();
let merged_poly = merge_polynomials(self.n_vars, self.log_num_polys, polys)?;
let inner_pcs_proof = self
.inner
.prove_evaluation(transcript, committed, &[merged_poly], &new_query, backend)
.map_err(|err| Error::InnerPCS(Box::new(err)))?;
Ok(Proof(inner_pcs_proof))
}
fn verify_evaluation<Transcript, Backend>(
&self,
transcript: &mut Transcript,
commitment: &Self::Commitment,
query: &[FE],
proof: Self::Proof,
values: &[FE],
backend: &Backend,
) -> Result<(), Self::Error>
where
Transcript: CanObserve<FE>
+ CanObserve<Self::Commitment>
+ CanSample<FE>
+ CanSampleBits<usize>
+ CanRead,
Backend: ComputationBackend,
{
if values.len() > 1 << self.log_num_polys {
bail!(Error::NumPolys {
max: 1 << self.log_num_polys
});
}
let mixing_challenges = transcript.sample_vec(self.log_num_polys);
let mixed_evaluation = inner_product_unchecked(
MultilinearQuery::expand(&mixing_challenges).into_expansion(),
values.iter().copied(),
);
let mixed_value = &[mixed_evaluation];
let new_query = query
.iter()
.copied()
.chain(mixing_challenges.iter().copied())
.collect::<Vec<_>>();
self.inner
.verify_evaluation(transcript, commitment, &new_query, proof.0, mixed_value, backend)
.map_err(|err| Error::InnerPCS(Box::new(err)))?;
Ok(())
}
fn proof_size(&self, _n_polys: usize) -> usize {
self.inner.proof_size(1)
}
}
#[derive(Debug, Clone)]
pub struct Proof<Inner>(Inner);
#[cfg(test)]
mod tests {
use super::*;
use crate::{
fiat_shamir::HasherChallenger,
merkle_tree_vcs::BinaryMerkleTreeProver,
poly_commit::FRIPCS,
transcript::{AdviceWriter, TranscriptWriter},
};
use binius_field::{
arch::OptimalUnderlier128b, as_packed_field::PackedType, BinaryField128b, BinaryField32b,
BinaryField8b,
};
use binius_hal::{make_portable_backend, ComputationBackendExt};
use binius_hash::{GroestlDigestCompression, GroestlHasher};
use binius_math::IsomorphicEvaluationDomainFactory;
use binius_ntt::NTTOptions;
use groestl_crypto::Groestl256;
use p3_util::log2_ceil_usize;
use rand::{prelude::StdRng, SeedableRng};
use std::iter::repeat_with;
#[test]
fn test_commit_prove_verify_success_128b() {
type U = OptimalUnderlier128b;
type F = BinaryField128b;
let mut rng = StdRng::seed_from_u64(0);
let n_vars = 7;
let n_polys = 6;
let m = log2_ceil_usize(n_polys);
let total_new_vars = n_vars + m;
let multilins = repeat_with(|| {
MultilinearExtension::from_values(
repeat_with(|| <PackedType<U, F>>::random(&mut rng))
.take(1 << (n_vars))
.collect(),
)
.unwrap()
})
.take(n_polys)
.collect::<Vec<_>>();
let eval_point = repeat_with(|| <F as Field>::random(&mut rng))
.take(n_vars)
.collect::<Vec<_>>();
let backend = make_portable_backend();
let eval_query = backend.multilinear_query::<F>(&eval_point).unwrap();
let values = multilins
.iter()
.map(|x| x.evaluate(&eval_query).unwrap())
.collect::<Vec<_>>();
let domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField8b>::default();
let merkle_prover = BinaryMerkleTreeProver::<_, GroestlHasher<_>, _>::new(
GroestlDigestCompression::default(),
);
let inner_pcs = FRIPCS::<F, BinaryField8b, F, PackedType<U, F>, _, _, _>::new(
total_new_vars,
2,
vec![3, 3, 3],
32,
merkle_prover,
domain_factory,
NTTOptions::default(),
)
.unwrap();
let backend = make_portable_backend();
let pcs = BatchPCS::new(inner_pcs, n_vars, m).unwrap();
let polys = multilins.iter().map(|x| x.to_ref()).collect::<Vec<_>>();
let (commitment, committed) = pcs.commit(&polys).unwrap();
let mut prover_proof = crate::transcript::Proof {
transcript: TranscriptWriter::<HasherChallenger<Groestl256>>::default(),
advice: AdviceWriter::new(),
};
prover_proof.transcript.observe(commitment);
let proof = pcs
.prove_evaluation(
&mut prover_proof.transcript,
&committed,
&polys,
&eval_point,
&backend,
)
.unwrap();
let mut verifier_proof = prover_proof.into_verifier();
verifier_proof.transcript.observe(commitment);
pcs.verify_evaluation(
&mut verifier_proof.transcript,
&commitment,
&eval_point,
proof,
&values,
&backend,
)
.unwrap();
}
#[test]
fn test_commit_prove_verify_success_32b() {
type U = OptimalUnderlier128b;
type F = BinaryField32b;
type FE = BinaryField128b;
type Packed = PackedType<U, F>;
let mut rng = StdRng::seed_from_u64(0);
let n_vars = 6;
let n_polys = 6;
let m = log2_ceil_usize(n_polys);
let total_new_vars = n_vars + m;
let multilins = repeat_with(|| {
MultilinearExtension::from_values(
repeat_with(|| <PackedType<U, F>>::random(&mut rng))
.take(1 << (n_vars - Packed::LOG_WIDTH))
.collect(),
)
.unwrap()
})
.take(n_polys)
.collect::<Vec<_>>();
let eval_point = repeat_with(|| <FE as Field>::random(&mut rng))
.take(n_vars)
.collect::<Vec<_>>();
let backend = make_portable_backend();
let eval_query = backend.multilinear_query::<FE>(&eval_point).unwrap();
let values = multilins
.iter()
.map(|x| x.evaluate(&eval_query).unwrap())
.collect::<Vec<_>>();
let domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField8b>::default();
let merkle_prover = BinaryMerkleTreeProver::<_, GroestlHasher<_>, _>::new(
GroestlDigestCompression::default(),
);
let inner_pcs = FRIPCS::<F, BinaryField8b, F, PackedType<U, FE>, _, _, _>::new(
total_new_vars,
2,
vec![2, 2, 2],
32,
merkle_prover,
domain_factory,
NTTOptions::default(),
)
.unwrap();
let backend = make_portable_backend();
let pcs = BatchPCS::new(inner_pcs, n_vars, m).unwrap();
let polys = multilins.iter().map(|x| x.to_ref()).collect::<Vec<_>>();
let (commitment, committed) = pcs.commit(&polys).unwrap();
let mut prover_proof = crate::transcript::Proof {
transcript: TranscriptWriter::<HasherChallenger<Groestl256>>::default(),
advice: AdviceWriter::default(),
};
prover_proof.transcript.observe(commitment);
let proof = pcs
.prove_evaluation(
&mut prover_proof.transcript,
&committed,
&polys,
&eval_point,
&backend,
)
.unwrap();
let mut verifier_proof = prover_proof.into_verifier();
verifier_proof.transcript.observe(commitment);
pcs.verify_evaluation(
&mut verifier_proof.transcript,
&commitment,
&eval_point,
proof,
&values,
&backend,
)
.unwrap();
}
}