use super::{Error, VerificationError};
use crate::{
oracle::{CompositePolyOracle, OracleId},
polynomial::MultilinearComposite,
protocols::{
abstract_sumcheck::{
AbstractSumcheckClaim, AbstractSumcheckProof, AbstractSumcheckReductor,
AbstractSumcheckRound, AbstractSumcheckRoundClaim, AbstractSumcheckWitness,
},
evalcheck::EvalcheckClaim,
},
witness::MultilinearWitness,
};
use binius_field::{Field, PackedField};
use binius_math::evaluate_univariate;
use binius_utils::bail;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct ZerocheckClaim<F: Field> {
pub poly: CompositePolyOracle<F>,
}
impl<F: Field> AbstractSumcheckClaim<F> for ZerocheckClaim<F> {
fn n_vars(&self) -> usize {
self.poly.n_vars()
}
fn max_individual_degree(&self) -> usize {
self.poly.max_individual_degree()
}
fn sum(&self) -> F {
F::ZERO
}
}
impl<F: Field> ZerocheckClaim<F> {
pub fn n_vars(&self) -> usize {
self.poly.n_vars()
}
}
pub type ZerocheckRound<F> = AbstractSumcheckRound<F>;
pub type ZerocheckProof<F> = AbstractSumcheckProof<F>;
pub type ZerocheckRoundClaim<F> = AbstractSumcheckRoundClaim<F>;
pub type ZerocheckWitness<P, C, M> = MultilinearComposite<P, C, M>;
pub type ZerocheckWitnessTypeErased<'a, P, C> =
MultilinearComposite<P, C, MultilinearWitness<'a, P>>;
#[derive(Debug)]
pub struct ZerocheckProveOutput<F: Field> {
pub evalcheck_claim: EvalcheckClaim<F>,
pub zerocheck_proof: ZerocheckProof<F>,
}
pub struct ZerocheckReductor<'a, F> {
pub max_individual_degree: usize,
pub alphas: &'a [F],
}
impl<'a, F: Field> AbstractSumcheckReductor<F> for ZerocheckReductor<'a, F> {
type Error = Error;
fn validate_round_proof_shape(
&self,
_round: usize,
proof: &AbstractSumcheckRound<F>,
) -> Result<(), Self::Error> {
if proof.coeffs.len() != self.max_individual_degree {
return Err(VerificationError::NumberOfCoefficients {
expected: self.max_individual_degree,
}
.into());
}
Ok(())
}
fn reduce_round_claim(
&self,
round: usize,
claim: AbstractSumcheckRoundClaim<F>,
challenge: F,
round_proof: AbstractSumcheckRound<F>,
) -> Result<AbstractSumcheckRoundClaim<F>, Self::Error> {
if round != claim.partial_point.len() {
bail!(Error::RoundArgumentRoundClaimMismatch);
}
let alpha_i = if round == 0 {
None
} else {
Some(self.alphas[round - 1])
};
reduce_intermediate_round_claim_helper(
claim,
challenge,
round_proof,
alpha_i,
self.max_individual_degree,
)
}
}
fn reduce_intermediate_round_claim_helper<F: Field>(
claim: ZerocheckRoundClaim<F>,
challenge: F,
proof: ZerocheckRound<F>,
alpha_i: Option<F>,
degree_bound: usize,
) -> Result<ZerocheckRoundClaim<F>, Error> {
let ZerocheckRoundClaim {
mut partial_point,
current_round_sum,
} = claim;
let ZerocheckRound { mut coeffs } = proof;
let round = partial_point.len();
if round == 0 {
if coeffs.is_empty() {
return Err(VerificationError::NumberOfCoefficients {
expected: degree_bound,
}
.into());
}
if alpha_i.is_some() {
return Err(VerificationError::UnexpectedZerocheckChallengeFound.into());
}
let constant_term = F::ZERO;
let expected_linear_term = F::ZERO - coeffs.iter().skip(1).sum::<F>();
let actual_linear_term = coeffs[0];
if expected_linear_term != actual_linear_term {
bail!(Error::RoundPolynomialCheckFailed);
}
coeffs.insert(0, constant_term);
} else {
if coeffs.is_empty() {
return Err(VerificationError::NumberOfCoefficients {
expected: degree_bound,
}
.into());
}
let alpha_i = alpha_i.ok_or(VerificationError::ExpectedZerocheckChallengeNotFound)?;
let constant_term = current_round_sum - alpha_i * coeffs.iter().sum::<F>();
coeffs.insert(0, constant_term);
}
let new_round_sum = evaluate_univariate(&coeffs, challenge);
partial_point.push(challenge);
Ok(ZerocheckRoundClaim {
partial_point,
current_round_sum: new_round_sum,
})
}
pub fn validate_witness<F, PW, W>(claim: &ZerocheckClaim<F>, witness: W) -> Result<(), Error>
where
F: Field,
PW: PackedField,
W: AbstractSumcheckWitness<PW, MultilinearId = OracleId>,
{
let log_size = claim.n_vars();
let oracle_ids = claim.poly.inner_polys_oracle_ids().collect::<Vec<_>>();
let multilinears = witness
.multilinears(0, oracle_ids.as_slice())?
.into_iter()
.map(|(_, multilinear)| multilinear)
.collect::<Vec<_>>();
let witness = MultilinearComposite::new(log_size, witness.composition(), multilinears)?;
for index in 0..(1 << log_size) {
if witness.evaluate_on_hypercube(index)? != PW::Scalar::zero() {
bail!(Error::NaiveValidation { index });
}
}
Ok(())
}