binius_core/protocols/sumcheck/
front_loaded.rs1use std::{cmp, cmp::Ordering, collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_math::{evaluate_univariate, CompositionPoly};
7use binius_utils::sorting::is_sorted_ascending;
8use bytes::Buf;
9
10use super::{
11 common::batch_weighted_value,
12 error::{Error, VerificationError},
13 verify_sumcheck::compute_expected_batch_composite_evaluation_single_claim,
14 RoundCoeffs, RoundProof,
15};
16use crate::{
17 fiat_shamir::{CanSample, Challenger},
18 protocols::sumcheck::{BatchSumcheckOutput, SumcheckClaim},
19 transcript::{TranscriptReader, VerifierTranscript},
20};
21
22#[derive(Debug)]
23enum CoeffsOrSums<F: Field> {
24 Coeffs(RoundCoeffs<F>),
25 Sum(F),
26}
27
28#[derive(Debug)]
56pub struct BatchVerifier<F: Field, C> {
57 claims: VecDeque<SumcheckClaimWithContext<F, C>>,
58 round: usize,
59 last_coeffs_or_sum: CoeffsOrSums<F>,
60}
61
62impl<F, C> BatchVerifier<F, C>
63where
64 F: TowerField,
65 C: CompositionPoly<F> + Clone,
66{
67 pub fn new<Transcript>(
75 claims: &[SumcheckClaim<F, C>],
76 transcript: &mut Transcript,
77 ) -> Result<Self, Error>
78 where
79 Transcript: CanSample<F>,
80 {
81 let batch_coeffs = transcript.sample_vec(claims.len());
83
84 let sum = iter::zip(claims, &batch_coeffs)
86 .map(|(claim, &batch_coeff)| {
87 batch_weighted_value(
88 batch_coeff,
89 claim
90 .composite_sums()
91 .iter()
92 .map(|composite_claim| composite_claim.sum),
93 )
94 })
95 .sum();
96
97 Self::new_prebatched(batch_coeffs, sum, claims)
98 }
99
100 pub fn new_prebatched(
107 batch_coeffs: Vec<F>,
108 sum: F,
109 claims: &[SumcheckClaim<F, C>],
110 ) -> Result<Self, Error> {
111 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
112 return Err(Error::ClaimsOutOfOrder);
113 }
114
115 if batch_coeffs.len() != claims.len() {
116 return Err(Error::IncorrectNumberOfBatchCoeffs);
117 }
118
119 let mut claims = iter::zip(claims.iter().cloned(), batch_coeffs)
120 .map(|(claim, batch_coeff)| {
121 let degree = claim
122 .composite_sums()
123 .iter()
124 .map(|composite_claim| composite_claim.composition.degree())
125 .max()
126 .unwrap_or(0);
127 SumcheckClaimWithContext {
128 claim,
129 batch_coeff,
130 max_degree_remaining: degree,
131 }
132 })
133 .collect::<VecDeque<_>>();
134
135 for i in (0..claims.len()).rev().skip(1) {
137 claims[i].max_degree_remaining =
138 cmp::max(claims[i].max_degree_remaining, claims[i + 1].max_degree_remaining);
139 }
140
141 Ok(Self {
142 claims,
143 round: 0,
144 last_coeffs_or_sum: CoeffsOrSums::Sum(sum),
145 })
146 }
147
148 pub fn total_rounds(&self) -> usize {
150 self.claims
151 .back()
152 .map(|claim_with_context| claim_with_context.claim.n_vars())
153 .unwrap_or(0)
154 }
155
156 pub fn remaining_claims(&self) -> usize {
158 self.claims.len()
159 }
160
161 pub fn try_finish_claim<B>(
163 &mut self,
164 transcript: &mut TranscriptReader<B>,
165 ) -> Result<Option<Vec<F>>, Error>
166 where
167 B: Buf,
168 {
169 let Some(SumcheckClaimWithContext { claim, .. }) = self.claims.front() else {
170 return Ok(None);
171 };
172 let multilinear_evals = match claim.n_vars().cmp(&self.round) {
173 Ordering::Equal => {
174 let SumcheckClaimWithContext {
175 claim, batch_coeff, ..
176 } = self.claims.pop_front().expect("front returned Some");
177 let multilinear_evals = transcript.read_scalar_slice(claim.n_multilinears())?;
178 match self.last_coeffs_or_sum {
179 CoeffsOrSums::Coeffs(_) => {
180 return Err(Error::ExpectedFinishRound);
181 }
182 CoeffsOrSums::Sum(ref mut sum) => {
183 *sum -= compute_expected_batch_composite_evaluation_single_claim(
188 batch_coeff,
189 &claim,
190 &multilinear_evals,
191 )?;
192 }
193 }
194 Some(multilinear_evals)
195 }
196 Ordering::Less => {
197 unreachable!(
198 "round is incremented in finish_round; \
199 finish_round does not increment round until receive_round_proof is called; \
200 receive_round_proof errors unless the claim at the active index has enough \
201 variables"
202 );
203 }
204 Ordering::Greater => None,
205 };
206 Ok(multilinear_evals)
207 }
208
209 pub fn receive_round_proof<B>(
211 &mut self,
212 transcript: &mut TranscriptReader<B>,
213 ) -> Result<(), Error>
214 where
215 B: Buf,
216 {
217 match self.last_coeffs_or_sum {
218 CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
219 CoeffsOrSums::Sum(sum) => {
220 let degree = match self.claims.front() {
221 Some(SumcheckClaimWithContext {
222 claim,
223 max_degree_remaining,
224 ..
225 }) => {
226 if claim.n_vars() == self.round {
228 return Err(Error::ExpectedFinishClaim);
229 }
230 *max_degree_remaining
231 }
232 None => 0,
233 };
234
235 let proof_vals = transcript.read_scalar_slice(degree)?;
236 let round_proof = RoundProof(RoundCoeffs(proof_vals));
237 self.last_coeffs_or_sum = CoeffsOrSums::Coeffs(round_proof.recover(sum));
238 Ok(())
239 }
240 }
241 }
242
243 pub fn finish_round(&mut self, challenge: F) -> Result<(), Error> {
245 match self.last_coeffs_or_sum {
246 CoeffsOrSums::Coeffs(ref coeffs) => {
247 let sum = evaluate_univariate(&coeffs.0, challenge);
248 self.last_coeffs_or_sum = CoeffsOrSums::Sum(sum);
249 self.round += 1;
250 Ok(())
251 }
252 CoeffsOrSums::Sum(_) => Err(Error::ExpectedReceiveCoeffs),
253 }
254 }
255
256 pub fn finish(self) -> Result<(), Error> {
258 if !self.claims.is_empty() {
259 return Err(Error::ExpectedFinishRound);
260 }
261
262 match self.last_coeffs_or_sum {
263 CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
264 CoeffsOrSums::Sum(sum) => {
265 if sum != F::ZERO {
266 return Err(VerificationError::IncorrectBatchEvaluation.into());
267 }
268 Ok(())
269 }
270 }
271 }
272
273 pub fn run<Challenger_>(
275 mut self,
276 transcript: &mut VerifierTranscript<Challenger_>,
277 ) -> Result<BatchSumcheckOutput<F>, Error>
278 where
279 Challenger_: Challenger,
280 {
281 let rounds_count = self.total_rounds();
282
283 let mut multilinear_evals = Vec::with_capacity(self.remaining_claims());
284 let mut challenges = Vec::with_capacity(rounds_count);
285
286 for _round_no in 0..rounds_count {
287 let mut reader = transcript.message();
288 while let Some(claim_multilinear_evals) = self.try_finish_claim(&mut reader)? {
289 multilinear_evals.push(claim_multilinear_evals);
290 }
291 self.receive_round_proof(&mut reader)?;
292
293 let challenge = transcript.sample();
294 challenges.push(challenge);
295
296 self.finish_round(challenge)?;
297 }
298
299 let mut reader = transcript.message();
300 while let Some(claim_multilinear_evals) = self.try_finish_claim(&mut reader)? {
301 multilinear_evals.push(claim_multilinear_evals);
302 }
303 self.finish()?;
304
305 Ok(BatchSumcheckOutput {
306 challenges,
307 multilinear_evals,
308 })
309 }
310}
311
312#[derive(Debug)]
313struct SumcheckClaimWithContext<F: Field, C> {
314 claim: SumcheckClaim<F, C>,
315 batch_coeff: F,
316 max_degree_remaining: usize,
317}