binius_core/protocols/sumcheck/prove/
regular_sumcheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, ops::Range};
4
5use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
6use binius_hal::{ComputationBackend, SumcheckEvaluator};
7use binius_math::{
8	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
9	MultilinearPoly, RowsBatchRef,
10};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::bail;
13use itertools::izip;
14use stackalloc::stackalloc_with_default;
15use tracing::instrument;
16
17use crate::{
18	polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
19	protocols::sumcheck::{
20		common::{
21			get_nontrivial_evaluation_points, interpolation_domains_for_composition_degrees,
22			CompositeSumClaim, RoundCoeffs,
23		},
24		error::Error,
25		prove::{MultilinearInput, ProverState, SumcheckInterpolator, SumcheckProver},
26	},
27};
28
29pub fn validate_witness<'a, F, P, M, Composition>(
30	multilinears: &[M],
31	sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, &'a Composition>>,
32) -> Result<(), Error>
33where
34	F: Field,
35	P: PackedField<Scalar = F>,
36	M: MultilinearPoly<P> + Send + Sync,
37	Composition: CompositionPoly<P> + 'a,
38{
39	let n_vars = multilinears
40		.first()
41		.map(|multilinear| multilinear.n_vars())
42		.unwrap_or_default();
43	for multilinear in multilinears {
44		if multilinear.n_vars() != n_vars {
45			bail!(Error::NumberOfVariablesMismatch);
46		}
47	}
48
49	let multilinears = multilinears.iter().collect::<Vec<_>>();
50
51	for (i, claim) in sum_claims.into_iter().enumerate() {
52		let CompositeSumClaim {
53			composition,
54			sum: expected_sum,
55			..
56		} = claim;
57		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
58		let sum = (0..(1 << n_vars))
59			.into_par_iter()
60			.map(|j| witness.evaluate_on_hypercube(j))
61			.try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
62
63		if sum != expected_sum {
64			bail!(Error::SumcheckNaiveValidationFailure {
65				composition_index: i,
66			});
67		}
68	}
69	Ok(())
70}
71
72pub struct RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
73where
74	FDomain: Field,
75	P: PackedField,
76	M: MultilinearPoly<P> + Send + Sync,
77	Backend: ComputationBackend,
78{
79	n_vars: usize,
80	state: ProverState<'a, FDomain, P, M, Backend>,
81	compositions: Vec<Composition>,
82	domains: Vec<InterpolationDomain<FDomain>>,
83}
84
85impl<'a, F, FDomain, P, Composition, M, Backend>
86	RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>
87where
88	F: Field,
89	FDomain: Field,
90	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
91	Composition: CompositionPoly<P>,
92	M: MultilinearPoly<P> + Send + Sync,
93	Backend: ComputationBackend,
94{
95	#[instrument(skip_all, level = "debug", name = "RegularSumcheckProver::new")]
96	pub fn new(
97		evaluation_order: EvaluationOrder,
98		multilinears: Vec<M>,
99		composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
100		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
101		switchover_fn: impl Fn(usize) -> usize,
102		backend: &'a Backend,
103	) -> Result<Self, Error> {
104		let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
105
106		#[cfg(feature = "debug_validate_sumcheck")]
107		{
108			let composite_claims = composite_claims
109				.iter()
110				.map(|x| CompositeSumClaim {
111					sum: x.sum,
112					composition: &x.composition,
113				})
114				.collect::<Vec<_>>();
115			validate_witness(&multilinears, composite_claims)?;
116		}
117
118		for claim in &composite_claims {
119			if claim.composition.n_vars() != multilinears.len() {
120				bail!(Error::InvalidComposition {
121					actual: claim.composition.n_vars(),
122					expected: multilinears.len(),
123				});
124			}
125		}
126
127		let claimed_sums = composite_claims
128			.iter()
129			.map(|composite_claim| composite_claim.sum)
130			.collect();
131
132		let domains = interpolation_domains_for_composition_degrees(
133			evaluation_domain_factory,
134			composite_claims
135				.iter()
136				.map(|composite_claim| composite_claim.composition.degree()),
137		)?;
138
139		let compositions = composite_claims
140			.into_iter()
141			.map(|claim| claim.composition)
142			.collect();
143
144		let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
145
146		let multilinears_input = multilinears
147			.into_iter()
148			.map(|multilinear| MultilinearInput {
149				multilinear,
150				zero_scalars_suffix: 0,
151			})
152			.collect();
153
154		let state = ProverState::new(
155			evaluation_order,
156			multilinears_input,
157			claimed_sums,
158			nontrivial_evaluation_points,
159			switchover_fn,
160			backend,
161		)?;
162		let n_vars = state.n_vars();
163
164		Ok(Self {
165			n_vars,
166			state,
167			compositions,
168			domains,
169		})
170	}
171}
172
173impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
174	for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend>
175where
176	F: TowerField + ExtensionField<FDomain>,
177	FDomain: Field,
178	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
179	Composition: CompositionPoly<P>,
180	M: MultilinearPoly<P> + Send + Sync,
181	Backend: ComputationBackend,
182{
183	fn n_vars(&self) -> usize {
184		self.n_vars
185	}
186
187	fn evaluation_order(&self) -> EvaluationOrder {
188		self.state.evaluation_order()
189	}
190
191	#[instrument("RegularSumcheckProver::fold", skip_all, level = "debug")]
192	fn fold(&mut self, challenge: F) -> Result<(), Error> {
193		self.state.fold(challenge)?;
194		Ok(())
195	}
196
197	#[instrument("RegularSumcheckProver::execute", skip_all, level = "debug")]
198	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
199		let evaluators = izip!(&self.compositions, &self.domains)
200			.map(|(composition, interpolation_domain)| {
201				let composition_at_infinity =
202					ArithCircuitPoly::new(composition.expression().leading_term());
203
204				RegularSumcheckEvaluator {
205					composition,
206					composition_at_infinity,
207					interpolation_domain,
208					_marker: PhantomData,
209				}
210			})
211			.collect::<Vec<_>>();
212
213		let round_evals = self.state.calculate_round_evals(&evaluators)?;
214		self.state
215			.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, round_evals)
216	}
217
218	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
219		self.state.finish()
220	}
221}
222
223struct RegularSumcheckEvaluator<'a, P, FDomain, Composition>
224where
225	P: PackedField,
226	FDomain: Field,
227{
228	composition: &'a Composition,
229	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
230	interpolation_domain: &'a InterpolationDomain<FDomain>,
231	_marker: PhantomData<P>,
232}
233
234impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
235	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
236where
237	F: TowerField + ExtensionField<FDomain>,
238	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
239	FDomain: Field,
240	Composition: CompositionPoly<P>,
241{
242	fn eval_point_indices(&self) -> Range<usize> {
243		// NB: We skip evaluation of $r(X)$ at $X = 0$ as it is derivable from the
244		// current_round_sum - $r(1)$.
245		1..self.composition.degree() + 1
246	}
247
248	fn process_subcube_at_eval_point(
249		&self,
250		_subcube_vars: usize,
251		_subcube_index: usize,
252		is_infinity_point: bool,
253		batch_query: &RowsBatchRef<P>,
254	) -> P {
255		let row_len = batch_query.row_len();
256
257		stackalloc_with_default(row_len, |evals| {
258			if is_infinity_point {
259				self.composition_at_infinity
260					.batch_evaluate(batch_query, evals)
261					.expect("correct by query construction invariant");
262			} else {
263				self.composition
264					.batch_evaluate(batch_query, evals)
265					.expect("correct by query construction invariant");
266			}
267
268			evals.iter().copied().sum()
269		})
270	}
271
272	fn composition(&self) -> &Composition {
273		self.composition
274	}
275
276	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
277		None
278	}
279}
280
281impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
282	for RegularSumcheckEvaluator<'_, P, FDomain, Composition>
283where
284	F: Field,
285	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
286	FDomain: Field,
287{
288	fn round_evals_to_coeffs(
289		&self,
290		last_round_sum: F,
291		mut round_evals: Vec<F>,
292	) -> Result<Vec<F>, PolynomialError> {
293		// Given $r(1), \ldots, r(d+1)$, letting $s$ be the current round's claimed sum,
294		// we can compute $r(0)$ using the identity $r(0) = s - r(1)$
295		round_evals.insert(0, last_round_sum - round_evals[0]);
296
297		if round_evals.len() > 3 {
298			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then subspace points.
299			// InterpolationDomain expects "infinity" at the last position, thus reordering is needed.
300			// Putting "special" evaluation points at the beginning of domain allows benefitting from
301			// faster/skipped interpolation even in case of mixed degree compositions .
302			let infinity_round_eval = round_evals.remove(2);
303			round_evals.push(infinity_round_eval);
304		}
305
306		let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
307		Ok(coeffs)
308	}
309}