use crate::polynomial::{CompositionPoly, MultilinearComposite, MultilinearPoly};
use binius_field::{Field, PackedField};
use binius_math::evaluate_univariate;
use binius_utils::bail;
use crate::protocols::abstract_sumcheck::{
AbstractSumcheckClaim, AbstractSumcheckReductor, AbstractSumcheckRound,
AbstractSumcheckRoundClaim, AbstractSumcheckWitness, Error as AbstractSumcheckError,
};
use super::{Error, VerificationError};
#[derive(Debug, Clone)]
pub struct GkrSumcheckClaim<F: Field> {
pub n_vars: usize,
pub degree: usize,
pub sum: F,
pub r: Vec<F>,
}
impl<F: Field> AbstractSumcheckClaim<F> for GkrSumcheckClaim<F> {
fn n_vars(&self) -> usize {
self.n_vars
}
fn max_individual_degree(&self) -> usize {
self.degree
}
fn sum(&self) -> F {
self.sum
}
}
#[derive(Debug, Clone)]
pub struct GkrSumcheckWitness<P, C, M>
where
P: PackedField,
C: CompositionPoly<P>,
M: MultilinearPoly<P> + Clone + Send + Sync,
{
pub poly: MultilinearComposite<P, C, M>,
pub current_layer: M,
}
impl<P, C, M> AbstractSumcheckWitness<P> for GkrSumcheckWitness<P, C, M>
where
P: PackedField,
C: CompositionPoly<P>,
M: MultilinearPoly<P> + Clone + Send + Sync,
{
type MultilinearId = (usize, usize);
type Composition = C;
type Multilinear = M;
fn composition(&self) -> &C {
&self.poly.composition
}
fn multilinears(
&self,
seq_id: usize,
_claim_multilinear_ids: &[(usize, usize)],
) -> Result<impl IntoIterator<Item = ((usize, usize), M)>, AbstractSumcheckError> {
Ok(self
.poly
.multilinears
.iter()
.cloned()
.enumerate()
.map(move |(multilin_seq_id, multilinear)| ((seq_id, multilin_seq_id), multilinear)))
}
}
pub type GkrSumcheckRound<F> = AbstractSumcheckRound<F>;
pub type GkrSumcheckRoundClaim<F> = AbstractSumcheckRoundClaim<F>;
pub struct GkrSumcheckReductor<'a, F> {
pub max_individual_degree: usize,
pub gkr_challenge_point: &'a [F],
}
impl<'a, F: Field> AbstractSumcheckReductor<F> for GkrSumcheckReductor<'a, F> {
type Error = Error;
fn validate_round_proof_shape(
&self,
_round: usize,
proof: &AbstractSumcheckRound<F>,
) -> Result<(), Self::Error> {
if proof.coeffs.len() != self.max_individual_degree {
return Err(VerificationError::NumberOfCoefficients {
expected: self.max_individual_degree,
}
.into());
}
Ok(())
}
fn reduce_round_claim(
&self,
round: usize,
claim: AbstractSumcheckRoundClaim<F>,
challenge: F,
round_proof: AbstractSumcheckRound<F>,
) -> Result<AbstractSumcheckRoundClaim<F>, Self::Error> {
if round != claim.partial_point.len() {
bail!(Error::RoundArgumentRoundClaimMismatch);
}
let alpha_i = self.gkr_challenge_point[round];
reduce_round_claim_helper(claim, challenge, round_proof, alpha_i)
}
}
fn reduce_round_claim_helper<F: Field>(
round_claim: GkrSumcheckRoundClaim<F>,
challenge: F,
proof: GkrSumcheckRound<F>,
alpha_i: F,
) -> Result<GkrSumcheckRoundClaim<F>, Error> {
let GkrSumcheckRoundClaim {
mut partial_point,
current_round_sum,
} = round_claim;
let GkrSumcheckRound { mut coeffs } = proof;
let constant_term = current_round_sum - alpha_i * coeffs.iter().sum::<F>();
coeffs.insert(0, constant_term);
let new_round_sum = evaluate_univariate(&coeffs, challenge);
partial_point.push(challenge);
Ok(GkrSumcheckRoundClaim {
partial_point,
current_round_sum: new_round_sum,
})
}