use super::{
error::Error,
evalcheck::{
BatchCommittedEvalClaims, CommittedEvalClaim, EvalcheckMultilinearClaim, EvalcheckProof,
},
subclaims::MemoizedQueries,
};
use crate::{
oracle::{
ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
MultilinearPolyOracle, ProjectionVariant,
},
polynomial::MultilinearPoly,
protocols::evalcheck_v2::subclaims::{process_packed_sumcheck, process_shifted_sumcheck},
witness::MultilinearExtensionIndex,
};
use binius_field::{
as_packed_field::{PackScalar, PackedType},
underlier::UnderlierType,
PackedFieldIndexable, TowerField,
};
use binius_hal::ComputationBackend;
use getset::{Getters, MutGetters};
use tracing::instrument;
#[derive(Getters, MutGetters)]
pub struct EvalcheckProver<'a, 'b, U, F, Backend>
where
U: UnderlierType + PackScalar<F>,
F: TowerField,
Backend: ComputationBackend,
{
pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
#[getset(get = "pub", get_mut = "pub")]
pub(crate) batch_committed_eval_claims: BatchCommittedEvalClaims<F>,
new_sumchecks_constraints: Vec<ConstraintSetBuilder<PackedType<U, F>>>,
memoized_queries: MemoizedQueries<PackedType<U, F>, Backend>,
backend: Backend,
}
impl<'a, 'b, U, F, Backend> EvalcheckProver<'a, 'b, U, F, Backend>
where
U: UnderlierType + PackScalar<F>,
PackedType<U, F>: PackedFieldIndexable,
F: TowerField,
Backend: ComputationBackend,
{
pub fn new(
oracles: &'a mut MultilinearOracleSet<F>,
witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
backend: Backend,
) -> Self {
let memoized_queries = MemoizedQueries::new();
let new_sumchecks_constraints = Vec::new();
let batch_committed_eval_claims =
BatchCommittedEvalClaims::new(&oracles.committed_batches());
Self {
oracles,
witness_index,
batch_committed_eval_claims,
new_sumchecks_constraints,
memoized_queries,
backend,
}
}
pub fn take_new_sumchecks_constraints(
&mut self,
) -> Result<Vec<ConstraintSet<PackedType<U, F>>>, OracleError> {
self.new_sumchecks_constraints
.iter_mut()
.map(|builder| std::mem::take(builder).build(self.oracles))
.filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
.rev()
.collect()
}
#[instrument(skip_all, name = "EvalcheckProverState::prove", level = "debug")]
pub fn prove(
&mut self,
evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
) -> Result<Vec<EvalcheckProof<F>>, Error> {
evalcheck_claims
.into_iter()
.map(|claim| self.prove_multilinear(claim))
.collect::<Result<Vec<_>, _>>()
}
#[instrument(
skip_all,
name = "EvalcheckProverState::prove_multilinear",
level = "debug"
)]
fn prove_multilinear(
&mut self,
evalcheck_claim: EvalcheckMultilinearClaim<F>,
) -> Result<EvalcheckProof<F>, Error> {
let EvalcheckMultilinearClaim {
poly: multilinear,
eval_point,
eval,
is_random_point,
} = evalcheck_claim;
use MultilinearPolyOracle::*;
let proof = match multilinear {
Transparent { .. } => EvalcheckProof::Transparent,
Committed { id, .. } => {
let subclaim = CommittedEvalClaim {
id,
eval_point,
eval,
is_random_point,
};
self.batch_committed_eval_claims.insert(subclaim);
EvalcheckProof::Committed
}
Repeating { inner, .. } => {
let n_vars = inner.n_vars();
let inner_eval_point = eval_point[..n_vars].to_vec();
let subclaim = EvalcheckMultilinearClaim {
poly: *inner,
eval_point: inner_eval_point,
eval,
is_random_point,
};
let subproof = self.prove_multilinear(subclaim)?;
EvalcheckProof::Repeating(Box::new(subproof))
}
Merged { poly0, poly1, .. } => {
let n_vars = poly0.n_vars();
assert_eq!(poly0.n_vars(), poly1.n_vars());
let inner_eval_point = &eval_point[..n_vars];
let (eval1, subproof1) =
self.eval_and_proof(*poly0, inner_eval_point, is_random_point)?;
let (eval2, subproof2) =
self.eval_and_proof(*poly1, inner_eval_point, is_random_point)?;
EvalcheckProof::Merged {
eval1,
eval2,
subproof1: Box::new(subproof1),
subproof2: Box::new(subproof2),
}
}
Interleaved { poly0, poly1, .. } => {
assert_eq!(poly0.n_vars(), poly1.n_vars());
let inner_eval_point = &eval_point[1..];
let (eval1, subproof1) =
self.eval_and_proof(*poly0, inner_eval_point, is_random_point)?;
let (eval2, subproof2) =
self.eval_and_proof(*poly1, inner_eval_point, is_random_point)?;
EvalcheckProof::Interleaved {
eval1,
eval2,
subproof1: Box::new(subproof1),
subproof2: Box::new(subproof2),
}
}
Shifted { shifted, .. } => {
process_shifted_sumcheck(
self.oracles,
&shifted,
eval_point.as_slice(),
eval,
self.witness_index,
&mut self.memoized_queries,
&mut self.new_sumchecks_constraints,
self.backend.clone(),
)?;
EvalcheckProof::Shifted
}
Packed { packed, .. } => {
process_packed_sumcheck(
self.oracles,
&packed,
eval_point.as_slice(),
eval,
self.witness_index,
&mut self.memoized_queries,
&mut self.new_sumchecks_constraints,
self.backend.clone(),
)?;
EvalcheckProof::Packed
}
Projected { projected, .. } => {
let (inner, values) = (projected.inner(), projected.values());
let new_eval_point = match projected.projection_variant() {
ProjectionVariant::LastVars => {
let mut new_eval_point = eval_point.clone();
new_eval_point.extend(values);
new_eval_point
}
ProjectionVariant::FirstVars => {
values.iter().cloned().chain(eval_point).collect()
}
};
let new_poly = *inner.clone();
let subclaim = EvalcheckMultilinearClaim {
poly: new_poly,
eval_point: new_eval_point,
eval,
is_random_point,
};
self.prove_multilinear(subclaim)?
}
LinearCombination {
linear_combination, ..
} => {
let subproofs = linear_combination
.polys()
.cloned()
.map(|suboracle| self.eval_and_proof(suboracle, &eval_point, is_random_point))
.collect::<Result<_, Error>>()?;
EvalcheckProof::Composite { subproofs }
}
ZeroPadded { inner, .. } => {
let inner_n_vars = inner.n_vars();
let inner_eval_point = &eval_point[..inner_n_vars];
let (eval, subproof) =
self.eval_and_proof(*inner, inner_eval_point, is_random_point)?;
EvalcheckProof::ZeroPadded(eval, Box::new(subproof))
}
};
Ok(proof)
}
fn eval_and_proof(
&mut self,
poly: MultilinearPolyOracle<F>,
eval_point: &[F],
is_random_point: bool,
) -> Result<(F, EvalcheckProof<F>), Error> {
let eval_query = self
.memoized_queries
.full_query(eval_point, self.backend.clone())?;
let witness_poly = self
.witness_index
.get_multilin_poly(poly.id())
.map_err(Error::Witness)?;
let eval = witness_poly.evaluate(eval_query.to_ref())?;
let subclaim = EvalcheckMultilinearClaim {
poly,
eval_point: eval_point.to_vec(),
eval,
is_random_point,
};
let subproof = self.prove_multilinear(subclaim)?;
Ok((eval, subproof))
}
}