binius_core/protocols/sumcheck/
front_loaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp, cmp::Ordering, collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_math::{evaluate_univariate, CompositionPoly};
7use binius_utils::sorting::is_sorted_ascending;
8use bytes::Buf;
9
10use super::{
11	common::batch_weighted_value,
12	error::{Error, VerificationError},
13	verify_sumcheck::compute_expected_batch_composite_evaluation_single_claim,
14	RoundCoeffs, RoundProof,
15};
16use crate::{
17	fiat_shamir::{CanSample, Challenger},
18	protocols::sumcheck::{BatchSumcheckOutput, SumcheckClaim},
19	transcript::{TranscriptReader, VerifierTranscript},
20};
21
22#[derive(Debug)]
23enum CoeffsOrSums<F: Field> {
24	Coeffs(RoundCoeffs<F>),
25	Sum(F),
26}
27
28/// Verifier for a front-loaded batch sumcheck protocol execution.
29///
30/// The sumcheck protocol over can be batched over multiple instances by taking random linear
31/// combinations over the claimed sums and polynomials. When the sumcheck instances are not all
32/// over polynomials with the same number of variables, we can still batch them together.
33///
34/// This version of the protocol is sharing the round challenges of the early rounds across sumcheck
35/// claims with different numbers of variables. In contrast, the
36/// [`crate::protocols::sumcheck::verify_sumcheck`] module implements batches sumcheck sharing
37/// later round challenges. We call this version a "front-loaded" sumcheck.
38///
39/// For each sumcheck claim, we sample one random mixing coefficient. The multiple composite claims
40/// within each claim over a group of multilinears are mixed using the powers of the mixing
41/// coefficient.
42///
43/// This exposes a round-by-round interface so that the protocol call be interleaved with other
44/// interactive protocols, sharing the same sequence of challenges. The verification logic must be
45/// invoked with a specific sequence of calls, continuing for as many rounds as necessary until all
46/// claims are finished.
47///
48/// 1. construct a new verifier with [`BatchVerifier::new`]
49/// 2. call [`BatchVerifier::try_finish_claim`] until it returns `None`
50/// 3. if [`BatchVerifier::remaining_claims`] is 0, call [`BatchVerifier::finish`], otherwise
51///    proceed to step 4
52/// 3. call [`BatchVerifier::receive_round_proof`]
53/// 4. sample a random challenge and call [`BatchVerifier::finish_round`] with it
54/// 5. repeat from step 2
55#[derive(Debug)]
56pub struct BatchVerifier<F: Field, C> {
57	claims: VecDeque<SumcheckClaimWithContext<F, C>>,
58	round: usize,
59	last_coeffs_or_sum: CoeffsOrSums<F>,
60}
61
62impl<F, C> BatchVerifier<F, C>
63where
64	F: TowerField,
65	C: CompositionPoly<F> + Clone,
66{
67	/// Constructs a new verifier for the front-loaded batched sumcheck.
68	///
69	/// The constructor samples batching coefficients from the proof transcript.
70	///
71	/// ## Throws
72	///
73	/// * if the claims are not sorted in non-descending order by number of variables
74	pub fn new<Transcript>(
75		claims: &[SumcheckClaim<F, C>],
76		transcript: &mut Transcript,
77	) -> Result<Self, Error>
78	where
79		Transcript: CanSample<F>,
80	{
81		// Sample batch mixing coefficients
82		let batch_coeffs = transcript.sample_vec(claims.len());
83
84		// Compute the batched sum
85		let sum = iter::zip(claims, &batch_coeffs)
86			.map(|(claim, &batch_coeff)| {
87				batch_weighted_value(
88					batch_coeff,
89					claim
90						.composite_sums()
91						.iter()
92						.map(|composite_claim| composite_claim.sum),
93				)
94			})
95			.sum();
96
97		Self::new_prebatched(batch_coeffs, sum, claims)
98	}
99
100	/// Constructs a new verifier for the front-loaded batched sumcheck with
101	/// specified batching coefficients and a batched claims sum.
102	///
103	/// ## Throws
104	///
105	/// * if the claims are not sorted in non-descending order by number of variables
106	pub fn new_prebatched(
107		batch_coeffs: Vec<F>,
108		sum: F,
109		claims: &[SumcheckClaim<F, C>],
110	) -> Result<Self, Error> {
111		if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
112			return Err(Error::ClaimsOutOfOrder);
113		}
114
115		if batch_coeffs.len() != claims.len() {
116			return Err(Error::IncorrectNumberOfBatchCoeffs);
117		}
118
119		let mut claims = iter::zip(claims.iter().cloned(), batch_coeffs)
120			.map(|(claim, batch_coeff)| {
121				let degree = claim
122					.composite_sums()
123					.iter()
124					.map(|composite_claim| composite_claim.composition.degree())
125					.max()
126					.unwrap_or(0);
127				SumcheckClaimWithContext {
128					claim,
129					batch_coeff,
130					max_degree_remaining: degree,
131				}
132			})
133			.collect::<VecDeque<_>>();
134
135		// Identify the maximum composition degrees
136		for i in (0..claims.len()).rev().skip(1) {
137			claims[i].max_degree_remaining =
138				cmp::max(claims[i].max_degree_remaining, claims[i + 1].max_degree_remaining);
139		}
140
141		Ok(Self {
142			claims,
143			round: 0,
144			last_coeffs_or_sum: CoeffsOrSums::Sum(sum),
145		})
146	}
147
148	/// Returns total number of batched sumcheck rounds
149	pub fn total_rounds(&self) -> usize {
150		self.claims
151			.back()
152			.map(|claim_with_context| claim_with_context.claim.n_vars())
153			.unwrap_or(0)
154	}
155
156	/// Returns the number of sumcheck claims that have not finished.
157	pub fn remaining_claims(&self) -> usize {
158		self.claims.len()
159	}
160
161	/// Processes the next finished sumcheck claim, if all of its rounds are complete.
162	pub fn try_finish_claim<B>(
163		&mut self,
164		transcript: &mut TranscriptReader<B>,
165	) -> Result<Option<Vec<F>>, Error>
166	where
167		B: Buf,
168	{
169		let Some(SumcheckClaimWithContext { claim, .. }) = self.claims.front() else {
170			return Ok(None);
171		};
172		let multilinear_evals = match claim.n_vars().cmp(&self.round) {
173			Ordering::Equal => {
174				let SumcheckClaimWithContext {
175					claim, batch_coeff, ..
176				} = self.claims.pop_front().expect("front returned Some");
177				let multilinear_evals = transcript.read_scalar_slice(claim.n_multilinears())?;
178				match self.last_coeffs_or_sum {
179					CoeffsOrSums::Coeffs(_) => {
180						return Err(Error::ExpectedFinishRound);
181					}
182					CoeffsOrSums::Sum(ref mut sum) => {
183						// Compute the batched multivariate evaluation at the sumcheck point, using
184						// the prover's claimed multilinear evaluations and subtract it from the
185						// running sum. We defer checking the consistency of the multilinear
186						// evaluations until the end of the protocol.
187						*sum -= compute_expected_batch_composite_evaluation_single_claim(
188							batch_coeff,
189							&claim,
190							&multilinear_evals,
191						)?;
192					}
193				}
194				Some(multilinear_evals)
195			}
196			Ordering::Less => {
197				unreachable!(
198					"round is incremented in finish_round; \
199					finish_round does not increment round until receive_round_proof is called; \
200					receive_round_proof errors unless the claim at the active index has enough \
201					variables"
202				);
203			}
204			Ordering::Greater => None,
205		};
206		Ok(multilinear_evals)
207	}
208
209	/// Reads the round message from the proof transcript.
210	pub fn receive_round_proof<B>(
211		&mut self,
212		transcript: &mut TranscriptReader<B>,
213	) -> Result<(), Error>
214	where
215		B: Buf,
216	{
217		match self.last_coeffs_or_sum {
218			CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
219			CoeffsOrSums::Sum(sum) => {
220				let degree = match self.claims.front() {
221					Some(SumcheckClaimWithContext {
222						claim,
223						max_degree_remaining,
224						..
225					}) => {
226						// Must finish all claims that are ready this round before receiving the round proof.
227						if claim.n_vars() == self.round {
228							return Err(Error::ExpectedFinishClaim);
229						}
230						*max_degree_remaining
231					}
232					None => 0,
233				};
234
235				let proof_vals = transcript.read_scalar_slice(degree)?;
236				let round_proof = RoundProof(RoundCoeffs(proof_vals));
237				self.last_coeffs_or_sum = CoeffsOrSums::Coeffs(round_proof.recover(sum));
238				Ok(())
239			}
240		}
241	}
242
243	/// Finishes an interaction round by reducing the instance with a random challenge.
244	pub fn finish_round(&mut self, challenge: F) -> Result<(), Error> {
245		match self.last_coeffs_or_sum {
246			CoeffsOrSums::Coeffs(ref coeffs) => {
247				let sum = evaluate_univariate(&coeffs.0, challenge);
248				self.last_coeffs_or_sum = CoeffsOrSums::Sum(sum);
249				self.round += 1;
250				Ok(())
251			}
252			CoeffsOrSums::Sum(_) => Err(Error::ExpectedReceiveCoeffs),
253		}
254	}
255
256	/// Performs the final sumcheck verification checks, consuming the verifier.
257	pub fn finish(self) -> Result<(), Error> {
258		if !self.claims.is_empty() {
259			return Err(Error::ExpectedFinishRound);
260		}
261
262		match self.last_coeffs_or_sum {
263			CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
264			CoeffsOrSums::Sum(sum) => {
265				if sum != F::ZERO {
266					return Err(VerificationError::IncorrectBatchEvaluation.into());
267				}
268				Ok(())
269			}
270		}
271	}
272
273	/// Verifies a front-loaded batch sumcheck protocol execution.
274	pub fn run<Challenger_>(
275		mut self,
276		transcript: &mut VerifierTranscript<Challenger_>,
277	) -> Result<BatchSumcheckOutput<F>, Error>
278	where
279		Challenger_: Challenger,
280	{
281		let rounds_count = self.total_rounds();
282
283		let mut multilinear_evals = Vec::with_capacity(self.remaining_claims());
284		let mut challenges = Vec::with_capacity(rounds_count);
285
286		for _round_no in 0..rounds_count {
287			let mut reader = transcript.message();
288			while let Some(claim_multilinear_evals) = self.try_finish_claim(&mut reader)? {
289				multilinear_evals.push(claim_multilinear_evals);
290			}
291			self.receive_round_proof(&mut reader)?;
292
293			let challenge = transcript.sample();
294			challenges.push(challenge);
295
296			self.finish_round(challenge)?;
297		}
298
299		let mut reader = transcript.message();
300		while let Some(claim_multilinear_evals) = self.try_finish_claim(&mut reader)? {
301			multilinear_evals.push(claim_multilinear_evals);
302		}
303		self.finish()?;
304
305		Ok(BatchSumcheckOutput {
306			challenges,
307			multilinear_evals,
308		})
309	}
310}
311
312#[derive(Debug)]
313struct SumcheckClaimWithContext<F: Field, C> {
314	claim: SumcheckClaim<F, C>,
315	batch_coeff: F,
316	max_degree_remaining: usize,
317}