binius_core/protocols/sumcheck/
front_loaded.rsuse std::{cmp, cmp::Ordering, collections::VecDeque, iter};
use binius_field::{Field, TowerField};
use binius_math::{evaluate_univariate, CompositionPolyOS};
use binius_utils::sorting::is_sorted_ascending;
use super::{
common::batch_weighted_value,
error::{Error, VerificationError},
verify::compute_expected_batch_composite_evaluation_single_claim,
RoundCoeffs, RoundProof,
};
use crate::{fiat_shamir::CanSample, protocols::sumcheck::SumcheckClaim, transcript::CanRead};
#[derive(Debug)]
enum CoeffsOrSums<F: Field> {
Coeffs(RoundCoeffs<F>),
Sum(F),
}
#[derive(Debug)]
pub struct BatchVerifier<F: Field, C> {
claims: VecDeque<SumcheckClaimWithContext<F, C>>,
round: usize,
last_coeffs_or_sum: CoeffsOrSums<F>,
}
impl<F, C> BatchVerifier<F, C>
where
F: TowerField,
C: CompositionPolyOS<F> + Clone,
{
pub fn new<Transcript>(
claims: &[SumcheckClaim<F, C>],
transcript: &mut Transcript,
) -> Result<Self, Error>
where
Transcript: CanSample<F>,
{
if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
return Err(Error::ClaimsOutOfOrder);
}
let batch_coeffs = transcript.sample_vec(claims.len());
let sum = iter::zip(claims, &batch_coeffs)
.map(|(claim, &batch_coeff)| {
batch_weighted_value(
batch_coeff,
claim
.composite_sums()
.iter()
.map(|composite_claim| composite_claim.sum),
)
})
.sum();
let mut claims = iter::zip(claims.iter().cloned(), batch_coeffs)
.map(|(claim, batch_coeff)| {
let degree = claim
.composite_sums()
.iter()
.map(|composite_claim| composite_claim.composition.degree())
.max()
.unwrap_or(0);
SumcheckClaimWithContext {
claim,
batch_coeff,
max_degree_remaining: degree,
}
})
.collect::<VecDeque<_>>();
for i in (0..claims.len()).rev().skip(1) {
claims[i].max_degree_remaining =
cmp::max(claims[i].max_degree_remaining, claims[i + 1].max_degree_remaining);
}
Ok(Self {
claims,
round: 0,
last_coeffs_or_sum: CoeffsOrSums::Sum(sum),
})
}
pub fn remaining_claims(&self) -> usize {
self.claims.len()
}
pub fn try_finish_claim<Transcript>(
&mut self,
transcript: &mut Transcript,
) -> Result<Option<Vec<F>>, Error>
where
Transcript: CanRead,
{
let Some(SumcheckClaimWithContext { claim, .. }) = self.claims.front() else {
return Ok(None);
};
let multilinear_evals = match claim.n_vars().cmp(&self.round) {
Ordering::Equal => {
let SumcheckClaimWithContext {
claim, batch_coeff, ..
} = self.claims.pop_front().expect("front returned Some");
let multilinear_evals = transcript.read_scalar_slice(claim.n_multilinears())?;
match self.last_coeffs_or_sum {
CoeffsOrSums::Coeffs(_) => {
return Err(Error::ExpectedFinishRound);
}
CoeffsOrSums::Sum(ref mut sum) => {
*sum -= compute_expected_batch_composite_evaluation_single_claim(
batch_coeff,
&claim,
&multilinear_evals,
)?;
}
}
Some(multilinear_evals)
}
Ordering::Less => {
unreachable!(
"round is incremented in finish_round; \
finish_round does not increment round until receive_round_proof is called; \
receive_round_proof errors unless the claim at the active index has enough \
variables"
);
}
Ordering::Greater => None,
};
Ok(multilinear_evals)
}
pub fn receive_round_proof<Transcript>(
&mut self,
transcript: &mut Transcript,
) -> Result<(), Error>
where
Transcript: CanRead,
{
let degree = match self.claims.front() {
Some(SumcheckClaimWithContext {
claim,
max_degree_remaining,
..
}) => {
if claim.n_vars() == self.round {
return Err(Error::ExpectedFinishClaim);
}
*max_degree_remaining
}
None => 0,
};
match self.last_coeffs_or_sum {
CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
CoeffsOrSums::Sum(sum) => {
let proof_vals = transcript.read_scalar_slice(degree)?;
let round_proof = RoundProof(RoundCoeffs(proof_vals));
self.last_coeffs_or_sum = CoeffsOrSums::Coeffs(round_proof.recover(sum));
Ok(())
}
}
}
pub fn finish_round(&mut self, challenge: F) -> Result<(), Error> {
match self.last_coeffs_or_sum {
CoeffsOrSums::Coeffs(ref coeffs) => {
let sum = evaluate_univariate(&coeffs.0, challenge);
self.last_coeffs_or_sum = CoeffsOrSums::Sum(sum);
self.round += 1;
Ok(())
}
CoeffsOrSums::Sum(_) => Err(Error::ExpectedReceiveCoeffs),
}
}
pub fn finish(self) -> Result<(), Error> {
if !self.claims.is_empty() {
return Err(Error::ExpectedFinishRound);
}
match self.last_coeffs_or_sum {
CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
CoeffsOrSums::Sum(sum) => {
if sum != F::ZERO {
return Err(VerificationError::IncorrectBatchEvaluation.into());
}
Ok(())
}
}
}
}
#[derive(Debug)]
struct SumcheckClaimWithContext<F: Field, C> {
claim: SumcheckClaim<F, C>,
batch_coeff: F,
max_degree_remaining: usize,
}