binius_core/protocols/sumcheck/
common.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{Add, AddAssign, Mul, MulAssign};
4
5use binius_field::{
6	ExtensionField, Field, PackedField,
7	util::{inner_product_unchecked, powers},
8};
9use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15/// A claim about the sum of the values of a multilinear composite polynomial over the boolean
16/// hypercube.
17///
18/// This struct contains a composition polynomial and a claimed sum and implicitly refers to a
19/// sequence of multilinears that are composed. This is typically embedded within a
20/// [`SumcheckClaim`], which contains more metadata about the multilinears (eg. the number of
21/// variables they are defined over).
22#[derive(Debug, Clone, Getters, CopyGetters)]
23pub struct CompositeSumClaim<F: Field, Composition> {
24	pub composition: Composition,
25	pub sum: F,
26}
27
28/// A group of claims about the sum of the values of multilinear composite polynomials over the
29/// boolean hypercube.
30///
31/// All polynomials in the group of claims are compositions of the same sequence of multilinear
32/// polynomials. By defining [`SumcheckClaim`] in this way, the sumcheck protocol can implement
33/// efficient batch proving and verification and reduce to a set of multilinear evaluations of the
34/// same polynomials. In other words, this grouping deduplicates prover work and proof data that
35/// would be redundant in a more naive implementation.
36#[derive(Debug, Clone, CopyGetters)]
37pub struct SumcheckClaim<F: Field, C> {
38	#[getset(get_copy = "pub")]
39	n_vars: usize,
40	#[getset(get_copy = "pub")]
41	n_multilinears: usize,
42	composite_sums: Vec<CompositeSumClaim<F, C>>,
43}
44
45impl<F: Field, Composition> SumcheckClaim<F, Composition>
46where
47	Composition: CompositionPoly<F>,
48{
49	/// Constructs a new sumcheck claim.
50	///
51	/// ## Throws
52	///
53	/// * [`Error::InvalidComposition`] if any of the composition polynomials in the composite
54	///   claims vector do not have their number of variables equal to `n_multilinears`
55	pub fn new(
56		n_vars: usize,
57		n_multilinears: usize,
58		composite_sums: Vec<CompositeSumClaim<F, Composition>>,
59	) -> Result<Self, Error> {
60		for CompositeSumClaim { composition, .. } in &composite_sums {
61			if composition.n_vars() != n_multilinears {
62				bail!(Error::InvalidComposition {
63					actual: composition.n_vars(),
64					expected: n_multilinears,
65				});
66			}
67		}
68		Ok(Self {
69			n_vars,
70			n_multilinears,
71			composite_sums,
72		})
73	}
74
75	/// Returns the maximum individual degree of all composite polynomials.
76	pub fn max_individual_degree(&self) -> usize {
77		self.composite_sums
78			.iter()
79			.map(|composite_sum| composite_sum.composition.degree())
80			.max()
81			.unwrap_or(0)
82	}
83
84	pub fn composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
85		&self.composite_sums
86	}
87}
88
89/// A univariate polynomial in monomial basis.
90///
91/// The coefficient at position `i` in the inner vector corresponds to the term $X^i$.
92#[derive(Debug, Default, Clone, PartialEq, Eq)]
93pub struct RoundCoeffs<F: Field>(pub Vec<F>);
94
95impl<F: Field> RoundCoeffs<F> {
96	/// Representation in an isomorphic field
97	pub fn isomorphic<FI: Field + From<F>>(self) -> RoundCoeffs<FI> {
98		RoundCoeffs(self.0.into_iter().map(Into::into).collect())
99	}
100
101	/// Truncate one coefficient from the polynomial to a more compact round proof.
102	pub fn truncate(mut self) -> RoundProof<F> {
103		self.0.pop();
104		RoundProof(self)
105	}
106}
107
108impl<F: Field> Add<&Self> for RoundCoeffs<F> {
109	type Output = Self;
110
111	fn add(mut self, rhs: &Self) -> Self::Output {
112		self += rhs;
113		self
114	}
115}
116
117impl<F: Field> AddAssign<&Self> for RoundCoeffs<F> {
118	fn add_assign(&mut self, rhs: &Self) {
119		if self.0.len() < rhs.0.len() {
120			self.0.resize(rhs.0.len(), F::ZERO);
121		}
122
123		for (lhs_i, &rhs_i) in self.0.iter_mut().zip(rhs.0.iter()) {
124			*lhs_i += rhs_i;
125		}
126	}
127}
128
129impl<F: Field> Mul<F> for RoundCoeffs<F> {
130	type Output = Self;
131
132	fn mul(mut self, rhs: F) -> Self::Output {
133		self *= rhs;
134		self
135	}
136}
137
138impl<F: Field> MulAssign<F> for RoundCoeffs<F> {
139	fn mul_assign(&mut self, rhs: F) {
140		for coeff in &mut self.0 {
141			*coeff *= rhs;
142		}
143	}
144}
145
146/// A sumcheck round proof is a univariate polynomial in monomial basis with the coefficient of the
147/// highest-degree term truncated off.
148///
149/// Since the verifier knows the claimed sum of the polynomial values at the points 0 and 1, the
150/// high-degree term coefficient can be easily recovered. Truncating the coefficient off saves a
151/// small amount of proof data.
152#[derive(Debug, Default, Clone, PartialEq, Eq)]
153pub struct RoundProof<F: Field>(pub RoundCoeffs<F>);
154
155impl<F: Field> RoundProof<F> {
156	/// Recovers all univariate polynomial coefficients from the compressed round proof.
157	///
158	/// The prover has sent coefficients for the purported ith round polynomial
159	/// $r_i(X) = \sum_{j=0}^d a_j * X^j$.
160	/// However, the prover has not sent the highest degree coefficient $a_d$.
161	/// The verifier will need to recover this missing coefficient.
162	///
163	/// Let $s$ denote the current round's claimed sum.
164	/// The verifier expects the round polynomial $r_i$ to satisfy the identity
165	/// $s = r_i(0) + r_i(1)$.
166	/// Using
167	///     $r_i(0) = a_0$
168	///     $r_i(1) = \sum_{j=0}^d a_j$
169	/// There is a unique $a_d$ that allows $r_i$ to satisfy the above identity.
170	/// Specifically
171	///     $a_d = s - a_0 - \sum_{j=0}^{d-1} a_j$
172	///
173	/// Not sending the whole round polynomial is an optimization.
174	/// In the unoptimized version of the protocol, the verifier will halt and reject
175	/// if given a round polynomial that does not satisfy the above identity.
176	pub fn recover(self, sum: F) -> RoundCoeffs<F> {
177		let Self(RoundCoeffs(mut coeffs)) = self;
178		let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
179		let last_coeff = sum - first_coeff - coeffs.iter().sum::<F>();
180		coeffs.push(last_coeff);
181		RoundCoeffs(coeffs)
182	}
183
184	/// The truncated polynomial coefficients.
185	pub fn coeffs(&self) -> &[F] {
186		&self.0.0
187	}
188
189	/// Representation in an isomorphic field
190	pub fn isomorphic<FI: Field + From<F>>(self) -> RoundProof<FI> {
191		RoundProof(self.0.isomorphic())
192	}
193}
194
195/// A sumcheck batch proof.
196#[derive(Debug, Default, Clone, PartialEq, Eq)]
197pub struct Proof<F: Field> {
198	/// The round proofs for each round.
199	pub rounds: Vec<RoundProof<F>>,
200	/// The claimed evaluations of all multilinears at the point defined by the sumcheck verifier
201	/// challenges.
202	///
203	/// The structure is a vector of vectors of field elements. Each entry of the outer vector
204	/// corresponds to one [`SumcheckClaim`] in a batch. Each inner vector contains the evaluations
205	/// of the multilinears referenced by that claim.
206	pub multilinear_evals: Vec<Vec<F>>,
207}
208
209/// Output of the batched sumcheck reduction
210#[derive(Debug, PartialEq, Eq)]
211pub struct BatchSumcheckOutput<F: Field> {
212	/// Sumcheck challenges - an evaluation point for the reduced claim.
213	pub challenges: Vec<F>,
214	/// Values of each multilinear (per claim, in descending `n_vars` order) at a suffix
215	/// of `challenges` of appropriate length.
216	pub multilinear_evals: Vec<Vec<F>>,
217}
218
219impl<F: Field> BatchSumcheckOutput<F> {
220	pub fn isomorphic<FI: Field + From<F>>(self) -> BatchSumcheckOutput<FI> {
221		BatchSumcheckOutput {
222			challenges: self.challenges.into_iter().map(Into::into).collect(),
223			multilinear_evals: self
224				.multilinear_evals
225				.into_iter()
226				.map(|prover_evals| prover_evals.into_iter().map(Into::into).collect())
227				.collect(),
228		}
229	}
230}
231
232/// Constructs a switchover function thaw returns the round number where folded multilinear is at
233/// least 2^k times smaller (in bytes) than the original, or 1 when not applicable.
234pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
235	move |extension_degree: usize| {
236		((extension_degree.ilog2() as isize + k).max(0) as usize).saturating_sub(1)
237	}
238}
239
240/// Sumcheck switchover heuristic that begins folding immediately in the first round.
241pub const fn immediate_switchover_heuristic(_extension_degree: usize) -> usize {
242	0
243}
244
245/// Check that all multilinears in a slice are of the same size.
246pub fn equal_n_vars_check<'a, P, M>(
247	multilinears: impl IntoIterator<Item = &'a M>,
248) -> Result<usize, Error>
249where
250	P: PackedField,
251	M: MultilinearPoly<P> + 'a,
252{
253	let mut multilinears = multilinears.into_iter();
254	let n_vars = multilinears
255		.next()
256		.map(|multilinear| multilinear.n_vars())
257		.unwrap_or_default();
258	for multilinear in multilinears {
259		if multilinear.n_vars() != n_vars {
260			bail!(Error::NumberOfVariablesMismatch);
261		}
262	}
263	Ok(n_vars)
264}
265
266/// Check that evaluations of all multilinears can actually be embedded in the scalar
267/// type of small field `PBase`.
268///
269/// Returns binary logarithm of the embedding degree.
270pub fn small_field_embedding_degree_check<F, FBase, P, M>(multilinears: &[M]) -> Result<(), Error>
271where
272	F: Field + ExtensionField<FBase>,
273	FBase: Field,
274	P: PackedField<Scalar = F>,
275	M: MultilinearPoly<P>,
276{
277	for multilinear in multilinears {
278		if multilinear.log_extension_degree() < F::LOG_DEGREE {
279			bail!(Error::MultilinearEvalsCannotBeEmbeddedInBaseField);
280		}
281	}
282
283	Ok(())
284}
285
286/// Multiply a sequence of field elements by the consecutive powers of `batch_coeff`
287pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item = F>) -> F {
288	// Multiplying by batch_coeff is important for security!
289	batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
290}
291
292/// Create interpolation domains for a sequence of composition degrees.
293pub fn interpolation_domains_for_composition_degrees<FDomain>(
294	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
295	degrees: impl IntoIterator<Item = usize>,
296) -> Result<Vec<InterpolationDomain<FDomain>>, Error>
297where
298	FDomain: Field,
299{
300	degrees
301		.into_iter()
302		.map(|degree| Ok(evaluation_domain_factory.create(degree + 1)?.into()))
303		.collect()
304}
305
306/// Validate the sumcheck evaluation domains to conform to the shape expected by the
307/// `SumcheckRoundCalculator`:
308///   1) First three points are zero, one, and Karatsuba infinity (for degrees above 1)
309///   2) All finite evaluation point slices are proper prefixes of the largest evaluation domain
310pub fn get_nontrivial_evaluation_points<F: Field>(
311	domains: &[InterpolationDomain<F>],
312) -> Result<Vec<F>, Error> {
313	let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
314		return Ok(Vec::new());
315	};
316
317	#[allow(clippy::get_first)]
318	if !domains.iter().all(|domain| {
319		(domain.size() <= 2 || domain.with_infinity())
320			&& domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
321			&& domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
322	}) {
323		bail!(Error::IncorrectSumcheckEvaluationDomain);
324	}
325
326	let finite_points = largest_domain.finite_points();
327
328	if domains
329		.iter()
330		.any(|domain| !finite_points.starts_with(domain.finite_points()))
331	{
332		bail!(Error::NonProperPrefixEvaluationDomain);
333	}
334
335	let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
336	Ok(nontrivial_evaluation_points)
337}
338
339#[cfg(test)]
340mod tests {
341	use binius_field::BinaryField64b;
342
343	use super::*;
344
345	type F = BinaryField64b;
346
347	#[test]
348	fn test_round_coeffs_truncate_non_empty() {
349		let coeffs = RoundCoeffs(vec![F::from(1), F::from(2), F::from(3)]);
350		let truncated = coeffs.truncate();
351		assert_eq!(truncated.0.0, vec![F::from(1), F::from(2)]);
352	}
353
354	#[test]
355	fn test_round_coeffs_truncate_empty() {
356		let coeffs = RoundCoeffs::<F>(vec![]);
357		let truncated = coeffs.truncate();
358		assert!(truncated.0.0.is_empty());
359	}
360}