use super::{Error, VerificationError};
use crate::{
oracle::{CompositePolyOracle, OracleId},
polynomial::MultilinearComposite,
protocols::{
abstract_sumcheck::{
AbstractSumcheckClaim, AbstractSumcheckProof, AbstractSumcheckReductor,
AbstractSumcheckRound, AbstractSumcheckRoundClaim, AbstractSumcheckWitness,
},
evalcheck::EvalcheckClaim,
},
};
use binius_field::{Field, PackedField};
use binius_math::evaluate_univariate;
use binius_utils::bail;
pub type SumcheckRound<F> = AbstractSumcheckRound<F>;
pub type SumcheckProof<F> = AbstractSumcheckProof<F>;
#[derive(Debug)]
pub struct SumcheckProveOutput<F: Field> {
pub evalcheck_claim: EvalcheckClaim<F>,
pub sumcheck_proof: SumcheckProof<F>,
}
#[derive(Debug, Clone)]
pub struct SumcheckClaim<F: Field> {
pub poly: CompositePolyOracle<F>,
pub sum: F,
}
impl<F: Field> AbstractSumcheckClaim<F> for SumcheckClaim<F> {
fn n_vars(&self) -> usize {
self.poly.n_vars()
}
fn max_individual_degree(&self) -> usize {
self.poly.max_individual_degree()
}
fn sum(&self) -> F {
self.sum
}
}
pub type SumcheckWitness<P, C, M> = MultilinearComposite<P, C, M>;
pub type SumcheckRoundClaim<F> = AbstractSumcheckRoundClaim<F>;
pub struct SumcheckReductor {
pub max_individual_degree: usize,
}
impl<F: Field> AbstractSumcheckReductor<F> for SumcheckReductor {
type Error = Error;
fn validate_round_proof_shape(
&self,
_round: usize,
proof: &AbstractSumcheckRound<F>,
) -> Result<(), Self::Error> {
if proof.coeffs.len() != self.max_individual_degree {
return Err(VerificationError::NumberOfCoefficients {
expected: self.max_individual_degree,
}
.into());
}
Ok(())
}
fn reduce_round_claim(
&self,
_round: usize,
claim: AbstractSumcheckRoundClaim<F>,
challenge: F,
round_proof: AbstractSumcheckRound<F>,
) -> Result<AbstractSumcheckRoundClaim<F>, Self::Error> {
reduce_intermediate_round_claim_helper(claim, challenge, round_proof)
}
}
fn reduce_intermediate_round_claim_helper<F: Field>(
claim: SumcheckRoundClaim<F>,
challenge: F,
proof: SumcheckRound<F>,
) -> Result<SumcheckRoundClaim<F>, Error> {
let SumcheckRoundClaim {
mut partial_point,
current_round_sum,
} = claim;
let SumcheckRound { mut coeffs } = proof;
let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
let last_coeff = current_round_sum - first_coeff - coeffs.iter().sum::<F>();
coeffs.push(last_coeff);
let new_round_sum = evaluate_univariate(&coeffs, challenge);
partial_point.push(challenge);
Ok(SumcheckRoundClaim {
partial_point,
current_round_sum: new_round_sum,
})
}
pub fn validate_witness<F, PW, W>(claim: &SumcheckClaim<F>, witness: W) -> Result<(), Error>
where
F: Field,
PW: PackedField<Scalar: From<F> + Into<F>>,
W: AbstractSumcheckWitness<PW, MultilinearId = OracleId>,
{
let log_size = claim.n_vars();
let oracle_ids = claim.poly.inner_polys_oracle_ids().collect::<Vec<_>>();
let multilinears = witness
.multilinears(0, oracle_ids.as_slice())?
.into_iter()
.map(|(_, multilinear)| multilinear)
.collect::<Vec<_>>();
let witness = MultilinearComposite::new(log_size, witness.composition(), multilinears)?;
let sum = (0..(1 << log_size))
.try_fold(PW::Scalar::ZERO, |acc, i| witness.evaluate_on_hypercube(i).map(|res| res + acc));
if sum? == claim.sum().into() {
Ok(())
} else {
bail!(Error::NaiveValidation)
}
}