binius_core/protocols/sumcheck/prove/
regular_sumcheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
6use binius_hal::{ComputationBackend, SumcheckEvaluator, SumcheckMultilinear};
7use binius_math::{
8	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
9	MultilinearPoly, RowsBatchRef,
10};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::bail;
13use itertools::izip;
14use stackalloc::stackalloc_with_default;
15use tracing::instrument;
16
17use crate::{
18	polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
19	protocols::sumcheck::{
20		common::{
21			equal_n_vars_check, get_nontrivial_evaluation_points,
22			interpolation_domains_for_composition_degrees, CompositeSumClaim, RoundCoeffs,
23		},
24		error::Error,
25		prove::{ProverState, SumcheckInterpolator, SumcheckProver},
26	},
27};
28
29pub fn validate_witness<'a, F, P, M, Composition>(
30	multilinears: &[M],
31	sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
32) -> Result<(), Error>
33where
34	F: Field,
35	P: PackedField<Scalar = F>,
36	M: MultilinearPoly<P> + Send + Sync,
37	Composition: CompositionPoly<P> + 'a,
38{
39	let n_vars = multilinears
40		.first()
41		.map(|multilinear| multilinear.n_vars())
42		.unwrap_or_default();
43	for multilinear in multilinears {
44		if multilinear.n_vars() != n_vars {
45			bail!(Error::NumberOfVariablesMismatch);
46		}
47	}
48
49	let multilinears = multilinears.iter().collect::<Vec<_>>();
50
51	for (i, claim) in sum_claims.into_iter().enumerate() {
52		let CompositeSumClaim {
53			composition,
54			sum: expected_sum,
55			..
56		} = claim;
57		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
58		let sum = (0..(1 << n_vars))
59			.into_par_iter()
60			.map(|j| witness.evaluate_on_hypercube(j))
61			.try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
62
63		if sum != expected_sum {
64			bail!(Error::SumcheckNaiveValidationFailure {
65				composition_index: i,
66			});
67		}
68	}
69	Ok(())
70}
71
72pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
73where
74	FDomain: Field,
75	P: PackedField,
76	M: MultilinearPoly<P> + Send + Sync,
77	Backend: ComputationBackend,
78{
79	n_vars: usize,
80	state: ProverState<'a, FDomain, P, M, Backend>,
81	compositions: Vec<Composition>,
82	domains: Vec<InterpolationDomain<FDomain>>,
83}
84
85impl<'a, F, FDomain, P, Composition, M, Backend>
86	RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
87where
88	F: Field,
89	FDomain: Field,
90	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
91	Composition: CompositionPoly<P>,
92	M: MultilinearPoly<P> + Send + Sync,
93	Backend: ComputationBackend,
94{
95	#[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
96	pub fn new(
97		evaluation_order: EvaluationOrder,
98		multilinears: Vec<M>,
99		composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
100		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
101		switchover_fn: impl Fn(usize) -> usize,
102		backend: &'a Backend,
103	) -> Result<Self, Error> {
104		let n_vars = equal_n_vars_check(&multilinears)?;
105		let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
106
107		#[cfg(feature = "debug_validate_sumcheck")]
108		{
109			let composite_claims = composite_claims
110				.iter()
111				.map(|x| CompositeSumClaim {
112					sum: x.sum,
113					composition: &x.composition,
114				})
115				.collect::<Vec<_>>();
116			validate_witness(&multilinears, composite_claims)?;
117		}
118
119		for claim in &composite_claims {
120			if claim.composition.n_vars() != multilinears.len() {
121				bail!(Error::InvalidComposition {
122					actual: claim.composition.n_vars(),
123					expected: multilinears.len(),
124				});
125			}
126		}
127
128		let claimed_sums = composite_claims
129			.iter()
130			.map(|composite_claim| composite_claim.sum)
131			.collect();
132
133		let domains = interpolation_domains_for_composition_degrees(
134			evaluation_domain_factory,
135			composite_claims
136				.iter()
137				.map(|composite_claim| composite_claim.composition.degree()),
138		)?;
139
140		let compositions = composite_claims
141			.into_iter()
142			.map(|claim| claim.composition)
143			.collect();
144
145		let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
146
147		let multilinears = multilinears
148			.into_iter()
149			.map(|multilinear| SumcheckMultilinear::transparent(multilinear, &switchover_fn))
150			.collect();
151
152		let state = ProverState::new(
153			evaluation_order,
154			n_vars,
155			multilinears,
156			claimed_sums,
157			nontrivial_evaluation_points,
158			backend,
159		)?;
160
161		Ok(Self {
162			n_vars,
163			state,
164			compositions,
165			domains,
166		})
167	}
168}
169
170impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
171	for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
172where
173	F: TowerField + ExtensionField<FDomain>,
174	FDomain: Field,
175	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
176	Composition: CompositionPoly<P>,
177	M: MultilinearPoly<P> + Send + Sync,
178	Backend: ComputationBackend,
179{
180	fn n_vars(&self) -> usize {
181		self.n_vars
182	}
183
184	fn evaluation_order(&self) -> EvaluationOrder {
185		self.state.evaluation_order()
186	}
187
188	#[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
189	fn fold(&mut self, challenge: F) -> Result<(), Error> {
190		self.state.fold(challenge)?;
191		Ok(())
192	}
193
194	#[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
195	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
196		let evaluators = izip!(&self.compositions, &self.domains)
197			.map(|(composition, interpolation_domain)| {
198				let composition_at_infinity =
199					ArithCircuitPoly::new(composition.expression().leading_term());
200
201				RegularSumcheckEvaluator {
202					composition,
203					composition_at_infinity,
204					interpolation_domain,
205					_marker: PhantomData,
206				}
207			})
208			.collect::<Vec<_>>();
209
210		let round_evals = self.state.calculate_round_evals(&evaluators)?;
211		self.state
212			.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, round_evals)
213	}
214
215	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
216		self.state.finish()
217	}
218}
219
220struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
221where
222	P: PackedField,
223	FDomain: Field,
224{
225	composition: &'a Composition,
226	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
227	interpolation_domain: &'a InterpolationDomain<FDomain>,
228	_marker: PhantomData<P>,
229}
230
231impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
232	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
233where
234	F: TowerField + ExtensionField<FDomain>,
235	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
236	FDomain: Field,
237	Composition: CompositionPoly<P>,
238{
239	fn eval_point_indices(&self) -> Range<usize> {
240		// NB: We skip evaluation of $r(X)$ at $X = 0$ as it is derivable from the
241		// current_round_sum - $r(1)$.
242		1..self.composition.degree() + 1
243	}
244
245	fn process_subcube_at_eval_point(
246		&self,
247		_subcube_vars: usize,
248		_subcube_index: usize,
249		is_infinity_point: bool,
250		batch_query: &RowsBatchRef<P>,
251	) -> P {
252		let row_len = batch_query.row_len();
253
254		stackalloc_with_default(row_len, |evals| {
255			if is_infinity_point {
256				self.composition_at_infinity
257					.batch_evaluate(batch_query, evals)
258					.expect("correct by query construction invariant");
259			} else {
260				self.composition
261					.batch_evaluate(batch_query, evals)
262					.expect("correct by query construction invariant");
263			}
264
265			evals.iter().copied().sum()
266		})
267	}
268
269	fn composition(&self) -> &Composition {
270		self.composition
271	}
272
273	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
274		None
275	}
276}
277
278impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
279	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
280where
281	F: Field,
282	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
283	FDomain: Field,
284{
285	fn round_evals_to_coeffs(
286		&self,
287		last_round_sum: F,
288		mut round_evals: Vec<F>,
289	) -> Result<Vec<F>, PolynomialError> {
290		// Given $r(1), \ldots, r(d+1)$, letting $s$ be the current round's claimed sum,
291		// we can compute $r(0)$ using the identity $r(0) = s - r(1)$
292		round_evals.insert(0, last_round_sum - round_evals[0]);
293
294		if round_evals.len() > 3 {
295			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then subspace points.
296			// InterpolationDomain expects "infinity" at the last position, thus reordering is needed.
297			// Putting "special" evaluation points at the beginning of domain allows benefitting from
298			// faster/skipped interpolation even in case of mixed degree compositions .
299			let infinity_round_eval = round_evals.remove(2);
300			round_evals.push(infinity_round_eval);
301		}
302
303		let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
304		Ok(coeffs)
305	}
306}