binius_core/protocols/gkr_gpa/gpa_sumcheck/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::Range;
4
5use binius_field::{
6	packed::packed_from_fn_with_offset, util::eq, ExtensionField, Field, PackedExtension,
7	PackedField, TowerField,
8};
9use binius_hal::{ComputationBackend, SumcheckEvaluator};
10use binius_math::{
11	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain, MultilinearPoly,
12};
13use binius_maybe_rayon::prelude::*;
14use binius_utils::bail;
15use itertools::izip;
16use stackalloc::stackalloc_with_default;
17use tracing::{debug_span, instrument};
18
19use super::error::Error;
20use crate::{
21	polynomial::{ArithCircuitPoly, Error as PolynomialError},
22	protocols::sumcheck::{
23		get_nontrivial_evaluation_points, immediate_switchover_heuristic,
24		prove::{common, prover_state::ProverState, SumcheckInterpolator, SumcheckProver},
25		CompositeSumClaim, Error as SumcheckError, RoundCoeffs,
26	},
27};
28
29#[derive(Debug)]
30pub struct GPAProver<'a, FDomain, P, Composition, M, Backend>
31where
32	FDomain: Field,
33	P: PackedField,
34	M: MultilinearPoly<P> + Send + Sync,
35	Backend: ComputationBackend,
36{
37	n_vars: usize,
38	state: ProverState<'a, FDomain, P, M, Backend>,
39	eq_ind_eval: P::Scalar,
40	partial_eq_ind_evals: Backend::Vec<P>,
41	gpa_round_challenges: Vec<P::Scalar>,
42	compositions: Vec<Composition>,
43	domains: Vec<InterpolationDomain<FDomain>>,
44	first_round_eval_1s: Option<Vec<P::Scalar>>,
45}
46
47impl<'a, F, FDomain, P, Composition, M, Backend> GPAProver<'a, FDomain, P, Composition, M, Backend>
48where
49	F: TowerField + ExtensionField<FDomain>,
50	FDomain: Field,
51	P: PackedExtension<FDomain, Scalar = F>,
52	Composition: CompositionPoly<P>,
53	M: MultilinearPoly<P> + Send + Sync,
54	Backend: ComputationBackend,
55{
56	#[instrument(skip_all, level = "debug", name = "GPAProver::new")]
57	pub fn new(
58		evaluation_order: EvaluationOrder,
59		multilinears: Vec<M>,
60		first_layer_mle_advice: Option<Vec<M>>,
61		composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
62		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
63		gpa_round_challenges: &[F],
64		backend: &'a Backend,
65	) -> Result<Self, Error> {
66		let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
67
68		for claim in &composite_claims {
69			if claim.composition.n_vars() != multilinears.len() {
70				bail!(Error::InvalidComposition {
71					expected_n_vars: multilinears.len(),
72				});
73			}
74		}
75
76		if let Some(first_layer_mle_advice) = &first_layer_mle_advice {
77			if first_layer_mle_advice.len() != composite_claims.len() {
78				bail!(Error::IncorrectFirstLayerAdviceLength);
79			}
80		}
81
82		let claimed_sums = composite_claims
83			.iter()
84			.map(|composite_claim| composite_claim.sum)
85			.collect();
86
87		let domains = composite_claims
88			.par_iter()
89			.map(|composite_claim| {
90				let degree = composite_claim.composition.degree();
91				let domain =
92					evaluation_domain_factory.create_with_infinity(degree + 1, degree >= 2)?;
93				Ok(domain.into())
94			})
95			.collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
96			.map_err(Error::MathError)?;
97
98		let compositions = composite_claims
99			.into_iter()
100			.map(|claim| claim.composition)
101			.collect();
102
103		let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
104
105		let state = ProverState::new(
106			evaluation_order,
107			multilinears,
108			claimed_sums,
109			nontrivial_evaluation_points,
110			// We use GPA protocol only for big fields, which is why switchover is trivial
111			immediate_switchover_heuristic,
112			backend,
113		)?;
114		let n_vars = state.n_vars();
115
116		if gpa_round_challenges.len() != n_vars {
117			return Err(Error::IncorrectGPARoundChallengesLength);
118		}
119
120		let gpa_round_challenges = gpa_round_challenges.to_vec();
121
122		let partial_eq_ind_evals = backend
123			.tensor_product_full_query(match evaluation_order {
124				EvaluationOrder::LowToHigh => &gpa_round_challenges[n_vars.min(1)..],
125				EvaluationOrder::HighToLow => &gpa_round_challenges[..n_vars.saturating_sub(1)],
126			})
127			.map_err(SumcheckError::from)?;
128
129		let first_round_eval_1s = debug_span!("first_round_eval_1s").in_scope(|| {
130			// This block takes non-trivial amount of time, therefore, instrumenting it is needed.
131			let high_to_low_offset = 1 << n_vars.saturating_sub(1);
132			first_layer_mle_advice.map(|first_layer_mle_advice| {
133				first_layer_mle_advice
134					.into_par_iter()
135					.map(|poly_mle| {
136						let packed_sum = partial_eq_ind_evals
137							.par_iter()
138							.enumerate()
139							.map(|(i, &eq_ind)| {
140								eq_ind
141									* packed_from_fn_with_offset::<P>(i, |j| {
142										let index = match evaluation_order {
143											EvaluationOrder::LowToHigh => j << 1 | 1,
144											EvaluationOrder::HighToLow => j | high_to_low_offset,
145										};
146										poly_mle.evaluate_on_hypercube(index).unwrap_or(F::ZERO)
147									})
148							})
149							.sum::<P>();
150						packed_sum.iter().take(1 << n_vars).sum()
151					})
152					.collect::<Vec<_>>()
153			})
154		});
155
156		Ok(Self {
157			n_vars,
158			state,
159			eq_ind_eval: F::ONE,
160			partial_eq_ind_evals,
161			gpa_round_challenges,
162			compositions,
163			domains,
164			first_round_eval_1s,
165		})
166	}
167
168	fn gpa_round_challenge(&self) -> F {
169		match self.state.evaluation_order() {
170			EvaluationOrder::LowToHigh => self.gpa_round_challenges[self.round()],
171			EvaluationOrder::HighToLow => {
172				self.gpa_round_challenges[self.gpa_round_challenges.len() - 1 - self.round()]
173			}
174		}
175	}
176
177	fn update_eq_ind_eval(&mut self, challenge: F) {
178		// Update the running eq ind evaluation.
179		self.eq_ind_eval *= eq(self.gpa_round_challenge(), challenge);
180	}
181
182	#[instrument(skip_all, name = "GPAProver::fold_partial_eq_ind", level = "trace")]
183	fn fold_partial_eq_ind(&mut self) {
184		common::fold_partial_eq_ind::<P, Backend>(
185			self.state.evaluation_order(),
186			self.n_rounds_remaining(),
187			&mut self.partial_eq_ind_evals,
188		);
189	}
190
191	fn round(&self) -> usize {
192		self.n_vars - self.n_rounds_remaining()
193	}
194
195	fn n_rounds_remaining(&self) -> usize {
196		self.state.n_vars()
197	}
198}
199
200impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
201	for GPAProver<'_, FDomain, P, Composition, M, Backend>
202where
203	F: TowerField + ExtensionField<FDomain>,
204	FDomain: Field,
205	P: PackedExtension<FDomain, Scalar = F>,
206	Composition: CompositionPoly<P>,
207	M: MultilinearPoly<P> + Send + Sync,
208	Backend: ComputationBackend,
209{
210	fn n_vars(&self) -> usize {
211		self.n_vars
212	}
213
214	fn evaluation_order(&self) -> EvaluationOrder {
215		self.state.evaluation_order()
216	}
217
218	#[instrument(skip_all, name = "GPAProver::execute", level = "debug")]
219	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, SumcheckError> {
220		let round = self.round();
221		let alpha = self.gpa_round_challenge();
222
223		let evaluators = izip!(&self.compositions, &self.domains)
224			.enumerate()
225			.map(|(index, (composition, interpolation_domain))| {
226				let first_round_eval_1 = self
227					.first_round_eval_1s
228					.as_ref()
229					.map(|first_round_eval_1s| first_round_eval_1s[index])
230					.filter(|_| round == 0);
231
232				let composition_at_infinity =
233					ArithCircuitPoly::new(composition.expression().leading_term());
234
235				GPAEvaluator {
236					composition,
237					composition_at_infinity,
238					interpolation_domain,
239					first_round_eval_1,
240					partial_eq_ind_evals: &self.partial_eq_ind_evals,
241					gpa_round_challenge: alpha,
242				}
243			})
244			.collect::<Vec<_>>();
245
246		let evals = self.state.calculate_round_evals(&evaluators)?;
247		let coeffs =
248			self.state
249				.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?;
250
251		// Convert v' polynomial into v polynomial
252
253		// eq(X, α) = (1 − α) + (2 α − 1) X
254		// NB: In binary fields, this expression can be simplified to 1 + α + challenge. However,
255		// we opt to keep this prover generic over all fields. These two multiplications per round
256		// have negligible performance impact.
257		let constant_scalar = F::ONE - alpha;
258		let linear_scalar = alpha.double() - F::ONE;
259
260		let coeffs_scaled_by_constant_term = coeffs.clone() * constant_scalar;
261		let mut coeffs_scaled_by_linear_term = coeffs * linear_scalar;
262		coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); // Multiply polynomial by X
263
264		let sumcheck_coeffs = coeffs_scaled_by_constant_term + &coeffs_scaled_by_linear_term;
265		Ok(sumcheck_coeffs * self.eq_ind_eval)
266	}
267
268	#[instrument(skip_all, name = "GPAProver::fold", level = "debug")]
269	fn fold(&mut self, challenge: F) -> Result<(), SumcheckError> {
270		self.update_eq_ind_eval(challenge);
271		let n_rounds_remaining = self.n_rounds_remaining();
272		let evaluation_order = self.state.evaluation_order();
273		binius_maybe_rayon::join(
274			|| self.state.fold(challenge),
275			|| {
276				common::fold_partial_eq_ind::<P, Backend>(
277					evaluation_order,
278					n_rounds_remaining - 1,
279					&mut self.partial_eq_ind_evals,
280				)
281			},
282		)
283		.0?;
284		Ok(())
285	}
286
287	fn finish(self: Box<Self>) -> Result<Vec<F>, SumcheckError> {
288		let mut evals = self.state.finish()?;
289		evals.push(self.eq_ind_eval);
290		Ok(evals)
291	}
292}
293
294struct GPAEvaluator<'a, P, FDomain, Composition>
295where
296	P: PackedField,
297	FDomain: Field,
298{
299	composition: &'a Composition,
300	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
301	interpolation_domain: &'a InterpolationDomain<FDomain>,
302	partial_eq_ind_evals: &'a [P],
303	first_round_eval_1: Option<P::Scalar>,
304	gpa_round_challenge: P::Scalar,
305}
306
307impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
308	for GPAEvaluator<'_, P, FDomain, Composition>
309where
310	F: TowerField + ExtensionField<FDomain>,
311	P: PackedExtension<FDomain, Scalar = F>,
312	FDomain: Field,
313	Composition: CompositionPoly<P>,
314{
315	fn eval_point_indices(&self) -> Range<usize> {
316		// By definition of grand product GKR circuit, the composition evaluation is a multilinear
317		// extension representing the previous layer. Hence in first round we can use the previous
318		// layer as an advice instead of evaluating r(1).
319		// Also we can uniquely derive the degree d univariate round polynomial r from evaluations at
320		// X = 2, ..., d because we have an identity that relates r(0), r(1), and the current
321		// round's claimed sum.
322		let start_index = if self.first_round_eval_1.is_some() {
323			2
324		} else {
325			1
326		};
327		start_index..self.composition.degree() + 1
328	}
329
330	fn process_subcube_at_eval_point(
331		&self,
332		subcube_vars: usize,
333		subcube_index: usize,
334		is_infinity_point: bool,
335		batch_query: &[&[P]],
336	) -> P {
337		let row_len = batch_query.first().map_or(0, |row| row.len());
338
339		stackalloc_with_default(row_len, |evals| {
340			if is_infinity_point {
341				self.composition_at_infinity
342					.batch_evaluate(batch_query, evals)
343					.expect("correct by query construction invariant");
344			} else {
345				self.composition
346					.batch_evaluate(batch_query, evals)
347					.expect("correct by query construction invariant");
348			};
349
350			let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
351			for (i, eval) in evals.iter_mut().enumerate() {
352				*eval *= self.partial_eq_ind_evals[subcube_start + i];
353			}
354
355			evals.iter().copied().sum::<P>()
356		})
357	}
358
359	fn composition(&self) -> &Composition {
360		self.composition
361	}
362
363	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
364		Some(self.partial_eq_ind_evals)
365	}
366}
367
368impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
369	for GPAEvaluator<'_, P, FDomain, Composition>
370where
371	F: Field,
372	P: PackedExtension<FDomain, Scalar = F>,
373	FDomain: Field,
374	Composition: CompositionPoly<P>,
375{
376	#[instrument(
377		skip_all,
378		name = "GPAFirstRoundEvaluator::round_evals_to_coeffs",
379		level = "debug"
380	)]
381	fn round_evals_to_coeffs(
382		&self,
383		last_round_sum: F,
384		mut round_evals: Vec<F>,
385	) -> Result<Vec<F>, PolynomialError> {
386		if let Some(first_round_eval_1) = self.first_round_eval_1 {
387			round_evals.insert(0, first_round_eval_1);
388		}
389
390		let alpha = self.gpa_round_challenge;
391		let alpha_bar = F::ONE - alpha;
392		let one_evaluation = round_evals[0];
393		let zero_evaluation_numerator = last_round_sum - one_evaluation * alpha;
394		let zero_evaluation_denominator_inv = alpha_bar.invert().unwrap_or(F::ZERO);
395		let zero_evaluation = zero_evaluation_numerator * zero_evaluation_denominator_inv;
396
397		round_evals.insert(0, zero_evaluation);
398
399		if round_evals.len() > 3 {
400			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then subspace points.
401			// InterpolationDomain expects "infinity" at the last position, thus reordering is needed.
402			// Putting "special" evaluation points at the beginning of domain allows benefitting from
403			// faster/skipped interpolation even in case of mixed degree compositions .
404			let infinity_round_eval = round_evals.remove(2);
405			round_evals.push(infinity_round_eval);
406		}
407
408		let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
409		Ok(coeffs)
410	}
411}