binius_core/protocols/sumcheck/prove/
batch_sumcheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6use itertools::izip;
7use tracing::instrument;
8
9use crate::{
10	fiat_shamir::{CanSample, Challenger},
11	protocols::sumcheck::{
12		common::{BatchSumcheckOutput, RoundCoeffs},
13		error::Error,
14	},
15	transcript::ProverTranscript,
16};
17
18/// A sumcheck prover with a round-by-round execution interface.
19///
20/// Sumcheck prover logic is accessed via a trait because important optimizations are available
21/// depending on the structure of the multivariate polynomial that the protocol targets. For
22/// example, [Gruen24] observes a significant optimization available to the sumcheck prover when
23/// the multivariate is the product of a multilinear composite and an equality indicator
24/// polynomial, which arises in the zerocheck protocol.
25///
26/// The trait exposes a round-by-round interface so that protocol execution logic that drives the
27/// prover can interleave the executions of the interactive protocol, for example in the case of
28/// batching several sumcheck protocols.
29///
30/// The caller must make a specific sequence of calls to the provers. For a prover where
31/// [`Self::n_vars`] is $n$, the caller must call [`Self::execute`] and then [`Self::fold`] $n$
32/// times, and finally call [`Self::finish`]. If the calls aren't made in that order, the caller
33/// will get an error result.
34///
35/// This trait is object-safe.
36///
37/// [Gruen24]: <https://eprint.iacr.org/2024/108>
38pub trait SumcheckProver<F: Field> {
39	/// The number of variables remaining in the multivariate polynomial.
40	///
41	/// This value decrements each time [`Self::fold`] is called on the instance.
42	fn n_vars(&self) -> usize;
43
44	/// Sumcheck evaluation order assumed by this specific prover.
45	fn evaluation_order(&self) -> EvaluationOrder;
46
47	/// Computes the prover message for this round as a univariate polynomial.
48	///
49	/// The prover message mixes the univariate polynomials of the underlying composites using the
50	/// powers of `batch_coeff`.
51	///
52	/// Let $alpha$ refer to `batch_coeff`. If [`Self::fold`] has already been called on the prover
53	/// with the values $r_0$, ..., $r_{k-1}$ and the sumcheck prover is proving the sums of the
54	/// composite polynomials $C_0, ..., C_{m-1}$, then the output of this method will be the
55	/// polynomial
56	///
57	/// $$
58	/// \sum_{v \in B_{n - k - 1}} \sum_{i=0}^{m-1} \alpha^i C_i(r_0, ..., r_{k-1}, X, \{v\})
59	/// $$
60	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error>;
61
62	/// Folds the sumcheck multilinears with a new verifier challenge.
63	fn fold(&mut self, challenge: F) -> Result<(), Error>;
64
65	/// Finishes the sumcheck proving protocol and returns the evaluations of all multilinears at
66	/// the challenge point.
67	fn finish(self: Box<Self>) -> Result<Vec<F>, Error>;
68}
69
70// NB: auto_impl does not currently handle ?Sized bound on Box<Self> receivers correctly.
71impl<F: Field, Prover: SumcheckProver<F> + ?Sized> SumcheckProver<F> for Box<Prover> {
72	fn n_vars(&self) -> usize {
73		(**self).n_vars()
74	}
75
76	fn evaluation_order(&self) -> EvaluationOrder {
77		(**self).evaluation_order()
78	}
79
80	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
81		(**self).execute(batch_coeff)
82	}
83
84	fn fold(&mut self, challenge: F) -> Result<(), Error> {
85		(**self).fold(challenge)
86	}
87
88	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
89		(*self).finish()
90	}
91}
92
93/// Prove a batched sumcheck protocol execution.
94///
95/// The sumcheck protocol over can be batched over multiple instances by taking random linear
96/// combinations over the claimed sums and polynomials. See
97/// [`crate::protocols::sumcheck::batch_verify`] for more details.
98///
99/// The provers in the `provers` parameter must in the same order as the corresponding claims
100/// provided to [`crate::protocols::sumcheck::batch_verify`] during proof verification.
101#[instrument(skip_all, name = "sumcheck::batch_prove")]
102pub fn batch_prove<F, Prover, Challenger_>(
103	mut provers: Vec<Prover>,
104	transcript: &mut ProverTranscript<Challenger_>,
105) -> Result<BatchSumcheckOutput<F>, Error>
106where
107	F: TowerField,
108	Prover: SumcheckProver<F>,
109	Challenger_: Challenger,
110{
111	let Some(first_prover) = provers.first() else {
112		return Ok(BatchSumcheckOutput {
113			challenges: Vec::new(),
114			multilinear_evals: Vec::new(),
115		});
116	};
117
118	let evaluation_order = first_prover.evaluation_order();
119
120	if provers
121		.iter()
122		.any(|prover| prover.evaluation_order() != evaluation_order)
123	{
124		bail!(Error::InconsistentEvaluationOrder);
125	}
126
127	// Check that the provers are in non-ascending order by n_vars
128	if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
129		bail!(Error::ClaimsOutOfOrder);
130	}
131
132	let n_rounds = provers
133		.iter()
134		.map(|prover| prover.n_vars())
135		.max()
136		.unwrap_or(0);
137
138	let mut batch_coeffs = Vec::with_capacity(provers.len());
139	let mut challenges = Vec::with_capacity(n_rounds);
140	for round_no in 0..n_rounds {
141		let n_vars = n_rounds - round_no;
142
143		// Activate new provers
144		while let Some(prover) = provers.get(batch_coeffs.len()) {
145			if prover.n_vars() != n_vars {
146				break;
147			}
148
149			let next_batch_coeff = transcript.sample();
150			batch_coeffs.push(next_batch_coeff);
151		}
152
153		// Process the active provers
154		let mut round_coeffs = RoundCoeffs::default();
155		for (&batch_coeff, prover) in izip!(&batch_coeffs, &mut provers) {
156			let prover_coeffs = prover.execute(batch_coeff)?;
157			round_coeffs += &(prover_coeffs * batch_coeff);
158		}
159
160		let round_proof = round_coeffs.truncate();
161		transcript
162			.message()
163			.write_scalar_slice(round_proof.coeffs());
164
165		let challenge = transcript.sample();
166		challenges.push(challenge);
167
168		for prover in &mut provers[..batch_coeffs.len()] {
169			prover.fold(challenge)?;
170		}
171	}
172
173	// sample next_batch_coeffs for 0-variate (ie. constant) provers to match with verify
174	for prover in &provers[batch_coeffs.len()..] {
175		debug_assert_eq!(prover.n_vars(), 0);
176		let _next_batch_coeff: F = transcript.sample();
177	}
178
179	let multilinear_evals = provers
180		.into_iter()
181		.map(|prover| Box::new(prover).finish())
182		.collect::<Result<Vec<_>, _>>()?;
183
184	let mut writer = transcript.message();
185	for multilinear_evals in &multilinear_evals {
186		writer.write_scalar_slice(multilinear_evals);
187	}
188
189	if EvaluationOrder::HighToLow == evaluation_order {
190		challenges.reverse();
191	}
192
193	let output = BatchSumcheckOutput {
194		challenges,
195		multilinear_evals,
196	};
197
198	Ok(output)
199}