binius_core/protocols/sumcheck/
verify.rsuse binius_field::{Field, TowerField};
use binius_math::{evaluate_univariate, CompositionPolyOS};
use binius_utils::{bail, sorting::is_sorted_ascending};
use itertools::izip;
use tracing::instrument;
use super::{
common::{batch_weighted_value, BatchSumcheckOutput, RoundProof, SumcheckClaim},
error::{Error, VerificationError},
RoundCoeffs,
};
use crate::{fiat_shamir::CanSample, transcript::CanRead};
pub fn batch_verify<F, Composition, Transcript>(
claims: &[SumcheckClaim<F, Composition>],
transcript: &mut Transcript,
) -> Result<BatchSumcheckOutput<F>, Error>
where
F: TowerField,
Composition: CompositionPolyOS<F>,
Transcript: CanSample<F> + CanRead,
{
let start = BatchVerifyStart {
batch_coeffs: Vec::new(),
sum: F::ZERO,
max_degree: 0,
skip_rounds: 0,
};
batch_verify_with_start(start, claims, transcript)
}
#[derive(Debug)]
pub struct BatchVerifyStart<F: Field> {
pub batch_coeffs: Vec<F>,
pub sum: F,
pub max_degree: usize,
pub skip_rounds: usize,
}
#[instrument(skip_all, level = "debug")]
pub fn batch_verify_with_start<F, Composition, Transcript>(
start: BatchVerifyStart<F>,
claims: &[SumcheckClaim<F, Composition>],
transcript: &mut Transcript,
) -> Result<BatchSumcheckOutput<F>, Error>
where
F: TowerField,
Composition: CompositionPolyOS<F>,
Transcript: CanSample<F> + CanRead,
{
let BatchVerifyStart {
mut batch_coeffs,
mut sum,
mut max_degree,
skip_rounds,
} = start;
if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
bail!(Error::ClaimsOutOfOrder);
}
if batch_coeffs.len() > claims.len() {
bail!(Error::TooManyPrebatchedCoeffs);
}
let n_rounds = claims.iter().map(|claim| claim.n_vars()).max().unwrap_or(0);
if skip_rounds > n_rounds {
return Err(VerificationError::IncorrectSkippedRoundsCount.into());
}
let mut active_index = batch_coeffs.len();
let mut challenges = Vec::with_capacity(n_rounds - skip_rounds);
for round_no in skip_rounds..n_rounds {
let n_vars = n_rounds - round_no;
while let Some(claim) = claims.get(active_index) {
if claim.n_vars() != n_vars {
break;
}
let next_batch_coeff = transcript.sample();
batch_coeffs.push(next_batch_coeff);
sum += batch_weighted_value(
next_batch_coeff,
claim
.composite_sums()
.iter()
.map(|inner_claim| inner_claim.sum),
);
max_degree = max_degree.max(claim.max_individual_degree());
active_index += 1;
}
let coeffs = transcript.read_scalar_slice(max_degree)?;
let round_proof = RoundProof(RoundCoeffs(coeffs));
let challenge = transcript.sample();
challenges.push(challenge);
sum = interpolate_round_proof(round_proof, sum, challenge);
}
while let Some(claim) = claims.get(active_index) {
debug_assert_eq!(claim.n_vars(), 0);
let next_batch_coeff = transcript.sample();
batch_coeffs.push(next_batch_coeff);
sum += batch_weighted_value(
next_batch_coeff,
claim
.composite_sums()
.iter()
.map(|inner_claim| inner_claim.sum),
);
active_index += 1;
}
let mut multilinear_evals = Vec::with_capacity(claims.len());
for claim in claims.iter() {
let evals = transcript.read_scalar_slice::<F>(claim.n_multilinears())?;
multilinear_evals.push(evals);
}
let expected_sum = compute_expected_batch_composite_evaluation_multi_claim(
batch_coeffs,
claims,
&multilinear_evals,
)?;
if sum != expected_sum {
return Err(VerificationError::IncorrectBatchEvaluation.into());
}
Ok(BatchSumcheckOutput {
challenges,
multilinear_evals,
})
}
pub fn compute_expected_batch_composite_evaluation_single_claim<F: Field, Composition>(
batch_coeff: F,
claim: &SumcheckClaim<F, Composition>,
multilinear_evals: &[F],
) -> Result<F, Error>
where
Composition: CompositionPolyOS<F>,
{
let composite_evals = claim
.composite_sums()
.iter()
.map(|sum_claim| sum_claim.composition.evaluate(multilinear_evals))
.collect::<Result<Vec<_>, _>>()?;
Ok(batch_weighted_value(batch_coeff, composite_evals.into_iter()))
}
fn compute_expected_batch_composite_evaluation_multi_claim<F: Field, Composition>(
batch_coeffs: Vec<F>,
claims: &[SumcheckClaim<F, Composition>],
multilinear_evals: &[Vec<F>],
) -> Result<F, Error>
where
Composition: CompositionPolyOS<F>,
{
izip!(batch_coeffs, claims, multilinear_evals.iter())
.map(|(batch_coeff, claim, multilinear_evals)| {
compute_expected_batch_composite_evaluation_single_claim(
batch_coeff,
claim,
multilinear_evals,
)
})
.try_fold(F::ZERO, |sum, term| Ok(sum + term?))
}
pub fn interpolate_round_proof<F: Field>(round_proof: RoundProof<F>, sum: F, challenge: F) -> F {
let coeffs = round_proof.recover(sum);
evaluate_univariate(&coeffs.0, challenge)
}