binius_core/protocols/sumcheck/
common.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{Add, AddAssign, Mul, MulAssign};
4
5use binius_field::{
6	util::{inner_product_unchecked, powers},
7	ExtensionField, Field, PackedField,
8};
9use binius_math::{CompositionPoly, InterpolationDomain, MultilinearPoly};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12use tracing::instrument;
13
14use super::error::Error;
15
16/// A claim about the sum of the values of a multilinear composite polynomial over the boolean
17/// hypercube.
18///
19/// This struct contains a composition polynomial and a claimed sum and implicitly refers to a
20/// sequence of multilinears that are composed. This is typically embedded within a
21/// [`SumcheckClaim`], which contains more metadata about the multilinears (eg. the number of
22/// variables they are defined over).
23#[derive(Debug, Clone, Getters, CopyGetters)]
24pub struct CompositeSumClaim<F: Field, Composition> {
25	pub composition: Composition,
26	pub sum: F,
27}
28
29/// A group of claims about the sum of the values of multilinear composite polynomials over the
30/// boolean hypercube.
31///
32/// All polynomials in the group of claims are compositions of the same sequence of multilinear
33/// polynomials. By defining [`SumcheckClaim`] in this way, the sumcheck protocol can implement
34/// efficient batch proving and verification and reduce to a set of multilinear evaluations of the
35/// same polynomials. In other words, this grouping deduplicates prover work and proof data that
36/// would be redundant in a more naive implementation.
37#[derive(Debug, Clone, CopyGetters)]
38pub struct SumcheckClaim<F: Field, C> {
39	#[getset(get_copy = "pub")]
40	n_vars: usize,
41	#[getset(get_copy = "pub")]
42	n_multilinears: usize,
43	composite_sums: Vec<CompositeSumClaim<F, C>>,
44}
45
46impl<F: Field, Composition> SumcheckClaim<F, Composition>
47where
48	Composition: CompositionPoly<F>,
49{
50	/// Constructs a new sumcheck claim.
51	///
52	/// ## Throws
53	///
54	/// * [`Error::InvalidComposition`] if any of the composition polynomials in the composite
55	///   claims vector do not have their number of variables equal to `n_multilinears`
56	pub fn new(
57		n_vars: usize,
58		n_multilinears: usize,
59		composite_sums: Vec<CompositeSumClaim<F, Composition>>,
60	) -> Result<Self, Error> {
61		for CompositeSumClaim {
62			ref composition, ..
63		} in &composite_sums
64		{
65			if composition.n_vars() != n_multilinears {
66				bail!(Error::InvalidComposition {
67					actual: composition.n_vars(),
68					expected: n_multilinears,
69				});
70			}
71		}
72		Ok(Self {
73			n_vars,
74			n_multilinears,
75			composite_sums,
76		})
77	}
78
79	/// Returns the maximum individual degree of all composite polynomials.
80	pub fn max_individual_degree(&self) -> usize {
81		self.composite_sums
82			.iter()
83			.map(|composite_sum| composite_sum.composition.degree())
84			.max()
85			.unwrap_or(0)
86	}
87
88	pub fn composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
89		&self.composite_sums
90	}
91}
92
93/// A univariate polynomial in monomial basis.
94///
95/// The coefficient at position `i` in the inner vector corresponds to the term $X^i$.
96#[derive(Debug, Default, Clone, PartialEq, Eq)]
97pub struct RoundCoeffs<F: Field>(pub Vec<F>);
98
99impl<F: Field> RoundCoeffs<F> {
100	/// Representation in an isomorphic field
101	pub fn isomorphic<FI: Field + From<F>>(self) -> RoundCoeffs<FI> {
102		RoundCoeffs(self.0.into_iter().map(Into::into).collect())
103	}
104
105	/// Truncate one coefficient from the polynomial to a more compact round proof.
106	pub fn truncate(mut self) -> RoundProof<F> {
107		self.0.pop();
108		RoundProof(self)
109	}
110}
111
112impl<F: Field> Add<&Self> for RoundCoeffs<F> {
113	type Output = Self;
114
115	fn add(mut self, rhs: &Self) -> Self::Output {
116		self += rhs;
117		self
118	}
119}
120
121impl<F: Field> AddAssign<&Self> for RoundCoeffs<F> {
122	fn add_assign(&mut self, rhs: &Self) {
123		if self.0.len() < rhs.0.len() {
124			self.0.resize(rhs.0.len(), F::ZERO);
125		}
126
127		for (lhs_i, &rhs_i) in self.0.iter_mut().zip(rhs.0.iter()) {
128			*lhs_i += rhs_i;
129		}
130	}
131}
132
133impl<F: Field> Mul<F> for RoundCoeffs<F> {
134	type Output = Self;
135
136	fn mul(mut self, rhs: F) -> Self::Output {
137		self *= rhs;
138		self
139	}
140}
141
142impl<F: Field> MulAssign<F> for RoundCoeffs<F> {
143	fn mul_assign(&mut self, rhs: F) {
144		for coeff in &mut self.0 {
145			*coeff *= rhs;
146		}
147	}
148}
149
150/// A sumcheck round proof is a univariate polynomial in monomial basis with the coefficient of the
151/// highest-degree term truncated off.
152///
153/// Since the verifier knows the claimed sum of the polynomial values at the points 0 and 1, the
154/// high-degree term coefficient can be easily recovered. Truncating the coefficient off saves a
155/// small amount of proof data.
156#[derive(Debug, Default, Clone, PartialEq, Eq)]
157pub struct RoundProof<F: Field>(pub RoundCoeffs<F>);
158
159impl<F: Field> RoundProof<F> {
160	/// Recovers all univariate polynomial coefficients from the compressed round proof.
161	///
162	/// The prover has sent coefficients for the purported ith round polynomial
163	/// $r_i(X) = \sum_{j=0}^d a_j * X^j$.
164	/// However, the prover has not sent the highest degree coefficient $a_d$.
165	/// The verifier will need to recover this missing coefficient.
166	///
167	/// Let $s$ denote the current round's claimed sum.
168	/// The verifier expects the round polynomial $r_i$ to satisfy the identity
169	/// $s = r_i(0) + r_i(1)$.
170	/// Using
171	///     $r_i(0) = a_0$
172	///     $r_i(1) = \sum_{j=0}^d a_j$
173	/// There is a unique $a_d$ that allows $r_i$ to satisfy the above identity.
174	/// Specifically
175	///     $a_d = s - a_0 - \sum_{j=0}^{d-1} a_j$
176	///
177	/// Not sending the whole round polynomial is an optimization.
178	/// In the unoptimized version of the protocol, the verifier will halt and reject
179	/// if given a round polynomial that does not satisfy the above identity.
180	pub fn recover(self, sum: F) -> RoundCoeffs<F> {
181		let Self(RoundCoeffs(mut coeffs)) = self;
182		let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
183		let last_coeff = sum - first_coeff - coeffs.iter().sum::<F>();
184		coeffs.push(last_coeff);
185		RoundCoeffs(coeffs)
186	}
187
188	/// The truncated polynomial coefficients.
189	pub fn coeffs(&self) -> &[F] {
190		&self.0 .0
191	}
192
193	/// Representation in an isomorphic field
194	pub fn isomorphic<FI: Field + From<F>>(self) -> RoundProof<FI> {
195		RoundProof(self.0.isomorphic())
196	}
197}
198
199/// A sumcheck batch proof.
200#[derive(Debug, Default, Clone, PartialEq, Eq)]
201pub struct Proof<F: Field> {
202	/// The round proofs for each round.
203	pub rounds: Vec<RoundProof<F>>,
204	/// The claimed evaluations of all multilinears at the point defined by the sumcheck verifier
205	/// challenges.
206	///
207	/// The structure is a vector of vectors of field elements. Each entry of the outer vector
208	/// corresponds to one [`SumcheckClaim`] in a batch. Each inner vector contains the evaluations
209	/// of the multilinears referenced by that claim.
210	pub multilinear_evals: Vec<Vec<F>>,
211}
212
213#[derive(Debug, PartialEq, Eq)]
214pub struct BatchSumcheckOutput<F: Field> {
215	pub challenges: Vec<F>,
216	pub multilinear_evals: Vec<Vec<F>>,
217}
218
219impl<F: Field> BatchSumcheckOutput<F> {
220	pub fn isomorphic<FI: Field + From<F>>(self) -> BatchSumcheckOutput<FI> {
221		BatchSumcheckOutput {
222			challenges: self.challenges.into_iter().map(Into::into).collect(),
223			multilinear_evals: self
224				.multilinear_evals
225				.into_iter()
226				.map(|prover_evals| prover_evals.into_iter().map(Into::into).collect())
227				.collect(),
228		}
229	}
230}
231
232/// Constructs a switchover function thaw returns the round number where folded multilinear is at
233/// least 2^k times smaller (in bytes) than the original, or 1 when not applicable.
234pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
235	move |extension_degree: usize| {
236		((extension_degree.ilog2() as isize + k).max(0) as usize).saturating_sub(1)
237	}
238}
239
240/// Sumcheck switchover heuristic that begins folding immediately in the first round.
241pub const fn immediate_switchover_heuristic(_extension_degree: usize) -> usize {
242	0
243}
244
245/// Determine switchover rounds for a slice of multilinears.
246#[instrument(skip_all, level = "debug")]
247pub fn determine_switchovers<P, M>(
248	multilinears: &[M],
249	switchover_fn: impl Fn(usize) -> usize,
250) -> Vec<usize>
251where
252	P: PackedField,
253	M: MultilinearPoly<P>,
254{
255	// TODO: This can be computed in parallel.
256	multilinears
257		.iter()
258		.map(|multilinear| switchover_fn(1 << multilinear.log_extension_degree()))
259		.collect()
260}
261
262/// Check that all multilinears in a slice are of the same size.
263pub fn equal_n_vars_check<P, M>(multilinears: &[M]) -> Result<usize, Error>
264where
265	P: PackedField,
266	M: MultilinearPoly<P>,
267{
268	let n_vars = multilinears
269		.first()
270		.map(|multilinear| multilinear.n_vars())
271		.unwrap_or_default();
272	for multilinear in multilinears {
273		if multilinear.n_vars() != n_vars {
274			bail!(Error::NumberOfVariablesMismatch);
275		}
276	}
277	Ok(n_vars)
278}
279
280/// Check that evaluations of all multilinears can actually be embedded in the scalar
281/// type of small field `PBase`.
282///
283/// Returns binary logarithm of the embedding degree.
284pub fn small_field_embedding_degree_check<F, FBase, P, M>(multilinears: &[M]) -> Result<(), Error>
285where
286	F: Field + ExtensionField<FBase>,
287	FBase: Field,
288	P: PackedField<Scalar = F>,
289	M: MultilinearPoly<P>,
290{
291	for multilinear in multilinears {
292		if multilinear.log_extension_degree() < F::LOG_DEGREE {
293			bail!(Error::MultilinearEvalsCannotBeEmbeddedInBaseField);
294		}
295	}
296
297	Ok(())
298}
299
300/// Multiply a sequence of field elements by the consecutive powers of `batch_coeff`
301pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item = F>) -> F {
302	// Multiplying by batch_coeff is important for security!
303	batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
304}
305
306/// Validate the sumcheck evaluation domains to conform to the shape expected by the
307/// `SumcheckRoundCalculator`:
308///   1) First three points are zero, one, and Karatsuba infinity (for degrees above 1)
309///   2) All finite evaluation point slices are proper prefixes of the largest evaluation domain
310pub fn get_nontrivial_evaluation_points<F: Field>(
311	domains: &[InterpolationDomain<F>],
312) -> Result<Vec<F>, Error> {
313	let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
314		return Ok(Vec::new());
315	};
316
317	#[allow(clippy::get_first)]
318	if !domains.iter().all(|domain| {
319		(domain.size() <= 2 || domain.with_infinity())
320			&& domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
321			&& domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
322	}) {
323		bail!(Error::IncorrectSumcheckEvaluationDomain);
324	}
325
326	let finite_points = largest_domain.finite_points();
327
328	if domains
329		.iter()
330		.any(|domain| !finite_points.starts_with(domain.finite_points()))
331	{
332		bail!(Error::NonProperPrefixEvaluationDomain);
333	}
334
335	let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
336	Ok(nontrivial_evaluation_points)
337}
338
339#[cfg(test)]
340mod tests {
341	use binius_field::BinaryField64b;
342
343	use super::*;
344
345	type F = BinaryField64b;
346
347	#[test]
348	fn test_round_coeffs_truncate_non_empty() {
349		let coeffs = RoundCoeffs(vec![F::from(1), F::from(2), F::from(3)]);
350		let truncated = coeffs.truncate();
351		assert_eq!(truncated.0 .0, vec![F::from(1), F::from(2)]);
352	}
353
354	#[test]
355	fn test_round_coeffs_truncate_empty() {
356		let coeffs = RoundCoeffs::<F>(vec![]);
357		let truncated = coeffs.truncate();
358		assert!(truncated.0 .0.is_empty());
359	}
360}