binius_core/protocols/sumcheck/
front_loaded.rs1use std::{cmp, cmp::Ordering, collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_math::{CompositionPoly, evaluate_univariate};
7use binius_utils::sorting::is_sorted_ascending;
8use bytes::Buf;
9
10use super::{
11 RoundCoeffs, RoundProof,
12 common::batch_weighted_value,
13 error::{Error, VerificationError},
14 verify_sumcheck::compute_expected_batch_composite_evaluation_single_claim,
15};
16use crate::{
17 fiat_shamir::{CanSample, Challenger},
18 protocols::sumcheck::{BatchSumcheckOutput, SumcheckClaim},
19 transcript::{TranscriptReader, VerifierTranscript},
20};
21
22#[derive(Debug)]
23enum CoeffsOrSums<F: Field> {
24 Coeffs(RoundCoeffs<F>),
25 Sum(F),
26}
27
28#[derive(Debug)]
56pub struct BatchVerifier<F: Field, C> {
57 claims: VecDeque<SumcheckClaimWithContext<F, C>>,
58 round: usize,
59 last_coeffs_or_sum: CoeffsOrSums<F>,
60}
61
62impl<F, C> BatchVerifier<F, C>
63where
64 F: TowerField,
65 C: CompositionPoly<F> + Clone,
66{
67 pub fn new<Transcript>(
75 claims: &[SumcheckClaim<F, C>],
76 transcript: &mut Transcript,
77 ) -> Result<Self, Error>
78 where
79 Transcript: CanSample<F>,
80 {
81 let batch_coeffs = transcript.sample_vec(claims.len());
83
84 let sum = iter::zip(claims, &batch_coeffs)
86 .map(|(claim, &batch_coeff)| {
87 batch_weighted_value(
88 batch_coeff,
89 claim
90 .composite_sums()
91 .iter()
92 .map(|composite_claim| composite_claim.sum),
93 )
94 })
95 .sum();
96
97 Self::new_prebatched(batch_coeffs, sum, claims)
98 }
99
100 pub fn new_prebatched(
107 batch_coeffs: Vec<F>,
108 sum: F,
109 claims: &[SumcheckClaim<F, C>],
110 ) -> Result<Self, Error> {
111 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
112 return Err(Error::ClaimsOutOfOrder);
113 }
114
115 if batch_coeffs.len() != claims.len() {
116 return Err(Error::IncorrectNumberOfBatchCoeffs);
117 }
118
119 let mut claims = iter::zip(claims.iter().cloned(), batch_coeffs)
120 .map(|(claim, batch_coeff)| {
121 let degree = claim
122 .composite_sums()
123 .iter()
124 .map(|composite_claim| composite_claim.composition.degree())
125 .max()
126 .unwrap_or(0);
127 SumcheckClaimWithContext {
128 claim,
129 batch_coeff,
130 max_degree_remaining: degree,
131 }
132 })
133 .collect::<VecDeque<_>>();
134
135 for i in (0..claims.len()).rev().skip(1) {
137 claims[i].max_degree_remaining =
138 cmp::max(claims[i].max_degree_remaining, claims[i + 1].max_degree_remaining);
139 }
140
141 Ok(Self {
142 claims,
143 round: 0,
144 last_coeffs_or_sum: CoeffsOrSums::Sum(sum),
145 })
146 }
147
148 pub fn total_rounds(&self) -> usize {
150 self.claims
151 .back()
152 .map(|claim_with_context| claim_with_context.claim.n_vars())
153 .unwrap_or(0)
154 }
155
156 pub fn remaining_claims(&self) -> usize {
158 self.claims.len()
159 }
160
161 pub fn try_finish_claim<B>(
163 &mut self,
164 transcript: &mut TranscriptReader<B>,
165 ) -> Result<Option<Vec<F>>, Error>
166 where
167 B: Buf,
168 {
169 let Some(SumcheckClaimWithContext { claim, .. }) = self.claims.front() else {
170 return Ok(None);
171 };
172 let multilinear_evals = match claim.n_vars().cmp(&self.round) {
173 Ordering::Equal => {
174 let SumcheckClaimWithContext {
175 claim, batch_coeff, ..
176 } = self.claims.pop_front().expect("front returned Some");
177 let multilinear_evals = transcript.read_scalar_slice(claim.n_multilinears())?;
178 match self.last_coeffs_or_sum {
179 CoeffsOrSums::Coeffs(_) => {
180 return Err(Error::ExpectedFinishRound);
181 }
182 CoeffsOrSums::Sum(ref mut sum) => {
183 *sum -= compute_expected_batch_composite_evaluation_single_claim(
188 batch_coeff,
189 &claim,
190 &multilinear_evals,
191 )?;
192 }
193 }
194 Some(multilinear_evals)
195 }
196 Ordering::Less => {
197 unreachable!(
198 "round is incremented in finish_round; \
199 finish_round does not increment round until receive_round_proof is called; \
200 receive_round_proof errors unless the claim at the active index has enough \
201 variables"
202 );
203 }
204 Ordering::Greater => None,
205 };
206 Ok(multilinear_evals)
207 }
208
209 pub fn receive_round_proof<B>(
211 &mut self,
212 transcript: &mut TranscriptReader<B>,
213 ) -> Result<(), Error>
214 where
215 B: Buf,
216 {
217 match self.last_coeffs_or_sum {
218 CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
219 CoeffsOrSums::Sum(sum) => {
220 let degree = match self.claims.front() {
221 Some(SumcheckClaimWithContext {
222 claim,
223 max_degree_remaining,
224 ..
225 }) => {
226 if claim.n_vars() == self.round {
229 return Err(Error::ExpectedFinishClaim);
230 }
231 *max_degree_remaining
232 }
233 None => 0,
234 };
235
236 let proof_vals = transcript.read_scalar_slice(degree)?;
237 let round_proof = RoundProof(RoundCoeffs(proof_vals));
238 self.last_coeffs_or_sum = CoeffsOrSums::Coeffs(round_proof.recover(sum));
239 Ok(())
240 }
241 }
242 }
243
244 pub fn finish_round(&mut self, challenge: F) -> Result<(), Error> {
246 match self.last_coeffs_or_sum {
247 CoeffsOrSums::Coeffs(ref coeffs) => {
248 let sum = evaluate_univariate(&coeffs.0, challenge);
249 self.last_coeffs_or_sum = CoeffsOrSums::Sum(sum);
250 self.round += 1;
251 Ok(())
252 }
253 CoeffsOrSums::Sum(_) => Err(Error::ExpectedReceiveCoeffs),
254 }
255 }
256
257 pub fn finish(self) -> Result<(), Error> {
259 if !self.claims.is_empty() {
260 return Err(Error::ExpectedFinishRound);
261 }
262
263 match self.last_coeffs_or_sum {
264 CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
265 CoeffsOrSums::Sum(sum) => {
266 if sum != F::ZERO {
267 return Err(VerificationError::IncorrectBatchEvaluation.into());
268 }
269 Ok(())
270 }
271 }
272 }
273
274 pub fn run<Challenger_>(
276 mut self,
277 transcript: &mut VerifierTranscript<Challenger_>,
278 ) -> Result<BatchSumcheckOutput<F>, Error>
279 where
280 Challenger_: Challenger,
281 {
282 let rounds_count = self.total_rounds();
283
284 let mut multilinear_evals = Vec::with_capacity(self.remaining_claims());
285 let mut challenges = Vec::with_capacity(rounds_count);
286
287 for _round_no in 0..rounds_count {
288 let mut reader = transcript.message();
289 while let Some(claim_multilinear_evals) = self.try_finish_claim(&mut reader)? {
290 multilinear_evals.push(claim_multilinear_evals);
291 }
292 self.receive_round_proof(&mut reader)?;
293
294 let challenge = transcript.sample();
295 challenges.push(challenge);
296
297 self.finish_round(challenge)?;
298 }
299
300 let mut reader = transcript.message();
301 while let Some(claim_multilinear_evals) = self.try_finish_claim(&mut reader)? {
302 multilinear_evals.push(claim_multilinear_evals);
303 }
304 self.finish()?;
305
306 Ok(BatchSumcheckOutput {
307 challenges,
308 multilinear_evals,
309 })
310 }
311}
312
313#[derive(Debug)]
314struct SumcheckClaimWithContext<F: Field, C> {
315 claim: SumcheckClaim<F, C>,
316 batch_coeff: F,
317 max_degree_remaining: usize,
318}