binius_core/protocols/sumcheck/prove/
regular_sumcheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
6use binius_hal::{ComputationBackend, SumcheckEvaluator};
7use binius_math::{
8	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain, MultilinearPoly,
9};
10use binius_maybe_rayon::prelude::*;
11use binius_utils::bail;
12use itertools::izip;
13use stackalloc::stackalloc_with_default;
14use tracing::instrument;
15
16use super::{batch_prove::SumcheckProver, prover_state::ProverState};
17use crate::{
18	polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
19	protocols::sumcheck::{
20		common::{get_nontrivial_evaluation_points, CompositeSumClaim, RoundCoeffs},
21		error::Error,
22		prove::prover_state::SumcheckInterpolator,
23	},
24};
25
26pub fn validate_witness<'a, F, P, M, Composition>(
27	multilinears: &[M],
28	sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
29) -> Result<(), Error>
30where
31	F: Field,
32	P: PackedField<Scalar = F>,
33	M: MultilinearPoly<P> + Send + Sync,
34	Composition: CompositionPoly<P> + 'a,
35{
36	let n_vars = multilinears
37		.first()
38		.map(|multilinear| multilinear.n_vars())
39		.unwrap_or_default();
40	for multilinear in multilinears {
41		if multilinear.n_vars() != n_vars {
42			bail!(Error::NumberOfVariablesMismatch);
43		}
44	}
45
46	let multilinears = multilinears.iter().collect::<Vec<_>>();
47
48	for (i, claim) in sum_claims.into_iter().enumerate() {
49		let CompositeSumClaim {
50			composition,
51			sum: expected_sum,
52			..
53		} = claim;
54		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
55		let sum = (0..(1 << n_vars))
56			.into_par_iter()
57			.map(|j| witness.evaluate_on_hypercube(j))
58			.try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
59
60		if sum != expected_sum {
61			bail!(Error::SumcheckNaiveValidationFailure {
62				composition_index: i,
63			});
64		}
65	}
66	Ok(())
67}
68
69pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
70where
71	FDomain: Field,
72	P: PackedField,
73	M: MultilinearPoly<P> + Send + Sync,
74	Backend: ComputationBackend,
75{
76	n_vars: usize,
77	state: ProverState<'a, FDomain, P, M, Backend>,
78	compositions: Vec<Composition>,
79	domains: Vec<InterpolationDomain<FDomain>>,
80}
81
82impl<'a, F, FDomain, P, Composition, M, Backend>
83	RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
84where
85	F: Field,
86	FDomain: Field,
87	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
88	Composition: CompositionPoly<P>,
89	M: MultilinearPoly<P> + Send + Sync,
90	Backend: ComputationBackend,
91{
92	#[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
93	pub fn new(
94		evaluation_order: EvaluationOrder,
95		multilinears: Vec<M>,
96		composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
97		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
98		switchover_fn: impl Fn(usize) -> usize,
99		backend: &'a Backend,
100	) -> Result<Self, Error> {
101		let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
102
103		#[cfg(feature = "debug_validate_sumcheck")]
104		{
105			let composite_claims = composite_claims
106				.iter()
107				.map(|x| CompositeSumClaim {
108					sum: x.sum,
109					composition: &x.composition,
110				})
111				.collect::<Vec<_>>();
112			validate_witness(&multilinears, composite_claims)?;
113		}
114
115		for claim in &composite_claims {
116			if claim.composition.n_vars() != multilinears.len() {
117				bail!(Error::InvalidComposition {
118					actual: claim.composition.n_vars(),
119					expected: multilinears.len(),
120				});
121			}
122		}
123
124		let claimed_sums = composite_claims
125			.iter()
126			.map(|composite_claim| composite_claim.sum)
127			.collect();
128
129		let domains = composite_claims
130			.iter()
131			.map(|composite_claim| {
132				let degree = composite_claim.composition.degree();
133				let domain =
134					evaluation_domain_factory.create_with_infinity(degree + 1, degree >= 2)?;
135				Ok(domain.into())
136			})
137			.collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
138			.map_err(Error::MathError)?;
139
140		let compositions = composite_claims
141			.into_iter()
142			.map(|claim| claim.composition)
143			.collect();
144
145		let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
146
147		let state = ProverState::new(
148			evaluation_order,
149			multilinears,
150			claimed_sums,
151			nontrivial_evaluation_points,
152			switchover_fn,
153			backend,
154		)?;
155		let n_vars = state.n_vars();
156
157		Ok(Self {
158			n_vars,
159			state,
160			compositions,
161			domains,
162		})
163	}
164}
165
166impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
167	for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
168where
169	F: TowerField + ExtensionField<FDomain>,
170	FDomain: Field,
171	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
172	Composition: CompositionPoly<P>,
173	M: MultilinearPoly<P> + Send + Sync,
174	Backend: ComputationBackend,
175{
176	fn n_vars(&self) -> usize {
177		self.n_vars
178	}
179
180	fn evaluation_order(&self) -> EvaluationOrder {
181		self.state.evaluation_order()
182	}
183
184	#[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
185	fn fold(&mut self, challenge: F) -> Result<(), Error> {
186		self.state.fold(challenge)?;
187		Ok(())
188	}
189
190	#[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
191	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
192		let evaluators = izip!(&self.compositions, &self.domains)
193			.map(|(composition, interpolation_domain)| {
194				let composition_at_infinity =
195					ArithCircuitPoly::new(composition.expression().leading_term());
196
197				RegularSumcheckEvaluator {
198					composition,
199					composition_at_infinity,
200					interpolation_domain,
201					_marker: PhantomData,
202				}
203			})
204			.collect::<Vec<_>>();
205
206		let evals = self.state.calculate_round_evals(&evaluators)?;
207		self.state
208			.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)
209	}
210
211	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
212		self.state.finish()
213	}
214}
215
216struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
217where
218	P: PackedField,
219	FDomain: Field,
220{
221	composition: &'a Composition,
222	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
223	interpolation_domain: &'a InterpolationDomain<FDomain>,
224	_marker: PhantomData<P>,
225}
226
227impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
228	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
229where
230	F: TowerField + ExtensionField<FDomain>,
231	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
232	FDomain: Field,
233	Composition: CompositionPoly<P>,
234{
235	fn eval_point_indices(&self) -> Range<usize> {
236		// NB: We skip evaluation of $r(X)$ at $X = 0$ as it is derivable from the
237		// current_round_sum - $r(1)$.
238		1..self.composition.degree() + 1
239	}
240
241	fn process_subcube_at_eval_point(
242		&self,
243		_subcube_vars: usize,
244		_subcube_index: usize,
245		is_infinity_point: bool,
246		batch_query: &[&[P]],
247	) -> P {
248		let row_len = batch_query.first().map_or(0, |row| row.len());
249
250		stackalloc_with_default(row_len, |evals| {
251			if is_infinity_point {
252				self.composition_at_infinity
253					.batch_evaluate(batch_query, evals)
254					.expect("correct by query construction invariant");
255			} else {
256				self.composition
257					.batch_evaluate(batch_query, evals)
258					.expect("correct by query construction invariant");
259			}
260
261			evals.iter().copied().sum()
262		})
263	}
264
265	fn composition(&self) -> &Composition {
266		self.composition
267	}
268
269	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
270		None
271	}
272}
273
274impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
275	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
276where
277	F: Field,
278	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
279	FDomain: Field,
280{
281	fn round_evals_to_coeffs(
282		&self,
283		last_round_sum: F,
284		mut round_evals: Vec<F>,
285	) -> Result<Vec<F>, PolynomialError> {
286		// Given $r(1), \ldots, r(d+1)$, letting $s$ be the current round's claimed sum,
287		// we can compute $r(0)$ using the identity $r(0) = s - r(1)$
288		round_evals.insert(0, last_round_sum - round_evals[0]);
289
290		if round_evals.len() > 3 {
291			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then subspace points.
292			// InterpolationDomain expects "infinity" at the last position, thus reordering is needed.
293			// Putting "special" evaluation points at the beginning of domain allows benefitting from
294			// faster/skipped interpolation even in case of mixed degree compositions .
295			let infinity_round_eval = round_evals.remove(2);
296			round_evals.push(infinity_round_eval);
297		}
298
299		let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
300		Ok(coeffs)
301	}
302}