use std::{marker::PhantomData, ops::Range, sync::Arc};
use binius_field::{
util::{eq, powers},
ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, RepackedExtension,
TowerField,
};
use binius_hal::{ComputationBackend, SumcheckEvaluator};
use binius_math::{
CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MLEDirectAdapter,
MultilinearPoly, MultilinearQuery,
};
use binius_utils::bail;
use bytemuck::zeroed_vec;
use getset::Getters;
use itertools::izip;
use rayon::prelude::*;
use stackalloc::stackalloc_with_default;
use tracing::instrument;
use crate::{
polynomial::{Error as PolynomialError, MultilinearComposite},
protocols::sumcheck::{
common::{determine_switchovers, equal_n_vars_check, small_field_embedding_degree_check},
prove::{
common::fold_partial_eq_ind,
univariate::{
zerocheck_univariate_evals, ZerocheckUnivariateEvalsOutput,
ZerocheckUnivariateFoldResult,
},
ProverState, SumcheckInterpolator, SumcheckProver, UnivariateZerocheckProver,
},
univariate::LagrangeRoundEvals,
univariate_zerocheck::domain_size,
Error, RoundCoeffs,
},
witness::MultilinearWitness,
};
pub fn validate_witness<'a, F, P, M, Composition>(
multilinears: &[M],
zero_claims: impl IntoIterator<Item = &'a (Arc<str>, Composition)>,
) -> Result<(), Error>
where
F: Field,
P: PackedField<Scalar = F>,
M: MultilinearPoly<P> + Send + Sync,
Composition: CompositionPolyOS<P> + 'a,
{
let n_vars = multilinears
.first()
.map(|multilinear| multilinear.n_vars())
.unwrap_or_default();
for multilinear in multilinears.iter() {
if multilinear.n_vars() != n_vars {
bail!(Error::NumberOfVariablesMismatch);
}
}
let multilinears = multilinears.iter().collect::<Vec<_>>();
for (name, composition) in zero_claims.into_iter() {
let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
(0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
if witness.evaluate_on_hypercube(j)? != F::ZERO {
return Err(Error::ZerocheckNaiveValidationFailure {
composition_name: name.to_string(),
vertex_index: j,
});
}
Ok(())
})?;
}
Ok(())
}
#[derive(Debug, Getters)]
pub struct UnivariateZerocheck<'a, 'm, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
FDomain: Field,
PBase: PackedField,
P: PackedField,
Backend: ComputationBackend,
{
n_vars: usize,
#[getset(get = "pub")]
multilinears: Vec<M>,
switchover_rounds: Vec<usize>,
compositions: Vec<(Arc<str>, CompositionBase, Composition)>,
zerocheck_challenges: Vec<P::Scalar>,
domains: Vec<InterpolationDomain<FDomain>>,
backend: &'a Backend,
univariate_evals_output: Option<ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>>,
_p_base_marker: PhantomData<PBase>,
_m_marker: PhantomData<&'m ()>,
}
impl<'a, 'm, F, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
UnivariateZerocheck<'a, 'm, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
F: Field + ExtensionField<PBase::Scalar> + ExtensionField<FDomain>,
FDomain: Field,
PBase: PackedField<Scalar: ExtensionField<FDomain>> + PackedExtension<FDomain>,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
CompositionBase: CompositionPolyOS<PBase>,
Composition: CompositionPolyOS<P>,
M: MultilinearPoly<P> + Send + Sync + 'm,
Backend: ComputationBackend,
{
pub fn new(
multilinears: Vec<M>,
zero_claims: impl IntoIterator<Item = (Arc<str>, CompositionBase, Composition)>,
zerocheck_challenges: &[F],
evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
switchover_fn: impl Fn(usize) -> usize,
backend: &'a Backend,
) -> Result<Self, Error> {
let n_vars = equal_n_vars_check(&multilinears)?;
let compositions = zero_claims.into_iter().collect::<Vec<_>>();
for (_, composition_base, composition) in compositions.iter() {
if composition_base.n_vars() != multilinears.len()
|| composition.n_vars() != multilinears.len()
|| composition_base.degree() != composition.degree()
{
bail!(Error::InvalidComposition {
actual: composition.n_vars(),
expected: multilinears.len(),
});
}
}
#[cfg(feature = "debug_validate_sumcheck")]
{
let compositions = compositions
.iter()
.map(|(name, _, a)| (name.clone(), a))
.collect::<Vec<_>>();
validate_witness(&multilinears, &compositions)?;
}
small_field_embedding_degree_check::<PBase, P, _>(&multilinears)?;
let switchover_rounds = determine_switchovers(&multilinears, switchover_fn);
let zerocheck_challenges = zerocheck_challenges.to_vec();
let domains = compositions
.iter()
.map(|(_, _, composition)| {
let degree = composition.degree();
let domain = evaluation_domain_factory.create(degree + 1)?;
Ok(domain.into())
})
.collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
.map_err(Error::MathError)?;
Ok(Self {
n_vars,
multilinears,
switchover_rounds,
compositions,
zerocheck_challenges,
domains,
backend,
univariate_evals_output: None,
_p_base_marker: PhantomData,
_m_marker: PhantomData,
})
}
#[instrument(skip_all, level = "debug")]
#[allow(clippy::type_complexity)]
pub fn into_regular_zerocheck(
self,
) -> Result<
ZerocheckProver<
'a,
FDomain,
PBase,
P,
CompositionBase,
Composition,
MultilinearWitness<'m, P>,
Backend,
>,
Error,
> {
if self.univariate_evals_output.is_some() {
bail!(Error::ExpectedFold);
}
let multilinears = self
.multilinears
.into_iter()
.map(|multilinear| Arc::new(multilinear) as MultilinearWitness<'_, P>)
.collect::<Vec<_>>();
#[cfg(feature = "debug_validate_sumcheck")]
{
let compositions = self
.compositions
.iter()
.map(|(name, _, a)| (name.clone(), a))
.collect::<Vec<_>>();
validate_witness(&multilinears, &compositions)?;
}
let start = self.n_vars.min(1);
let partial_eq_ind_evals = self
.backend
.tensor_product_full_query(&self.zerocheck_challenges[start..])?;
let claimed_sums = vec![F::ZERO; self.compositions.len()];
ZerocheckProver::new(
multilinears,
self.switchover_rounds,
self.compositions
.into_iter()
.map(|(_, a, b)| (a, b))
.collect(),
partial_eq_ind_evals,
self.zerocheck_challenges,
claimed_sums,
self.domains,
RegularFirstRound::BaseField,
self.backend,
)
}
}
impl<'a, 'm, F, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
UnivariateZerocheckProver<'a, F>
for UnivariateZerocheck<'a, 'm, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
F: TowerField + ExtensionField<PBase::Scalar> + ExtensionField<FDomain>,
FDomain: TowerField,
PBase: PackedFieldIndexable<Scalar: ExtensionField<FDomain>>
+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
P: PackedFieldIndexable<Scalar = F> + RepackedExtension<PBase> + PackedExtension<FDomain>,
CompositionBase: CompositionPolyOS<PBase> + 'static,
Composition: CompositionPolyOS<P> + 'static,
M: MultilinearPoly<P> + Send + Sync + 'm,
Backend: ComputationBackend,
{
fn n_vars(&self) -> usize {
self.n_vars
}
fn domain_size(&self, skip_rounds: usize) -> usize {
self.compositions
.iter()
.map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
.max()
.unwrap_or(0)
}
#[instrument(skip_all, level = "debug")]
fn execute_univariate_round(
&mut self,
skip_rounds: usize,
max_domain_size: usize,
batch_coeff: F,
) -> Result<LagrangeRoundEvals<F>, Error> {
if self.univariate_evals_output.is_some() {
bail!(Error::ExpectedFold);
}
let compositions_base = self
.compositions
.iter()
.map(|(_, composition_base, _)| composition_base)
.collect::<Vec<_>>();
let univariate_evals_output = zerocheck_univariate_evals(
&self.multilinears,
&compositions_base,
&self.zerocheck_challenges,
skip_rounds,
max_domain_size,
self.backend,
)?;
let zeros_prefix_len = 1 << skip_rounds;
let batched_round_evals = univariate_evals_output
.round_evals
.iter()
.zip(powers(batch_coeff))
.map(|(evals, scalar)| {
let round_evals = LagrangeRoundEvals {
zeros_prefix_len,
evals: evals.clone(),
};
round_evals * scalar
})
.try_fold(
LagrangeRoundEvals::zeros(max_domain_size),
|mut accum, evals| -> Result<_, Error> {
accum.add_assign_lagrange(&evals)?;
Ok(accum)
},
)?;
self.univariate_evals_output = Some(univariate_evals_output);
Ok(batched_round_evals)
}
#[instrument(skip_all, level = "debug")]
fn fold_univariate_round(
self: Box<Self>,
challenge: F,
) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
if self.univariate_evals_output.is_none() {
bail!(Error::ExpectedExecution);
}
let ZerocheckUnivariateFoldResult {
skip_rounds,
subcube_lagrange_coeffs,
claimed_prime_sums,
partial_eq_ind_evals,
} = self
.univariate_evals_output
.expect("validated to be Some")
.fold::<FDomain>(challenge)?;
let mut packed_subcube_lagrange_coeffs =
zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
P::unpack_scalars_mut(&mut packed_subcube_lagrange_coeffs)[..1 << skip_rounds]
.copy_from_slice(&subcube_lagrange_coeffs);
let lagrange_coeffs_query =
MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
let partial_low_multilinears = self
.multilinears
.into_par_iter()
.map(|multilinear| -> Result<_, Error> {
let multilinear =
multilinear.evaluate_partial_low(lagrange_coeffs_query.to_ref())?;
let mle_adapter = Arc::new(MLEDirectAdapter::from(multilinear));
Ok(mle_adapter as MultilinearWitness<'static, P>)
})
.collect::<Result<Vec<_>, _>>()?;
let switchover_rounds = self
.switchover_rounds
.into_iter()
.map(|switchover_round| switchover_round.saturating_sub(skip_rounds))
.collect();
let zerocheck_challenges = self.zerocheck_challenges.to_vec();
let regular_prover = ZerocheckProver::new(
partial_low_multilinears,
switchover_rounds,
self.compositions
.into_iter()
.map(|(_, a, b)| (a, b))
.collect(),
partial_eq_ind_evals,
zerocheck_challenges,
claimed_prime_sums,
self.domains,
RegularFirstRound::LargeField,
self.backend,
)?;
Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
}
}
#[derive(Debug, Clone, Copy)]
enum RegularFirstRound {
BaseField,
LargeField,
}
#[derive(Debug)]
pub struct ZerocheckProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
FDomain: Field,
PBase: PackedField,
P: PackedField,
M: MultilinearPoly<P> + Send + Sync,
Backend: ComputationBackend,
{
n_vars: usize,
state: ProverState<'a, FDomain, P, M, Backend>,
eq_ind_eval: P::Scalar,
partial_eq_ind_evals: Backend::Vec<P>,
zerocheck_challenges: Vec<P::Scalar>,
compositions: Vec<(CompositionBase, Composition)>,
domains: Vec<InterpolationDomain<FDomain>>,
first_round: RegularFirstRound,
_p_base_marker: PhantomData<PBase>,
}
impl<'a, F, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
ZerocheckProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
F: Field + ExtensionField<PBase::Scalar> + ExtensionField<FDomain>,
FDomain: Field,
PBase: PackedField<Scalar: ExtensionField<FDomain>> + PackedExtension<FDomain>,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
CompositionBase: CompositionPolyOS<PBase>,
Composition: CompositionPolyOS<P>,
M: MultilinearPoly<P> + Send + Sync,
Backend: ComputationBackend,
{
#[allow(clippy::too_many_arguments)]
fn new(
multilinears: Vec<M>,
switchover_rounds: Vec<usize>,
compositions: Vec<(CompositionBase, Composition)>,
partial_eq_ind_evals: Backend::Vec<P>,
zerocheck_challenges: Vec<F>,
claimed_prime_sums: Vec<F>,
domains: Vec<InterpolationDomain<FDomain>>,
first_round: RegularFirstRound,
backend: &'a Backend,
) -> Result<Self, Error> {
let evaluation_points = domains
.iter()
.max_by_key(|domain| domain.points().len())
.map_or_else(|| Vec::new(), |domain| domain.points().to_vec());
if claimed_prime_sums.len() != compositions.len() {
bail!(Error::IncorrectClaimedPrimeSumsLength);
}
let state = ProverState::new_with_switchover_rounds(
multilinears,
&switchover_rounds,
claimed_prime_sums,
evaluation_points,
backend,
)?;
let n_vars = state.n_vars();
if zerocheck_challenges.len() != n_vars {
bail!(Error::IncorrectZerocheckChallengesLength);
}
if partial_eq_ind_evals.len() != 1 << n_vars.saturating_sub(1 + P::LOG_WIDTH) {
bail!(Error::IncorrectZerocheckPartialEqIndSize);
}
let eq_ind_eval = F::ONE;
Ok(Self {
n_vars,
state,
eq_ind_eval,
partial_eq_ind_evals,
zerocheck_challenges,
compositions,
domains,
first_round,
_p_base_marker: PhantomData,
})
}
fn round(&self) -> usize {
self.n_vars - self.n_rounds_remaining()
}
fn n_rounds_remaining(&self) -> usize {
self.state.n_vars()
}
fn update_eq_ind_eval(&mut self, challenge: F) {
let alpha = self.zerocheck_challenges[self.round()];
self.eq_ind_eval *= eq(alpha, challenge);
}
#[instrument(skip_all, level = "debug")]
fn fold_partial_eq_ind(&mut self) {
fold_partial_eq_ind::<P, Backend>(
self.n_rounds_remaining(),
&mut self.partial_eq_ind_evals,
);
}
}
impl<F, FDomain, PBase, P, CompositionBase, Composition, M, Backend> SumcheckProver<F>
for ZerocheckProver<'_, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
F: Field + ExtensionField<PBase::Scalar> + ExtensionField<FDomain>,
FDomain: Field,
PBase: PackedField<Scalar: ExtensionField<FDomain>> + PackedExtension<FDomain>,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<PBase>,
CompositionBase: CompositionPolyOS<PBase>,
Composition: CompositionPolyOS<P>,
M: MultilinearPoly<P> + Send + Sync,
Backend: ComputationBackend,
{
fn n_vars(&self) -> usize {
self.n_vars
}
#[instrument(skip_all, name = "ZerocheckProver::fold", level = "debug")]
fn fold(&mut self, challenge: F) -> Result<(), Error> {
self.update_eq_ind_eval(challenge);
self.state.fold(challenge)?;
self.fold_partial_eq_ind();
Ok(())
}
#[instrument(skip_all, name = "ZerocheckProver::execute", level = "debug")]
fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
let round = self.round();
let base_field_first_round =
round == 0 && matches!(self.first_round, RegularFirstRound::BaseField);
let coeffs = if base_field_first_round {
let evaluators = izip!(&self.compositions, &self.domains)
.map(|((composition_base, composition), interpolation_domain)| {
ZerocheckFirstRoundEvaluator {
composition_base,
composition,
interpolation_domain,
partial_eq_ind_evals: &self.partial_eq_ind_evals,
_p_base_marker: PhantomData,
}
})
.collect::<Vec<_>>();
let evals = self
.state
.calculate_first_round_evals::<PBase, _, Composition>(&evaluators)?;
self.state
.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?
} else {
let evaluators = izip!(&self.compositions, &self.domains)
.map(|((_, composition), interpolation_domain)| ZerocheckLaterRoundEvaluator {
composition,
interpolation_domain,
partial_eq_ind_evals: &self.partial_eq_ind_evals,
round_zerocheck_challenge: self.zerocheck_challenges[round],
})
.collect::<Vec<_>>();
let evals = self.state.calculate_later_round_evals(&evaluators)?;
self.state
.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?
};
let alpha = self.zerocheck_challenges[round];
let constant_scalar = F::ONE - alpha;
let linear_scalar = alpha.double() - F::ONE;
let coeffs_scaled_by_constant_term = coeffs.clone() * constant_scalar;
let mut coeffs_scaled_by_linear_term = coeffs * linear_scalar;
coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); let sumcheck_coeffs = coeffs_scaled_by_constant_term + &coeffs_scaled_by_linear_term;
Ok(sumcheck_coeffs * self.eq_ind_eval)
}
#[instrument(skip_all, name = "ZerocheckProver::finish", level = "debug")]
fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
let mut evals = self.state.finish()?;
evals.push(self.eq_ind_eval);
Ok(evals)
}
}
struct ZerocheckFirstRoundEvaluator<'a, PBase, P, FDomain, CompositionBase, Composition>
where
PBase: PackedField,
P: PackedField,
FDomain: Field,
{
composition_base: &'a CompositionBase,
composition: &'a Composition,
interpolation_domain: &'a InterpolationDomain<FDomain>,
partial_eq_ind_evals: &'a [P],
_p_base_marker: PhantomData<PBase>,
}
impl<F, PBase, P, FDomain, CompositionBase, Composition> SumcheckEvaluator<PBase, P, Composition>
for ZerocheckFirstRoundEvaluator<'_, PBase, P, FDomain, CompositionBase, Composition>
where
F: Field + ExtensionField<PBase::Scalar> + ExtensionField<FDomain>,
PBase: PackedField,
P: PackedField<Scalar = F>,
FDomain: Field,
CompositionBase: CompositionPolyOS<PBase>,
Composition: CompositionPolyOS<P>,
{
fn eval_point_indices(&self) -> Range<usize> {
2..self.composition.degree() + 1
}
fn process_subcube_at_eval_point(
&self,
subcube_vars: usize,
subcube_index: usize,
batch_query: &[&[PBase]],
) -> P {
if self.composition.degree() == 1 {
return P::zero();
}
let row_len = batch_query.first().map_or(0, |row| row.len());
stackalloc_with_default(row_len, |evals| {
self.composition_base
.batch_evaluate(batch_query, evals)
.expect("correct by query construction invariant");
let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
let partial_eq_ind_evals_slice = &self.partial_eq_ind_evals[subcube_start..];
let field_sum = PackedField::iter_slice(partial_eq_ind_evals_slice)
.zip(PackedField::iter_slice(evals))
.map(|(eq_ind_scalar, base_scalar)| eq_ind_scalar * base_scalar)
.sum();
P::set_single(field_sum)
})
}
fn composition(&self) -> &Composition {
self.composition
}
fn eq_ind_partial_eval(&self) -> Option<&[P]> {
Some(self.partial_eq_ind_evals)
}
}
impl<F, PBase, P, FDomain, CompositionBase, Composition> SumcheckInterpolator<F>
for ZerocheckFirstRoundEvaluator<'_, PBase, P, FDomain, CompositionBase, Composition>
where
F: Field + ExtensionField<PBase::Scalar> + ExtensionField<FDomain>,
PBase: PackedField,
P: PackedField<Scalar = F>,
FDomain: Field,
{
fn round_evals_to_coeffs(
&self,
last_round_sum: F,
mut round_evals: Vec<F>,
) -> Result<Vec<F>, PolynomialError> {
assert_eq!(last_round_sum, F::ZERO);
round_evals.insert(0, P::Scalar::ZERO);
round_evals.insert(0, P::Scalar::ZERO);
let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
Ok(coeffs)
}
}
struct ZerocheckLaterRoundEvaluator<'a, P, FDomain, Composition>
where
P: PackedField,
FDomain: Field,
{
composition: &'a Composition,
interpolation_domain: &'a InterpolationDomain<FDomain>,
partial_eq_ind_evals: &'a [P],
round_zerocheck_challenge: P::Scalar,
}
impl<F, P, FDomain, Composition> SumcheckEvaluator<P, P, Composition>
for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition>
where
F: Field + ExtensionField<FDomain>,
P: PackedField<Scalar = F> + PackedExtension<FDomain>,
FDomain: Field,
Composition: CompositionPolyOS<P>,
{
fn eval_point_indices(&self) -> Range<usize> {
1..self.composition.degree() + 1
}
fn process_subcube_at_eval_point(
&self,
subcube_vars: usize,
subcube_index: usize,
batch_query: &[&[P]],
) -> P {
if self.composition.degree() == 1 {
return P::zero();
}
let row_len = batch_query.first().map_or(0, |row| row.len());
stackalloc_with_default(row_len, |evals| {
self.composition
.batch_evaluate(batch_query, evals)
.expect("correct by query construction invariant");
let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
for (i, eval) in evals.iter_mut().enumerate() {
*eval *= self.partial_eq_ind_evals[subcube_start + i];
}
evals.iter().copied().sum::<P>()
})
}
fn composition(&self) -> &Composition {
self.composition
}
fn eq_ind_partial_eval(&self) -> Option<&[P]> {
Some(self.partial_eq_ind_evals)
}
}
impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition>
where
F: Field + ExtensionField<FDomain>,
P: PackedField<Scalar = F> + PackedExtension<FDomain>,
FDomain: Field,
{
fn round_evals_to_coeffs(
&self,
last_round_sum: F,
mut round_evals: Vec<F>,
) -> Result<Vec<F>, PolynomialError> {
let alpha = self.round_zerocheck_challenge;
let one_evaluation = round_evals[0]; let zero_evaluation_numerator = last_round_sum - one_evaluation * alpha;
let zero_evaluation_denominator_inv = (F::ONE - alpha).invert_or_zero();
let zero_evaluation = zero_evaluation_numerator * zero_evaluation_denominator_inv;
round_evals.insert(0, zero_evaluation);
let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
Ok(coeffs)
}
}