binius_core/protocols/sumcheck/prove/
batch_prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_field::{Field, TowerField};
6use binius_math::EvaluationOrder;
7use binius_utils::{bail, sorting::is_sorted_ascending};
8use tracing::instrument;
9
10use crate::{
11	fiat_shamir::{CanSample, Challenger},
12	protocols::sumcheck::{
13		common::{BatchSumcheckOutput, RoundCoeffs},
14		error::Error,
15	},
16	transcript::ProverTranscript,
17};
18
19/// A sumcheck prover with a round-by-round execution interface.
20///
21/// Sumcheck prover logic is accessed via a trait because important optimizations are available
22/// depending on the structure of the multivariate polynomial that the protocol targets. For
23/// example, [Gruen24] observes a significant optimization available to the sumcheck prover when
24/// the multivariate is the product of a multilinear composite and an equality indicator
25/// polynomial, which arises in the zerocheck protocol.
26///
27/// The trait exposes a round-by-round interface so that protocol execution logic that drives the
28/// prover can interleave the executions of the interactive protocol, for example in the case of
29/// batching several sumcheck protocols.
30///
31/// The caller must make a specific sequence of calls to the provers. For a prover where
32/// [`Self::n_vars`] is $n$, the caller must call [`Self::execute`] and then [`Self::fold`] $n$
33/// times, and finally call [`Self::finish`]. If the calls aren't made in that order, the caller
34/// will get an error result.
35///
36/// This trait is object-safe.
37///
38/// [Gruen24]: <https://eprint.iacr.org/2024/108>
39pub trait SumcheckProver<F: Field> {
40	/// The number of variables in the multivariate polynomial.
41	fn n_vars(&self) -> usize;
42
43	/// Sumcheck evaluation order assumed by this specific prover.
44	fn evaluation_order(&self) -> EvaluationOrder;
45
46	/// Computes the prover message for this round as a univariate polynomial.
47	///
48	/// The prover message mixes the univariate polynomials of the underlying composites using the
49	/// powers of `batch_coeff`.
50	///
51	/// Let $alpha$ refer to `batch_coeff`. If [`Self::fold`] has already been called on the prover
52	/// with the values $r_0$, ..., $r_{k-1}$ and the sumcheck prover is proving the sums of the
53	/// composite polynomials $C_0, ..., C_{m-1}$, then the output of this method will be the
54	/// polynomial
55	///
56	/// $$
57	/// \sum_{v \in B_{n - k - 1}} \sum_{i=0}^{m-1} \alpha^i C_i(r_0, ..., r_{k-1}, X, \{v\})
58	/// $$
59	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error>;
60
61	/// Folds the sumcheck multilinears with a new verifier challenge.
62	fn fold(&mut self, challenge: F) -> Result<(), Error>;
63
64	/// Finishes the sumcheck proving protocol and returns the evaluations of all multilinears at
65	/// the challenge point.
66	fn finish(self: Box<Self>) -> Result<Vec<F>, Error>;
67}
68
69// NB: auto_impl does not currently handle ?Sized bound on Box<Self> receivers correctly.
70impl<F: Field, Prover: SumcheckProver<F> + ?Sized> SumcheckProver<F> for Box<Prover> {
71	fn n_vars(&self) -> usize {
72		(**self).n_vars()
73	}
74
75	fn evaluation_order(&self) -> EvaluationOrder {
76		(**self).evaluation_order()
77	}
78
79	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
80		(**self).execute(batch_coeff)
81	}
82
83	fn fold(&mut self, challenge: F) -> Result<(), Error> {
84		(**self).fold(challenge)
85	}
86
87	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
88		(*self).finish()
89	}
90}
91
92/// Prove a batched sumcheck protocol execution.
93///
94/// The sumcheck protocol over can be batched over multiple instances by taking random linear
95/// combinations over the claimed sums and polynomials. See
96/// [`crate::protocols::sumcheck::batch_verify`] for more details.
97///
98/// The provers in the `provers` parameter must in the same order as the corresponding claims
99/// provided to [`crate::protocols::sumcheck::batch_verify`] during proof verification.
100#[instrument(skip_all, name = "sumcheck::batch_prove")]
101pub fn batch_prove<F, Prover, Challenger_>(
102	provers: Vec<Prover>,
103	transcript: &mut ProverTranscript<Challenger_>,
104) -> Result<BatchSumcheckOutput<F>, Error>
105where
106	F: TowerField,
107	Prover: SumcheckProver<F>,
108	Challenger_: Challenger,
109{
110	let start = BatchProveStart {
111		batch_coeffs: Vec::new(),
112		reduction_provers: Vec::<Prover>::new(),
113	};
114
115	batch_prove_with_start(start, provers, transcript)
116}
117
118/// A struct describing the starting state of batched sumcheck prove invocation.
119#[derive(Debug)]
120pub struct BatchProveStart<F: Field, Prover> {
121	/// Batching coefficients for the already batched provers.
122	pub batch_coeffs: Vec<F>,
123	/// Reduced provers which can complete sumchecks from an intermediate state.
124	pub reduction_provers: Vec<Prover>,
125}
126
127/// Prove a batched sumcheck protocol execution, but after some rounds have been processed.
128#[instrument(skip_all, name = "sumcheck::batch_prove")]
129pub fn batch_prove_with_start<F, Prover, Challenger_>(
130	start: BatchProveStart<F, Prover>,
131	mut provers: Vec<Prover>,
132	transcript: &mut ProverTranscript<Challenger_>,
133) -> Result<BatchSumcheckOutput<F>, Error>
134where
135	F: TowerField,
136	Prover: SumcheckProver<F>,
137	Challenger_: Challenger,
138{
139	let BatchProveStart {
140		mut batch_coeffs,
141		reduction_provers,
142	} = start;
143
144	provers.splice(0..0, reduction_provers);
145
146	let Some(first_prover) = provers.first() else {
147		return Ok(BatchSumcheckOutput {
148			challenges: Vec::new(),
149			multilinear_evals: Vec::new(),
150		});
151	};
152
153	let evaluation_order = first_prover.evaluation_order();
154
155	if provers
156		.iter()
157		.any(|prover| prover.evaluation_order() != evaluation_order)
158	{
159		bail!(Error::InconsistentEvaluationOrder);
160	}
161
162	// Check that the provers are in descending order by n_vars
163	if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
164		bail!(Error::ClaimsOutOfOrder);
165	}
166
167	if batch_coeffs.len() > provers.len() {
168		bail!(Error::TooManyPrebatchedCoeffs);
169	}
170
171	let n_rounds = provers
172		.iter()
173		.map(|prover| prover.n_vars())
174		.max()
175		.unwrap_or(0);
176
177	// active_index is an index into the provers slice.
178	let mut active_index = batch_coeffs.len();
179	let mut challenges = Vec::with_capacity(n_rounds);
180	for round_no in 0..n_rounds {
181		let n_vars = n_rounds - round_no;
182
183		// Activate new provers
184		while let Some(prover) = provers.get(active_index) {
185			if prover.n_vars() != n_vars {
186				break;
187			}
188
189			let next_batch_coeff: F = transcript.sample();
190			batch_coeffs.push(next_batch_coeff);
191			active_index += 1;
192		}
193
194		// Process the active provers
195		let mut round_coeffs = RoundCoeffs::default();
196		for (&batch_coeff, prover) in
197			iter::zip(batch_coeffs.iter(), provers[..active_index].iter_mut())
198		{
199			let prover_coeffs = prover.execute(batch_coeff)?;
200			round_coeffs += &(prover_coeffs * batch_coeff);
201		}
202
203		let round_proof = round_coeffs.truncate();
204		transcript
205			.message()
206			.write_scalar_slice(round_proof.coeffs());
207
208		let challenge = transcript.sample();
209		challenges.push(challenge);
210
211		for prover in &mut provers[..active_index] {
212			prover.fold(challenge)?;
213		}
214	}
215
216	// sample next_batch_coeffs for 0-variate (ie. constant) provers to match with verify
217	while let Some(prover) = provers.get(active_index) {
218		debug_assert_eq!(prover.n_vars(), 0);
219
220		let _next_batch_coeff: F = transcript.sample();
221		active_index += 1;
222	}
223
224	let multilinear_evals = provers
225		.into_iter()
226		.map(|prover| Box::new(prover).finish())
227		.collect::<Result<Vec<_>, _>>()?;
228
229	let mut writer = transcript.message();
230	for multilinear_evals in &multilinear_evals {
231		writer.write_scalar_slice(multilinear_evals);
232	}
233
234	if EvaluationOrder::HighToLow == evaluation_order {
235		challenges.reverse();
236	}
237
238	let output = BatchSumcheckOutput {
239		challenges,
240		multilinear_evals,
241	};
242
243	Ok(output)
244}