binius_core/protocols/sumcheck/
front_loaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp, cmp::Ordering, collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_math::{evaluate_univariate, CompositionPoly};
7use binius_utils::sorting::is_sorted_ascending;
8use bytes::Buf;
9
10use super::{
11	common::batch_weighted_value,
12	error::{Error, VerificationError},
13	verify::compute_expected_batch_composite_evaluation_single_claim,
14	RoundCoeffs, RoundProof,
15};
16use crate::{
17	fiat_shamir::CanSample, protocols::sumcheck::SumcheckClaim, transcript::TranscriptReader,
18};
19
20#[derive(Debug)]
21enum CoeffsOrSums<F: Field> {
22	Coeffs(RoundCoeffs<F>),
23	Sum(F),
24}
25
26/// Verifier for a front-loaded batch sumcheck protocol execution.
27///
28/// The sumcheck protocol over can be batched over multiple instances by taking random linear
29/// combinations over the claimed sums and polynomials. When the sumcheck instances are not all
30/// over polynomials with the same number of variables, we can still batch them together.
31///
32/// This version of the protocols sharing the round challenges of the early rounds across sumcheck
33/// claims with different numbers of variables. In contrast, the
34/// [`crate::protocols::sumcheck::verify`] module implements batches sumcheck sharing later
35/// round challenges. We call this version a "front-loaded" sumcheck.
36///
37/// For each sumcheck claim, we sample one random mixing coefficient. The multiple composite claims
38/// within each claim over a group of multilinears are mixed using the powers of the mixing
39/// coefficient.
40///
41/// This exposes a round-by-round interface so that the protocol call be interleaved with other
42/// interactive protocols, sharing the same sequence of challenges. The verification logic must be
43/// invoked with a specific sequence of calls, continuing for as many rounds as necessary until all
44/// claims are finished.
45///
46/// 1. construct a new verifier with [`BatchVerifier::new`]
47/// 2. call [`BatchVerifier::try_finish_claim`] until it returns `None`
48/// 3. if [`BatchVerifier::remaining_claims`] is 0, call [`BatchVerifier::finish`], otherwise
49///    proceed to step 4
50/// 3. call [`BatchVerifier::receive_round_proof`]
51/// 4. sample a random challenge and call [`BatchVerifier::finish_round`] with it
52/// 5. repeat from step 2
53#[derive(Debug)]
54pub struct BatchVerifier<F: Field, C> {
55	claims: VecDeque<SumcheckClaimWithContext<F, C>>,
56	round: usize,
57	last_coeffs_or_sum: CoeffsOrSums<F>,
58}
59
60impl<F, C> BatchVerifier<F, C>
61where
62	F: TowerField,
63	C: CompositionPoly<F> + Clone,
64{
65	/// Constructs a new verifier for the front-loaded batched sumcheck.
66	///
67	/// The constructor samples batching coefficients from the proof transcript.
68	///
69	/// ## Throws
70	///
71	/// * if the claims are not sorted in ascending order by number of variables
72	pub fn new<Transcript>(
73		claims: &[SumcheckClaim<F, C>],
74		transcript: &mut Transcript,
75	) -> Result<Self, Error>
76	where
77		Transcript: CanSample<F>,
78	{
79		if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
80			return Err(Error::ClaimsOutOfOrder);
81		}
82
83		// Sample batch mixing coefficients
84		let batch_coeffs = transcript.sample_vec(claims.len());
85
86		// Compute the batched sum
87		let sum = iter::zip(claims, &batch_coeffs)
88			.map(|(claim, &batch_coeff)| {
89				batch_weighted_value(
90					batch_coeff,
91					claim
92						.composite_sums()
93						.iter()
94						.map(|composite_claim| composite_claim.sum),
95				)
96			})
97			.sum();
98
99		let mut claims = iter::zip(claims.iter().cloned(), batch_coeffs)
100			.map(|(claim, batch_coeff)| {
101				let degree = claim
102					.composite_sums()
103					.iter()
104					.map(|composite_claim| composite_claim.composition.degree())
105					.max()
106					.unwrap_or(0);
107				SumcheckClaimWithContext {
108					claim,
109					batch_coeff,
110					max_degree_remaining: degree,
111				}
112			})
113			.collect::<VecDeque<_>>();
114
115		// Identify the maximum composition degrees
116		for i in (0..claims.len()).rev().skip(1) {
117			claims[i].max_degree_remaining =
118				cmp::max(claims[i].max_degree_remaining, claims[i + 1].max_degree_remaining);
119		}
120
121		Ok(Self {
122			claims,
123			round: 0,
124			last_coeffs_or_sum: CoeffsOrSums::Sum(sum),
125		})
126	}
127
128	/// Returns the number of sumcheck claims that have not finished.
129	pub fn remaining_claims(&self) -> usize {
130		self.claims.len()
131	}
132
133	/// Processes the next finished sumcheck claim, if all of its rounds are complete.
134	pub fn try_finish_claim<B>(
135		&mut self,
136		transcript: &mut TranscriptReader<B>,
137	) -> Result<Option<Vec<F>>, Error>
138	where
139		B: Buf,
140	{
141		let Some(SumcheckClaimWithContext { claim, .. }) = self.claims.front() else {
142			return Ok(None);
143		};
144		let multilinear_evals = match claim.n_vars().cmp(&self.round) {
145			Ordering::Equal => {
146				let SumcheckClaimWithContext {
147					claim, batch_coeff, ..
148				} = self.claims.pop_front().expect("front returned Some");
149				let multilinear_evals = transcript.read_scalar_slice(claim.n_multilinears())?;
150
151				match self.last_coeffs_or_sum {
152					CoeffsOrSums::Coeffs(_) => {
153						return Err(Error::ExpectedFinishRound);
154					}
155					CoeffsOrSums::Sum(ref mut sum) => {
156						// Compute the batched multivariate evaluation at the sumcheck point, using
157						// the prover's claimed multilinear evaluations and subtract it from the
158						// running sum. We defer checking the consistency of the multilinear
159						// evaluations until the end of the protocol.
160						*sum -= compute_expected_batch_composite_evaluation_single_claim(
161							batch_coeff,
162							&claim,
163							&multilinear_evals,
164						)?;
165					}
166				}
167				Some(multilinear_evals)
168			}
169			Ordering::Less => {
170				unreachable!(
171					"round is incremented in finish_round; \
172					finish_round does not increment round until receive_round_proof is called; \
173					receive_round_proof errors unless the claim at the active index has enough \
174					variables"
175				);
176			}
177			Ordering::Greater => None,
178		};
179		Ok(multilinear_evals)
180	}
181
182	/// Reads the round message from the proof transcript.
183	pub fn receive_round_proof<B>(
184		&mut self,
185		transcript: &mut TranscriptReader<B>,
186	) -> Result<(), Error>
187	where
188		B: Buf,
189	{
190		match self.last_coeffs_or_sum {
191			CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
192			CoeffsOrSums::Sum(sum) => {
193				let degree = match self.claims.front() {
194					Some(SumcheckClaimWithContext {
195						claim,
196						max_degree_remaining,
197						..
198					}) => {
199						// Must finish all claims that are ready this round before receiving the round proof.
200						if claim.n_vars() == self.round {
201							return Err(Error::ExpectedFinishClaim);
202						}
203						*max_degree_remaining
204					}
205					None => 0,
206				};
207
208				let proof_vals = transcript.read_scalar_slice(degree)?;
209				let round_proof = RoundProof(RoundCoeffs(proof_vals));
210				self.last_coeffs_or_sum = CoeffsOrSums::Coeffs(round_proof.recover(sum));
211				Ok(())
212			}
213		}
214	}
215
216	/// Finishes an interaction round by reducing the instance with a random challenge.
217	pub fn finish_round(&mut self, challenge: F) -> Result<(), Error> {
218		match self.last_coeffs_or_sum {
219			CoeffsOrSums::Coeffs(ref coeffs) => {
220				let sum = evaluate_univariate(&coeffs.0, challenge);
221				self.last_coeffs_or_sum = CoeffsOrSums::Sum(sum);
222				self.round += 1;
223				Ok(())
224			}
225			CoeffsOrSums::Sum(_) => Err(Error::ExpectedReceiveCoeffs),
226		}
227	}
228
229	/// Performs the final sumcheck verification checks, consuming the verifier.
230	pub fn finish(self) -> Result<(), Error> {
231		if !self.claims.is_empty() {
232			return Err(Error::ExpectedFinishRound);
233		}
234
235		match self.last_coeffs_or_sum {
236			CoeffsOrSums::Coeffs(_) => Err(Error::ExpectedFinishRound),
237			CoeffsOrSums::Sum(sum) => {
238				if sum != F::ZERO {
239					return Err(VerificationError::IncorrectBatchEvaluation.into());
240				}
241				Ok(())
242			}
243		}
244	}
245}
246
247#[derive(Debug)]
248struct SumcheckClaimWithContext<F: Field, C> {
249	claim: SumcheckClaim<F, C>,
250	batch_coeff: F,
251	max_degree_remaining: usize,
252}