binius_core/protocols/sumcheck/prove/
regular_sumcheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, ops::Range};
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
7use binius_hal::{ComputationBackend, SumcheckEvaluator, SumcheckMultilinear};
8use binius_math::{
9	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
10	MultilinearPoly, RowsBatchRef,
11};
12use binius_maybe_rayon::prelude::*;
13use binius_utils::bail;
14use itertools::izip;
15use stackalloc::stackalloc_with_default;
16use tracing::instrument;
17
18use crate::{
19	polynomial::{Error as PolynomialError, MultilinearComposite},
20	protocols::sumcheck::{
21		common::{
22			CompositeSumClaim, RoundCoeffs, equal_n_vars_check, get_nontrivial_evaluation_points,
23			interpolation_domains_for_composition_degrees,
24		},
25		error::Error,
26		prove::{ProverState, SumcheckInterpolator, SumcheckProver},
27	},
28};
29
30pub fn validate_witness<'a, F, P, M, Composition>(
31	multilinears: &[M],
32	sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
33) -> Result<(), Error>
34where
35	F: Field,
36	P: PackedField<Scalar = F>,
37	M: MultilinearPoly<P> + Send + Sync,
38	Composition: CompositionPoly<P> + 'a,
39{
40	let n_vars = multilinears
41		.first()
42		.map(|multilinear| multilinear.n_vars())
43		.unwrap_or_default();
44	for multilinear in multilinears {
45		if multilinear.n_vars() != n_vars {
46			bail!(Error::NumberOfVariablesMismatch);
47		}
48	}
49
50	let multilinears = multilinears.iter().collect::<Vec<_>>();
51
52	for (i, claim) in sum_claims.into_iter().enumerate() {
53		let CompositeSumClaim {
54			composition,
55			sum: expected_sum,
56			..
57		} = claim;
58		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
59		let sum = (0..(1 << n_vars))
60			.into_par_iter()
61			.map(|j| witness.evaluate_on_hypercube(j))
62			.try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
63
64		if sum != expected_sum {
65			bail!(Error::SumcheckNaiveValidationFailure {
66				composition_index: i,
67			});
68		}
69	}
70	Ok(())
71}
72
73pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
74where
75	FDomain: Field,
76	P: PackedField,
77	M: MultilinearPoly<P> + Send + Sync,
78	Backend: ComputationBackend,
79{
80	n_vars: usize,
81	state: ProverState<'a, FDomain, P, M, Backend>,
82	compositions: Vec<Composition>,
83	domains: Vec<InterpolationDomain<FDomain>>,
84}
85
86impl<'a, F, FDomain, P, Composition, M, Backend>
87	RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
88where
89	F: Field,
90	FDomain: Field,
91	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
92	Composition: CompositionPoly<P>,
93	M: MultilinearPoly<P> + Send + Sync,
94	Backend: ComputationBackend,
95{
96	#[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
97	pub fn new(
98		evaluation_order: EvaluationOrder,
99		multilinears: Vec<M>,
100		composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
101		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
102		switchover_fn: impl Fn(usize) -> usize,
103		backend: &'a Backend,
104	) -> Result<Self, Error> {
105		let n_vars = equal_n_vars_check(&multilinears)?;
106		let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
107
108		#[cfg(feature = "debug_validate_sumcheck")]
109		{
110			let composite_claims = composite_claims
111				.iter()
112				.map(|x| CompositeSumClaim {
113					sum: x.sum,
114					composition: &x.composition,
115				})
116				.collect::<Vec<_>>();
117			validate_witness(&multilinears, composite_claims)?;
118		}
119
120		for claim in &composite_claims {
121			if claim.composition.n_vars() != multilinears.len() {
122				bail!(Error::InvalidComposition {
123					actual: claim.composition.n_vars(),
124					expected: multilinears.len(),
125				});
126			}
127		}
128
129		let claimed_sums = composite_claims
130			.iter()
131			.map(|composite_claim| composite_claim.sum)
132			.collect();
133
134		let domains = interpolation_domains_for_composition_degrees(
135			evaluation_domain_factory,
136			composite_claims
137				.iter()
138				.map(|composite_claim| composite_claim.composition.degree()),
139		)?;
140
141		let compositions = composite_claims
142			.into_iter()
143			.map(|claim| claim.composition)
144			.collect();
145
146		let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
147
148		let multilinears = multilinears
149			.into_iter()
150			.map(|multilinear| SumcheckMultilinear::transparent(multilinear, &switchover_fn))
151			.collect();
152
153		let state = ProverState::new(
154			evaluation_order,
155			n_vars,
156			multilinears,
157			claimed_sums,
158			nontrivial_evaluation_points,
159			backend,
160		)?;
161
162		Ok(Self {
163			n_vars,
164			state,
165			compositions,
166			domains,
167		})
168	}
169}
170
171impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
172	for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
173where
174	F: TowerField + ExtensionField<FDomain>,
175	FDomain: Field,
176	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
177	Composition: CompositionPoly<P>,
178	M: MultilinearPoly<P> + Send + Sync,
179	Backend: ComputationBackend,
180{
181	fn n_vars(&self) -> usize {
182		self.n_vars
183	}
184
185	fn evaluation_order(&self) -> EvaluationOrder {
186		self.state.evaluation_order()
187	}
188
189	#[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
190	fn fold(&mut self, challenge: F) -> Result<(), Error> {
191		self.state.fold(challenge)?;
192		Ok(())
193	}
194
195	#[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
196	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
197		let evaluators = izip!(&self.compositions, &self.domains)
198			.map(|(composition, interpolation_domain)| {
199				let composition_at_infinity =
200					ArithCircuitPoly::new(composition.expression().leading_term());
201
202				RegularSumcheckEvaluator {
203					composition,
204					composition_at_infinity,
205					interpolation_domain,
206					_marker: PhantomData,
207				}
208			})
209			.collect::<Vec<_>>();
210
211		let round_evals = self.state.calculate_round_evals(&evaluators)?;
212		self.state
213			.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, round_evals)
214	}
215
216	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
217		self.state.finish()
218	}
219}
220
221struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
222where
223	P: PackedField,
224	FDomain: Field,
225{
226	composition: &'a Composition,
227	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
228	interpolation_domain: &'a InterpolationDomain<FDomain>,
229	_marker: PhantomData<P>,
230}
231
232impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
233	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
234where
235	F: TowerField + ExtensionField<FDomain>,
236	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
237	FDomain: Field,
238	Composition: CompositionPoly<P>,
239{
240	fn eval_point_indices(&self) -> Range<usize> {
241		// NB: We skip evaluation of $r(X)$ at $X = 0$ as it is derivable from the
242		// current_round_sum - $r(1)$.
243		1..self.composition.degree() + 1
244	}
245
246	fn process_subcube_at_eval_point(
247		&self,
248		_subcube_vars: usize,
249		_subcube_index: usize,
250		is_infinity_point: bool,
251		batch_query: &RowsBatchRef<P>,
252	) -> P {
253		let row_len = batch_query.row_len();
254
255		stackalloc_with_default(row_len, |evals| {
256			if is_infinity_point {
257				self.composition_at_infinity
258					.batch_evaluate(batch_query, evals)
259					.expect("correct by query construction invariant");
260			} else {
261				self.composition
262					.batch_evaluate(batch_query, evals)
263					.expect("correct by query construction invariant");
264			}
265
266			evals.iter().copied().sum()
267		})
268	}
269
270	fn composition(&self) -> &Composition {
271		self.composition
272	}
273
274	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
275		None
276	}
277}
278
279impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
280	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
281where
282	F: Field,
283	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
284	FDomain: Field,
285{
286	fn round_evals_to_coeffs(
287		&self,
288		last_round_sum: F,
289		mut round_evals: Vec<F>,
290	) -> Result<Vec<F>, PolynomialError> {
291		// Given $r(1), \ldots, r(d+1)$, letting $s$ be the current round's claimed sum,
292		// we can compute $r(0)$ using the identity $r(0) = s - r(1)$
293		round_evals.insert(0, last_round_sum - round_evals[0]);
294
295		if round_evals.len() > 3 {
296			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then
297			// subspace points. InterpolationDomain expects "infinity" at the last position, thus
298			// reordering is needed. Putting "special" evaluation points at the beginning of
299			// domain allows benefitting from faster/skipped interpolation even in case of mixed
300			// degree compositions .
301			let infinity_round_eval = round_evals.remove(2);
302			round_evals.push(infinity_round_eval);
303		}
304
305		let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
306		Ok(coeffs)
307	}
308}