binius_core/protocols/evalcheck/
prove.rsuse super::{
error::Error,
evalcheck::{
BatchCommittedEvalClaims, CommittedEvalClaim, EvalcheckMultilinearClaim, EvalcheckProof,
},
subclaims::MemoizedQueries,
};
use crate::{
oracle::{
ConstraintSet, ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet,
MultilinearPolyOracle, OracleId, ProjectionVariant,
},
protocols::evalcheck::subclaims::{process_packed_sumcheck, process_shifted_sumcheck},
witness::MultilinearExtensionIndex,
};
use binius_field::{
as_packed_field::{PackScalar, PackedType},
underlier::UnderlierType,
PackedFieldIndexable, TowerField,
};
use binius_hal::ComputationBackend;
use getset::{Getters, MutGetters};
use std::collections::HashMap;
use tracing::instrument;
#[derive(Getters, MutGetters)]
pub struct EvalcheckProver<'a, 'b, U, F, Backend>
where
U: UnderlierType + PackScalar<F>,
F: TowerField,
Backend: ComputationBackend,
{
pub(crate) oracles: &'a mut MultilinearOracleSet<F>,
pub(crate) witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
#[getset(get = "pub", get_mut = "pub")]
pub(crate) batch_committed_eval_claims: BatchCommittedEvalClaims<F>,
eval_memoization: HashMap<(OracleId, Vec<F>), F>,
new_sumchecks_constraints: Vec<ConstraintSetBuilder<PackedType<U, F>>>,
memoized_queries: MemoizedQueries<PackedType<U, F>, Backend>,
backend: &'a Backend,
}
impl<'a, 'b, U, F, Backend> EvalcheckProver<'a, 'b, U, F, Backend>
where
U: UnderlierType + PackScalar<F>,
PackedType<U, F>: PackedFieldIndexable,
F: TowerField,
Backend: ComputationBackend,
{
pub fn new(
oracles: &'a mut MultilinearOracleSet<F>,
witness_index: &'a mut MultilinearExtensionIndex<'b, U, F>,
backend: &'a Backend,
) -> Self {
let memoized_queries = MemoizedQueries::new();
let new_sumchecks_constraints = Vec::new();
let batch_committed_eval_claims =
BatchCommittedEvalClaims::new(&oracles.committed_batches());
let claim_memoization = HashMap::new();
Self {
oracles,
witness_index,
batch_committed_eval_claims,
eval_memoization: claim_memoization,
new_sumchecks_constraints,
memoized_queries,
backend,
}
}
pub fn take_new_sumchecks_constraints(
&mut self,
) -> Result<Vec<ConstraintSet<PackedType<U, F>>>, OracleError> {
self.new_sumchecks_constraints
.iter_mut()
.map(|builder| std::mem::take(builder).build_one(self.oracles))
.filter(|constraint| !matches!(constraint, Err(OracleError::EmptyConstraintSet)))
.rev()
.collect()
}
#[instrument(skip_all, name = "EvalcheckProverState::prove", level = "debug")]
pub fn prove(
&mut self,
evalcheck_claims: Vec<EvalcheckMultilinearClaim<F>>,
) -> Result<Vec<EvalcheckProof<F>>, Error> {
evalcheck_claims
.into_iter()
.map(|claim| self.prove_multilinear(claim))
.collect::<Result<Vec<_>, _>>()
}
#[instrument(
skip_all,
name = "EvalcheckProverState::prove_multilinear",
level = "debug"
)]
fn prove_multilinear(
&mut self,
evalcheck_claim: EvalcheckMultilinearClaim<F>,
) -> Result<EvalcheckProof<F>, Error> {
let EvalcheckMultilinearClaim {
poly: multilinear,
eval_point,
eval,
} = evalcheck_claim;
self.eval_memoization
.insert((multilinear.id(), eval_point.clone()), eval);
use MultilinearPolyOracle::*;
let proof = match multilinear {
Transparent { .. } => EvalcheckProof::Transparent,
Committed { id, .. } => {
let subclaim = CommittedEvalClaim {
id,
eval_point,
eval,
};
self.batch_committed_eval_claims.insert(subclaim);
EvalcheckProof::Committed
}
Repeating { inner, .. } => {
let n_vars = inner.n_vars();
let inner_eval_point = eval_point[..n_vars].to_vec();
let subclaim = EvalcheckMultilinearClaim {
poly: *inner,
eval_point: inner_eval_point,
eval,
};
let subproof = self.prove_multilinear(subclaim)?;
EvalcheckProof::Repeating(Box::new(subproof))
}
Shifted { shifted, .. } => {
process_shifted_sumcheck(
self.oracles,
&shifted,
eval_point.as_slice(),
eval,
self.witness_index,
&mut self.memoized_queries,
&mut self.new_sumchecks_constraints,
self.backend,
)?;
EvalcheckProof::Shifted
}
Packed { packed, .. } => {
process_packed_sumcheck(
self.oracles,
&packed,
eval_point.as_slice(),
eval,
self.witness_index,
&mut self.memoized_queries,
&mut self.new_sumchecks_constraints,
self.backend,
)?;
EvalcheckProof::Packed
}
Projected { projected, .. } => {
let (inner, values) = (projected.inner(), projected.values());
let new_eval_point = match projected.projection_variant() {
ProjectionVariant::LastVars => {
let mut new_eval_point = eval_point.clone();
new_eval_point.extend(values);
new_eval_point
}
ProjectionVariant::FirstVars => {
values.iter().cloned().chain(eval_point).collect()
}
};
let new_poly = *inner.clone();
let subclaim = EvalcheckMultilinearClaim {
poly: new_poly,
eval_point: new_eval_point,
eval,
};
self.prove_multilinear(subclaim)?
}
LinearCombination {
linear_combination, ..
} => {
let subproofs = linear_combination
.polys()
.cloned()
.map(|suboracle| self.eval_and_proof(suboracle, &eval_point))
.collect::<Result<_, Error>>()?;
EvalcheckProof::LinearCombination { subproofs }
}
ZeroPadded { inner, .. } => {
let inner_n_vars = inner.n_vars();
let inner_eval_point = &eval_point[..inner_n_vars];
let (eval, subproof) = self.eval_and_proof(*inner, inner_eval_point)?;
EvalcheckProof::ZeroPadded(eval, Box::new(subproof))
}
};
Ok(proof)
}
fn eval_and_proof(
&mut self,
poly: MultilinearPolyOracle<F>,
eval_point: &[F],
) -> Result<(F, EvalcheckProof<F>), Error> {
let eval = self
.eval_memoization
.get(&(poly.id(), eval_point.to_vec()))
.map(|eval| Result::<_, Error>::Ok(*eval))
.unwrap_or_else(|| {
let eval_query = self.memoized_queries.full_query(eval_point, self.backend)?;
let witness_poly = self
.witness_index
.get_multilin_poly(poly.id())
.map_err(Error::Witness)?;
witness_poly
.evaluate(eval_query.to_ref())
.map_err(Error::from)
})?;
let subclaim = EvalcheckMultilinearClaim {
poly,
eval_point: eval_point.to_vec(),
eval,
};
let subproof = self.prove_multilinear(subclaim)?;
Ok((eval, subproof))
}
}