1use std::ops::Range;
4
5use binius_field::{
6 packed::packed_from_fn_with_offset, util::eq, ExtensionField, Field, PackedExtension,
7 PackedField, TowerField,
8};
9use binius_hal::{ComputationBackend, SumcheckEvaluator};
10use binius_math::{
11 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain, MultilinearPoly,
12};
13use binius_maybe_rayon::prelude::*;
14use binius_utils::bail;
15use itertools::izip;
16use stackalloc::stackalloc_with_default;
17use tracing::{debug_span, instrument};
18
19use super::error::Error;
20use crate::{
21 polynomial::{ArithCircuitPoly, Error as PolynomialError},
22 protocols::sumcheck::{
23 get_nontrivial_evaluation_points, immediate_switchover_heuristic,
24 prove::{common, prover_state::ProverState, SumcheckInterpolator, SumcheckProver},
25 CompositeSumClaim, Error as SumcheckError, RoundCoeffs,
26 },
27};
28
29#[derive(Debug)]
30pub struct GPAProver<'a, FDomain, P, Composition, M, Backend>
31where
32 FDomain: Field,
33 P: PackedField,
34 M: MultilinearPoly<P> + Send + Sync,
35 Backend: ComputationBackend,
36{
37 n_vars: usize,
38 state: ProverState<'a, FDomain, P, M, Backend>,
39 eq_ind_eval: P::Scalar,
40 partial_eq_ind_evals: Backend::Vec<P>,
41 gpa_round_challenges: Vec<P::Scalar>,
42 compositions: Vec<Composition>,
43 domains: Vec<InterpolationDomain<FDomain>>,
44 first_round_eval_1s: Option<Vec<P::Scalar>>,
45}
46
47impl<'a, F, FDomain, P, Composition, M, Backend> GPAProver<'a, FDomain, P, Composition, M, Backend>
48where
49 F: TowerField + ExtensionField<FDomain>,
50 FDomain: Field,
51 P: PackedExtension<FDomain, Scalar = F>,
52 Composition: CompositionPoly<P>,
53 M: MultilinearPoly<P> + Send + Sync,
54 Backend: ComputationBackend,
55{
56 #[instrument(skip_all, level = "debug", name = "GPAProver::new")]
57 pub fn new(
58 evaluation_order: EvaluationOrder,
59 multilinears: Vec<M>,
60 first_layer_mle_advice: Option<Vec<M>>,
61 composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
62 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
63 gpa_round_challenges: &[F],
64 backend: &'a Backend,
65 ) -> Result<Self, Error> {
66 let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
67
68 for claim in &composite_claims {
69 if claim.composition.n_vars() != multilinears.len() {
70 bail!(Error::InvalidComposition {
71 expected_n_vars: multilinears.len(),
72 });
73 }
74 }
75
76 if let Some(first_layer_mle_advice) = &first_layer_mle_advice {
77 if first_layer_mle_advice.len() != composite_claims.len() {
78 bail!(Error::IncorrectFirstLayerAdviceLength);
79 }
80 }
81
82 let claimed_sums = composite_claims
83 .iter()
84 .map(|composite_claim| composite_claim.sum)
85 .collect();
86
87 let domains = composite_claims
88 .par_iter()
89 .map(|composite_claim| {
90 let degree = composite_claim.composition.degree();
91 let domain =
92 evaluation_domain_factory.create_with_infinity(degree + 1, degree >= 2)?;
93 Ok(domain.into())
94 })
95 .collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
96 .map_err(Error::MathError)?;
97
98 let compositions = composite_claims
99 .into_iter()
100 .map(|claim| claim.composition)
101 .collect();
102
103 let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
104
105 let state = ProverState::new(
106 evaluation_order,
107 multilinears,
108 claimed_sums,
109 nontrivial_evaluation_points,
110 immediate_switchover_heuristic,
112 backend,
113 )?;
114 let n_vars = state.n_vars();
115
116 if gpa_round_challenges.len() != n_vars {
117 return Err(Error::IncorrectGPARoundChallengesLength);
118 }
119
120 let gpa_round_challenges = gpa_round_challenges.to_vec();
121
122 let partial_eq_ind_evals = backend
123 .tensor_product_full_query(match evaluation_order {
124 EvaluationOrder::LowToHigh => &gpa_round_challenges[n_vars.min(1)..],
125 EvaluationOrder::HighToLow => &gpa_round_challenges[..n_vars.saturating_sub(1)],
126 })
127 .map_err(SumcheckError::from)?;
128
129 let first_round_eval_1s = debug_span!("first_round_eval_1s").in_scope(|| {
130 let high_to_low_offset = 1 << n_vars.saturating_sub(1);
132 first_layer_mle_advice.map(|first_layer_mle_advice| {
133 first_layer_mle_advice
134 .into_par_iter()
135 .map(|poly_mle| {
136 let packed_sum = partial_eq_ind_evals
137 .par_iter()
138 .enumerate()
139 .map(|(i, &eq_ind)| {
140 eq_ind
141 * packed_from_fn_with_offset::<P>(i, |j| {
142 let index = match evaluation_order {
143 EvaluationOrder::LowToHigh => j << 1 | 1,
144 EvaluationOrder::HighToLow => j | high_to_low_offset,
145 };
146 poly_mle.evaluate_on_hypercube(index).unwrap_or(F::ZERO)
147 })
148 })
149 .sum::<P>();
150 packed_sum.iter().take(1 << n_vars).sum()
151 })
152 .collect::<Vec<_>>()
153 })
154 });
155
156 Ok(Self {
157 n_vars,
158 state,
159 eq_ind_eval: F::ONE,
160 partial_eq_ind_evals,
161 gpa_round_challenges,
162 compositions,
163 domains,
164 first_round_eval_1s,
165 })
166 }
167
168 fn gpa_round_challenge(&self) -> F {
169 match self.state.evaluation_order() {
170 EvaluationOrder::LowToHigh => self.gpa_round_challenges[self.round()],
171 EvaluationOrder::HighToLow => {
172 self.gpa_round_challenges[self.gpa_round_challenges.len() - 1 - self.round()]
173 }
174 }
175 }
176
177 fn update_eq_ind_eval(&mut self, challenge: F) {
178 self.eq_ind_eval *= eq(self.gpa_round_challenge(), challenge);
180 }
181
182 #[instrument(skip_all, name = "GPAProver::fold_partial_eq_ind", level = "trace")]
183 fn fold_partial_eq_ind(&mut self) {
184 common::fold_partial_eq_ind::<P, Backend>(
185 self.state.evaluation_order(),
186 self.n_rounds_remaining(),
187 &mut self.partial_eq_ind_evals,
188 );
189 }
190
191 fn round(&self) -> usize {
192 self.n_vars - self.n_rounds_remaining()
193 }
194
195 fn n_rounds_remaining(&self) -> usize {
196 self.state.n_vars()
197 }
198}
199
200impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
201 for GPAProver<'_, FDomain, P, Composition, M, Backend>
202where
203 F: TowerField + ExtensionField<FDomain>,
204 FDomain: Field,
205 P: PackedExtension<FDomain, Scalar = F>,
206 Composition: CompositionPoly<P>,
207 M: MultilinearPoly<P> + Send + Sync,
208 Backend: ComputationBackend,
209{
210 fn n_vars(&self) -> usize {
211 self.n_vars
212 }
213
214 fn evaluation_order(&self) -> EvaluationOrder {
215 self.state.evaluation_order()
216 }
217
218 #[instrument(skip_all, name = "GPAProver::execute", level = "debug")]
219 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, SumcheckError> {
220 let round = self.round();
221 let alpha = self.gpa_round_challenge();
222
223 let evaluators = izip!(&self.compositions, &self.domains)
224 .enumerate()
225 .map(|(index, (composition, interpolation_domain))| {
226 let first_round_eval_1 = self
227 .first_round_eval_1s
228 .as_ref()
229 .map(|first_round_eval_1s| first_round_eval_1s[index])
230 .filter(|_| round == 0);
231
232 let composition_at_infinity =
233 ArithCircuitPoly::new(composition.expression().leading_term());
234
235 GPAEvaluator {
236 composition,
237 composition_at_infinity,
238 interpolation_domain,
239 first_round_eval_1,
240 partial_eq_ind_evals: &self.partial_eq_ind_evals,
241 gpa_round_challenge: alpha,
242 }
243 })
244 .collect::<Vec<_>>();
245
246 let evals = self.state.calculate_round_evals(&evaluators)?;
247 let coeffs =
248 self.state
249 .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?;
250
251 let constant_scalar = F::ONE - alpha;
258 let linear_scalar = alpha.double() - F::ONE;
259
260 let coeffs_scaled_by_constant_term = coeffs.clone() * constant_scalar;
261 let mut coeffs_scaled_by_linear_term = coeffs * linear_scalar;
262 coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); let sumcheck_coeffs = coeffs_scaled_by_constant_term + &coeffs_scaled_by_linear_term;
265 Ok(sumcheck_coeffs * self.eq_ind_eval)
266 }
267
268 #[instrument(skip_all, name = "GPAProver::fold", level = "debug")]
269 fn fold(&mut self, challenge: F) -> Result<(), SumcheckError> {
270 self.update_eq_ind_eval(challenge);
271 let n_rounds_remaining = self.n_rounds_remaining();
272 let evaluation_order = self.state.evaluation_order();
273 binius_maybe_rayon::join(
274 || self.state.fold(challenge),
275 || {
276 common::fold_partial_eq_ind::<P, Backend>(
277 evaluation_order,
278 n_rounds_remaining - 1,
279 &mut self.partial_eq_ind_evals,
280 )
281 },
282 )
283 .0?;
284 Ok(())
285 }
286
287 fn finish(self: Box<Self>) -> Result<Vec<F>, SumcheckError> {
288 let mut evals = self.state.finish()?;
289 evals.push(self.eq_ind_eval);
290 Ok(evals)
291 }
292}
293
294struct GPAEvaluator<'a, P, FDomain, Composition>
295where
296 P: PackedField,
297 FDomain: Field,
298{
299 composition: &'a Composition,
300 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
301 interpolation_domain: &'a InterpolationDomain<FDomain>,
302 partial_eq_ind_evals: &'a [P],
303 first_round_eval_1: Option<P::Scalar>,
304 gpa_round_challenge: P::Scalar,
305}
306
307impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
308 for GPAEvaluator<'_, P, FDomain, Composition>
309where
310 F: TowerField + ExtensionField<FDomain>,
311 P: PackedExtension<FDomain, Scalar = F>,
312 FDomain: Field,
313 Composition: CompositionPoly<P>,
314{
315 fn eval_point_indices(&self) -> Range<usize> {
316 let start_index = if self.first_round_eval_1.is_some() {
323 2
324 } else {
325 1
326 };
327 start_index..self.composition.degree() + 1
328 }
329
330 fn process_subcube_at_eval_point(
331 &self,
332 subcube_vars: usize,
333 subcube_index: usize,
334 is_infinity_point: bool,
335 batch_query: &[&[P]],
336 ) -> P {
337 let row_len = batch_query.first().map_or(0, |row| row.len());
338
339 stackalloc_with_default(row_len, |evals| {
340 if is_infinity_point {
341 self.composition_at_infinity
342 .batch_evaluate(batch_query, evals)
343 .expect("correct by query construction invariant");
344 } else {
345 self.composition
346 .batch_evaluate(batch_query, evals)
347 .expect("correct by query construction invariant");
348 };
349
350 let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
351 for (i, eval) in evals.iter_mut().enumerate() {
352 *eval *= self.partial_eq_ind_evals[subcube_start + i];
353 }
354
355 evals.iter().copied().sum::<P>()
356 })
357 }
358
359 fn composition(&self) -> &Composition {
360 self.composition
361 }
362
363 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
364 Some(self.partial_eq_ind_evals)
365 }
366}
367
368impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
369 for GPAEvaluator<'_, P, FDomain, Composition>
370where
371 F: Field,
372 P: PackedExtension<FDomain, Scalar = F>,
373 FDomain: Field,
374 Composition: CompositionPoly<P>,
375{
376 #[instrument(
377 skip_all,
378 name = "GPAFirstRoundEvaluator::round_evals_to_coeffs",
379 level = "debug"
380 )]
381 fn round_evals_to_coeffs(
382 &self,
383 last_round_sum: F,
384 mut round_evals: Vec<F>,
385 ) -> Result<Vec<F>, PolynomialError> {
386 if let Some(first_round_eval_1) = self.first_round_eval_1 {
387 round_evals.insert(0, first_round_eval_1);
388 }
389
390 let alpha = self.gpa_round_challenge;
391 let alpha_bar = F::ONE - alpha;
392 let one_evaluation = round_evals[0];
393 let zero_evaluation_numerator = last_round_sum - one_evaluation * alpha;
394 let zero_evaluation_denominator_inv = alpha_bar.invert().unwrap_or(F::ZERO);
395 let zero_evaluation = zero_evaluation_numerator * zero_evaluation_denominator_inv;
396
397 round_evals.insert(0, zero_evaluation);
398
399 if round_evals.len() > 3 {
400 let infinity_round_eval = round_evals.remove(2);
405 round_evals.push(infinity_round_eval);
406 }
407
408 let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
409 Ok(coeffs)
410 }
411}