binius_core/protocols/sumcheck/
verify_sumcheck.rs1use binius_field::{Field, TowerField};
4use binius_math::{evaluate_univariate, CompositionPoly, EvaluationOrder};
5use binius_utils::{bail, sorting::is_sorted_ascending};
6use itertools::izip;
7
8use super::{
9 common::{batch_weighted_value, BatchSumcheckOutput, RoundProof, SumcheckClaim},
10 error::{Error, VerificationError},
11 RoundCoeffs,
12};
13use crate::{
14 fiat_shamir::{CanSample, Challenger},
15 transcript::VerifierTranscript,
16};
17
18pub fn batch_verify<F, Composition, Challenger_>(
31 evaluation_order: EvaluationOrder,
32 claims: &[SumcheckClaim<F, Composition>],
33 transcript: &mut VerifierTranscript<Challenger_>,
34) -> Result<BatchSumcheckOutput<F>, Error>
35where
36 F: TowerField,
37 Composition: CompositionPoly<F>,
38 Challenger_: Challenger,
39{
40 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
42 bail!(Error::ClaimsOutOfOrder);
43 }
44
45 let n_rounds = claims.iter().map(|claim| claim.n_vars()).max().unwrap_or(0);
46
47 let mut active_index = 0;
50 let mut batch_coeffs = Vec::with_capacity(claims.len());
51 let mut challenges = Vec::with_capacity(n_rounds);
52 let mut sum = F::ZERO;
53 let mut max_degree = 0; for round_no in 0..n_rounds {
55 let n_vars = n_rounds - round_no;
56
57 while let Some(claim) = claims.get(active_index) {
58 if claim.n_vars() != n_vars {
59 break;
60 }
61
62 let next_batch_coeff = transcript.sample();
63 batch_coeffs.push(next_batch_coeff);
64
65 sum += batch_weighted_value(
67 next_batch_coeff,
68 claim
69 .composite_sums()
70 .iter()
71 .map(|inner_claim| inner_claim.sum),
72 );
73 max_degree = max_degree.max(claim.max_individual_degree());
74 active_index += 1;
75 }
76
77 let coeffs = transcript.message().read_scalar_slice(max_degree)?;
78 let round_proof = RoundProof(RoundCoeffs(coeffs));
79
80 let challenge = transcript.sample();
81 challenges.push(challenge);
82
83 sum = interpolate_round_proof(round_proof, sum, challenge);
84 }
85
86 while let Some(claim) = claims.get(active_index) {
88 debug_assert_eq!(claim.n_vars(), 0);
89
90 let next_batch_coeff = transcript.sample();
91 batch_coeffs.push(next_batch_coeff);
92
93 sum += batch_weighted_value(
95 next_batch_coeff,
96 claim
97 .composite_sums()
98 .iter()
99 .map(|inner_claim| inner_claim.sum),
100 );
101 active_index += 1;
102 }
103
104 let mut multilinear_evals = Vec::with_capacity(claims.len());
105 let mut reader = transcript.message();
106 for claim in claims {
107 let evals = reader.read_scalar_slice::<F>(claim.n_multilinears())?;
108 multilinear_evals.push(evals);
109 }
110
111 let expected_sum = compute_expected_batch_composite_evaluation_multi_claim(
112 batch_coeffs,
113 claims,
114 &multilinear_evals,
115 )?;
116
117 if sum != expected_sum {
118 return Err(VerificationError::IncorrectBatchEvaluation.into());
119 }
120
121 if EvaluationOrder::HighToLow == evaluation_order {
122 challenges.reverse();
123 }
124
125 Ok(BatchSumcheckOutput {
126 challenges,
127 multilinear_evals,
128 })
129}
130
131pub fn compute_expected_batch_composite_evaluation_single_claim<F: Field, Composition>(
132 batch_coeff: F,
133 claim: &SumcheckClaim<F, Composition>,
134 multilinear_evals: &[F],
135) -> Result<F, Error>
136where
137 Composition: CompositionPoly<F>,
138{
139 let composite_evals = claim
140 .composite_sums()
141 .iter()
142 .map(|sum_claim| sum_claim.composition.evaluate(multilinear_evals))
143 .collect::<Result<Vec<_>, _>>()?;
144 Ok(batch_weighted_value(batch_coeff, composite_evals.into_iter()))
145}
146
147fn compute_expected_batch_composite_evaluation_multi_claim<F: Field, Composition>(
148 batch_coeffs: Vec<F>,
149 claims: &[SumcheckClaim<F, Composition>],
150 multilinear_evals: &[Vec<F>],
151) -> Result<F, Error>
152where
153 Composition: CompositionPoly<F>,
154{
155 izip!(batch_coeffs, claims, multilinear_evals.iter())
156 .map(|(batch_coeff, claim, multilinear_evals)| {
157 compute_expected_batch_composite_evaluation_single_claim(
158 batch_coeff,
159 claim,
160 multilinear_evals,
161 )
162 })
163 .try_fold(F::ZERO, |sum, term| Ok(sum + term?))
164}
165
166pub fn interpolate_round_proof<F: Field>(round_proof: RoundProof<F>, sum: F, challenge: F) -> F {
167 let coeffs = round_proof.recover(sum);
168 evaluate_univariate(&coeffs.0, challenge)
169}