binius_core/protocols/sumcheck/prove/
batch_sumcheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6use itertools::izip;
7use tracing::instrument;
8
9use crate::{
10	fiat_shamir::{CanSample, Challenger},
11	protocols::sumcheck::{
12		common::{BatchSumcheckOutput, RoundCoeffs},
13		error::Error,
14	},
15	transcript::ProverTranscript,
16};
17
18/// A sumcheck prover with a round-by-round execution interface.
19///
20/// Sumcheck prover logic is accessed via a trait because important optimizations are available
21/// depending on the structure of the multivariate polynomial that the protocol targets. For
22/// example, [Gruen24] observes a significant optimization available to the sumcheck prover when
23/// the multivariate is the product of a multilinear composite and an equality indicator
24/// polynomial, which arises in the zerocheck protocol.
25///
26/// The trait exposes a round-by-round interface so that protocol execution logic that drives the
27/// prover can interleave the executions of the interactive protocol, for example in the case of
28/// batching several sumcheck protocols.
29///
30/// The caller must make a specific sequence of calls to the provers. For a prover where
31/// [`Self::n_vars`] is $n$, the caller must call [`Self::execute`] and then [`Self::fold`] $n$
32/// times, and finally call [`Self::finish`]. If the calls aren't made in that order, the caller
33/// will get an error result.
34///
35/// This trait is object-safe.
36///
37/// [Gruen24]: <https://eprint.iacr.org/2024/108>
38pub trait SumcheckProver<F: Field> {
39	/// The number of variables in the multivariate polynomial.
40	fn n_vars(&self) -> usize;
41
42	/// Sumcheck evaluation order assumed by this specific prover.
43	fn evaluation_order(&self) -> EvaluationOrder;
44
45	/// Computes the prover message for this round as a univariate polynomial.
46	///
47	/// The prover message mixes the univariate polynomials of the underlying composites using the
48	/// powers of `batch_coeff`.
49	///
50	/// Let $alpha$ refer to `batch_coeff`. If [`Self::fold`] has already been called on the prover
51	/// with the values $r_0$, ..., $r_{k-1}$ and the sumcheck prover is proving the sums of the
52	/// composite polynomials $C_0, ..., C_{m-1}$, then the output of this method will be the
53	/// polynomial
54	///
55	/// $$
56	/// \sum_{v \in B_{n - k - 1}} \sum_{i=0}^{m-1} \alpha^i C_i(r_0, ..., r_{k-1}, X, \{v\})
57	/// $$
58	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error>;
59
60	/// Folds the sumcheck multilinears with a new verifier challenge.
61	fn fold(&mut self, challenge: F) -> Result<(), Error>;
62
63	/// Finishes the sumcheck proving protocol and returns the evaluations of all multilinears at
64	/// the challenge point.
65	fn finish(self: Box<Self>) -> Result<Vec<F>, Error>;
66}
67
68// NB: auto_impl does not currently handle ?Sized bound on Box<Self> receivers correctly.
69impl<F: Field, Prover: SumcheckProver<F> + ?Sized> SumcheckProver<F> for Box<Prover> {
70	fn n_vars(&self) -> usize {
71		(**self).n_vars()
72	}
73
74	fn evaluation_order(&self) -> EvaluationOrder {
75		(**self).evaluation_order()
76	}
77
78	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
79		(**self).execute(batch_coeff)
80	}
81
82	fn fold(&mut self, challenge: F) -> Result<(), Error> {
83		(**self).fold(challenge)
84	}
85
86	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
87		(*self).finish()
88	}
89}
90
91/// Prove a batched sumcheck protocol execution.
92///
93/// The sumcheck protocol over can be batched over multiple instances by taking random linear
94/// combinations over the claimed sums and polynomials. See
95/// [`crate::protocols::sumcheck::batch_verify`] for more details.
96///
97/// The provers in the `provers` parameter must in the same order as the corresponding claims
98/// provided to [`crate::protocols::sumcheck::batch_verify`] during proof verification.
99#[instrument(skip_all, name = "sumcheck::batch_prove")]
100pub fn batch_prove<F, Prover, Challenger_>(
101	mut provers: Vec<Prover>,
102	transcript: &mut ProverTranscript<Challenger_>,
103) -> Result<BatchSumcheckOutput<F>, Error>
104where
105	F: TowerField,
106	Prover: SumcheckProver<F>,
107	Challenger_: Challenger,
108{
109	let Some(first_prover) = provers.first() else {
110		return Ok(BatchSumcheckOutput {
111			challenges: Vec::new(),
112			multilinear_evals: Vec::new(),
113		});
114	};
115
116	let evaluation_order = first_prover.evaluation_order();
117
118	if provers
119		.iter()
120		.any(|prover| prover.evaluation_order() != evaluation_order)
121	{
122		bail!(Error::InconsistentEvaluationOrder);
123	}
124
125	// Check that the provers are in non-ascending order by n_vars
126	if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
127		bail!(Error::ClaimsOutOfOrder);
128	}
129
130	let n_rounds = provers
131		.iter()
132		.map(|prover| prover.n_vars())
133		.max()
134		.unwrap_or(0);
135
136	let mut batch_coeffs = Vec::with_capacity(provers.len());
137	let mut challenges = Vec::with_capacity(n_rounds);
138	for round_no in 0..n_rounds {
139		let n_vars = n_rounds - round_no;
140
141		// Activate new provers
142		while let Some(prover) = provers.get(batch_coeffs.len()) {
143			if prover.n_vars() != n_vars {
144				break;
145			}
146
147			let next_batch_coeff = transcript.sample();
148			batch_coeffs.push(next_batch_coeff);
149		}
150
151		// Process the active provers
152		let mut round_coeffs = RoundCoeffs::default();
153		for (&batch_coeff, prover) in izip!(&batch_coeffs, &mut provers) {
154			let prover_coeffs = prover.execute(batch_coeff)?;
155			round_coeffs += &(prover_coeffs * batch_coeff);
156		}
157
158		let round_proof = round_coeffs.truncate();
159		transcript
160			.message()
161			.write_scalar_slice(round_proof.coeffs());
162
163		let challenge = transcript.sample();
164		challenges.push(challenge);
165
166		for prover in &mut provers[..batch_coeffs.len()] {
167			prover.fold(challenge)?;
168		}
169	}
170
171	// sample next_batch_coeffs for 0-variate (ie. constant) provers to match with verify
172	for prover in &provers[batch_coeffs.len()..] {
173		debug_assert_eq!(prover.n_vars(), 0);
174		let _next_batch_coeff: F = transcript.sample();
175	}
176
177	let multilinear_evals = provers
178		.into_iter()
179		.map(|prover| Box::new(prover).finish())
180		.collect::<Result<Vec<_>, _>>()?;
181
182	let mut writer = transcript.message();
183	for multilinear_evals in &multilinear_evals {
184		writer.write_scalar_slice(multilinear_evals);
185	}
186
187	if EvaluationOrder::HighToLow == evaluation_order {
188		challenges.reverse();
189	}
190
191	let output = BatchSumcheckOutput {
192		challenges,
193		multilinear_evals,
194	};
195
196	Ok(output)
197}