binius_core/protocols/sumcheck/
common.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{Add, AddAssign, Mul, MulAssign};
4
5use binius_field::{
6	util::{inner_product_unchecked, powers},
7	ExtensionField, Field, PackedField,
8};
9use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15/// A claim about the sum of the values of a multilinear composite polynomial over the boolean
16/// hypercube.
17///
18/// This struct contains a composition polynomial and a claimed sum and implicitly refers to a
19/// sequence of multilinears that are composed. This is typically embedded within a
20/// [`SumcheckClaim`], which contains more metadata about the multilinears (eg. the number of
21/// variables they are defined over).
22#[derive(Debug, Clone, Getters, CopyGetters)]
23pub struct CompositeSumClaim<F: Field, Composition> {
24	pub composition: Composition,
25	pub sum: F,
26}
27
28/// A group of claims about the sum of the values of multilinear composite polynomials over the
29/// boolean hypercube.
30///
31/// All polynomials in the group of claims are compositions of the same sequence of multilinear
32/// polynomials. By defining [`SumcheckClaim`] in this way, the sumcheck protocol can implement
33/// efficient batch proving and verification and reduce to a set of multilinear evaluations of the
34/// same polynomials. In other words, this grouping deduplicates prover work and proof data that
35/// would be redundant in a more naive implementation.
36#[derive(Debug, Clone, CopyGetters)]
37pub struct SumcheckClaim<F: Field, C> {
38	#[getset(get_copy = "pub")]
39	n_vars: usize,
40	#[getset(get_copy = "pub")]
41	n_multilinears: usize,
42	composite_sums: Vec<CompositeSumClaim<F, C>>,
43}
44
45impl<F: Field, Composition> SumcheckClaim<F, Composition>
46where
47	Composition: CompositionPoly<F>,
48{
49	/// Constructs a new sumcheck claim.
50	///
51	/// ## Throws
52	///
53	/// * [`Error::InvalidComposition`] if any of the composition polynomials in the composite
54	///   claims vector do not have their number of variables equal to `n_multilinears`
55	pub fn new(
56		n_vars: usize,
57		n_multilinears: usize,
58		composite_sums: Vec<CompositeSumClaim<F, Composition>>,
59	) -> Result<Self, Error> {
60		for CompositeSumClaim {
61			ref composition, ..
62		} in &composite_sums
63		{
64			if composition.n_vars() != n_multilinears {
65				bail!(Error::InvalidComposition {
66					actual: composition.n_vars(),
67					expected: n_multilinears,
68				});
69			}
70		}
71		Ok(Self {
72			n_vars,
73			n_multilinears,
74			composite_sums,
75		})
76	}
77
78	/// Returns the maximum individual degree of all composite polynomials.
79	pub fn max_individual_degree(&self) -> usize {
80		self.composite_sums
81			.iter()
82			.map(|composite_sum| composite_sum.composition.degree())
83			.max()
84			.unwrap_or(0)
85	}
86
87	pub fn composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
88		&self.composite_sums
89	}
90}
91
92/// A univariate polynomial in monomial basis.
93///
94/// The coefficient at position `i` in the inner vector corresponds to the term $X^i$.
95#[derive(Debug, Default, Clone, PartialEq, Eq)]
96pub struct RoundCoeffs<F: Field>(pub Vec<F>);
97
98impl<F: Field> RoundCoeffs<F> {
99	/// Representation in an isomorphic field
100	pub fn isomorphic<FI: Field + From<F>>(self) -> RoundCoeffs<FI> {
101		RoundCoeffs(self.0.into_iter().map(Into::into).collect())
102	}
103
104	/// Truncate one coefficient from the polynomial to a more compact round proof.
105	pub fn truncate(mut self) -> RoundProof<F> {
106		self.0.pop();
107		RoundProof(self)
108	}
109}
110
111impl<F: Field> Add<&Self> for RoundCoeffs<F> {
112	type Output = Self;
113
114	fn add(mut self, rhs: &Self) -> Self::Output {
115		self += rhs;
116		self
117	}
118}
119
120impl<F: Field> AddAssign<&Self> for RoundCoeffs<F> {
121	fn add_assign(&mut self, rhs: &Self) {
122		if self.0.len() < rhs.0.len() {
123			self.0.resize(rhs.0.len(), F::ZERO);
124		}
125
126		for (lhs_i, &rhs_i) in self.0.iter_mut().zip(rhs.0.iter()) {
127			*lhs_i += rhs_i;
128		}
129	}
130}
131
132impl<F: Field> Mul<F> for RoundCoeffs<F> {
133	type Output = Self;
134
135	fn mul(mut self, rhs: F) -> Self::Output {
136		self *= rhs;
137		self
138	}
139}
140
141impl<F: Field> MulAssign<F> for RoundCoeffs<F> {
142	fn mul_assign(&mut self, rhs: F) {
143		for coeff in &mut self.0 {
144			*coeff *= rhs;
145		}
146	}
147}
148
149/// A sumcheck round proof is a univariate polynomial in monomial basis with the coefficient of the
150/// highest-degree term truncated off.
151///
152/// Since the verifier knows the claimed sum of the polynomial values at the points 0 and 1, the
153/// high-degree term coefficient can be easily recovered. Truncating the coefficient off saves a
154/// small amount of proof data.
155#[derive(Debug, Default, Clone, PartialEq, Eq)]
156pub struct RoundProof<F: Field>(pub RoundCoeffs<F>);
157
158impl<F: Field> RoundProof<F> {
159	/// Recovers all univariate polynomial coefficients from the compressed round proof.
160	///
161	/// The prover has sent coefficients for the purported ith round polynomial
162	/// $r_i(X) = \sum_{j=0}^d a_j * X^j$.
163	/// However, the prover has not sent the highest degree coefficient $a_d$.
164	/// The verifier will need to recover this missing coefficient.
165	///
166	/// Let $s$ denote the current round's claimed sum.
167	/// The verifier expects the round polynomial $r_i$ to satisfy the identity
168	/// $s = r_i(0) + r_i(1)$.
169	/// Using
170	///     $r_i(0) = a_0$
171	///     $r_i(1) = \sum_{j=0}^d a_j$
172	/// There is a unique $a_d$ that allows $r_i$ to satisfy the above identity.
173	/// Specifically
174	///     $a_d = s - a_0 - \sum_{j=0}^{d-1} a_j$
175	///
176	/// Not sending the whole round polynomial is an optimization.
177	/// In the unoptimized version of the protocol, the verifier will halt and reject
178	/// if given a round polynomial that does not satisfy the above identity.
179	pub fn recover(self, sum: F) -> RoundCoeffs<F> {
180		let Self(RoundCoeffs(mut coeffs)) = self;
181		let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
182		let last_coeff = sum - first_coeff - coeffs.iter().sum::<F>();
183		coeffs.push(last_coeff);
184		RoundCoeffs(coeffs)
185	}
186
187	/// The truncated polynomial coefficients.
188	pub fn coeffs(&self) -> &[F] {
189		&self.0 .0
190	}
191
192	/// Representation in an isomorphic field
193	pub fn isomorphic<FI: Field + From<F>>(self) -> RoundProof<FI> {
194		RoundProof(self.0.isomorphic())
195	}
196}
197
198/// A sumcheck batch proof.
199#[derive(Debug, Default, Clone, PartialEq, Eq)]
200pub struct Proof<F: Field> {
201	/// The round proofs for each round.
202	pub rounds: Vec<RoundProof<F>>,
203	/// The claimed evaluations of all multilinears at the point defined by the sumcheck verifier
204	/// challenges.
205	///
206	/// The structure is a vector of vectors of field elements. Each entry of the outer vector
207	/// corresponds to one [`SumcheckClaim`] in a batch. Each inner vector contains the evaluations
208	/// of the multilinears referenced by that claim.
209	pub multilinear_evals: Vec<Vec<F>>,
210}
211
212#[derive(Debug, PartialEq, Eq)]
213pub struct BatchSumcheckOutput<F: Field> {
214	pub challenges: Vec<F>,
215	pub multilinear_evals: Vec<Vec<F>>,
216}
217
218impl<F: Field> BatchSumcheckOutput<F> {
219	pub fn isomorphic<FI: Field + From<F>>(self) -> BatchSumcheckOutput<FI> {
220		BatchSumcheckOutput {
221			challenges: self.challenges.into_iter().map(Into::into).collect(),
222			multilinear_evals: self
223				.multilinear_evals
224				.into_iter()
225				.map(|prover_evals| prover_evals.into_iter().map(Into::into).collect())
226				.collect(),
227		}
228	}
229}
230
231/// Constructs a switchover function thaw returns the round number where folded multilinear is at
232/// least 2^k times smaller (in bytes) than the original, or 1 when not applicable.
233pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
234	move |extension_degree: usize| {
235		((extension_degree.ilog2() as isize + k).max(0) as usize).saturating_sub(1)
236	}
237}
238
239/// Sumcheck switchover heuristic that begins folding immediately in the first round.
240pub const fn immediate_switchover_heuristic(_extension_degree: usize) -> usize {
241	0
242}
243
244/// Check that all multilinears in a slice are of the same size.
245pub fn equal_n_vars_check<'a, P, M>(
246	multilinears: impl IntoIterator<Item = &'a M>,
247) -> Result<usize, Error>
248where
249	P: PackedField,
250	M: MultilinearPoly<P> + 'a,
251{
252	let mut multilinears = multilinears.into_iter();
253	let n_vars = multilinears
254		.next()
255		.map(|multilinear| multilinear.n_vars())
256		.unwrap_or_default();
257	for multilinear in multilinears {
258		if multilinear.n_vars() != n_vars {
259			bail!(Error::NumberOfVariablesMismatch);
260		}
261	}
262	Ok(n_vars)
263}
264
265/// Check that evaluations of all multilinears can actually be embedded in the scalar
266/// type of small field `PBase`.
267///
268/// Returns binary logarithm of the embedding degree.
269pub fn small_field_embedding_degree_check<F, FBase, P, M>(multilinears: &[M]) -> Result<(), Error>
270where
271	F: Field + ExtensionField<FBase>,
272	FBase: Field,
273	P: PackedField<Scalar = F>,
274	M: MultilinearPoly<P>,
275{
276	for multilinear in multilinears {
277		if multilinear.log_extension_degree() < F::LOG_DEGREE {
278			bail!(Error::MultilinearEvalsCannotBeEmbeddedInBaseField);
279		}
280	}
281
282	Ok(())
283}
284
285/// Multiply a sequence of field elements by the consecutive powers of `batch_coeff`
286pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item = F>) -> F {
287	// Multiplying by batch_coeff is important for security!
288	batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
289}
290
291/// Create interpolation domains for a sequence of composition degrees.
292pub fn interpolation_domains_for_composition_degrees<FDomain>(
293	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
294	degrees: impl IntoIterator<Item = usize>,
295) -> Result<Vec<InterpolationDomain<FDomain>>, Error>
296where
297	FDomain: Field,
298{
299	degrees
300		.into_iter()
301		.map(|degree| Ok(evaluation_domain_factory.create(degree + 1)?.into()))
302		.collect()
303}
304
305/// Validate the sumcheck evaluation domains to conform to the shape expected by the
306/// `SumcheckRoundCalculator`:
307///   1) First three points are zero, one, and Karatsuba infinity (for degrees above 1)
308///   2) All finite evaluation point slices are proper prefixes of the largest evaluation domain
309pub fn get_nontrivial_evaluation_points<F: Field>(
310	domains: &[InterpolationDomain<F>],
311) -> Result<Vec<F>, Error> {
312	let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
313		return Ok(Vec::new());
314	};
315
316	#[allow(clippy::get_first)]
317	if !domains.iter().all(|domain| {
318		(domain.size() <= 2 || domain.with_infinity())
319			&& domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
320			&& domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
321	}) {
322		bail!(Error::IncorrectSumcheckEvaluationDomain);
323	}
324
325	let finite_points = largest_domain.finite_points();
326
327	if domains
328		.iter()
329		.any(|domain| !finite_points.starts_with(domain.finite_points()))
330	{
331		bail!(Error::NonProperPrefixEvaluationDomain);
332	}
333
334	let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
335	Ok(nontrivial_evaluation_points)
336}
337
338#[cfg(test)]
339mod tests {
340	use binius_field::BinaryField64b;
341
342	use super::*;
343
344	type F = BinaryField64b;
345
346	#[test]
347	fn test_round_coeffs_truncate_non_empty() {
348		let coeffs = RoundCoeffs(vec![F::from(1), F::from(2), F::from(3)]);
349		let truncated = coeffs.truncate();
350		assert_eq!(truncated.0 .0, vec![F::from(1), F::from(2)]);
351	}
352
353	#[test]
354	fn test_round_coeffs_truncate_empty() {
355		let coeffs = RoundCoeffs::<F>(vec![]);
356		let truncated = coeffs.truncate();
357		assert!(truncated.0 .0.is_empty());
358	}
359}