1use binius_field::{
4 packed::packed_from_fn_with_offset, Field, PackedExtension, PackedField, TowerField,
5};
6use binius_hal::ComputationBackend;
7use binius_math::{
8 extrapolate_line_scalar, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
9 MultilinearExtension, MultilinearPoly,
10};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::{
13 bail,
14 sorting::{stable_sort, unsort},
15};
16use tracing::instrument;
17
18use super::{
19 gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
20 packed_field_storage::PackedFieldStorage,
21 Error, GrandProductClaim, GrandProductWitness,
22};
23use crate::{
24 composition::{BivariateProduct, IndexComposition},
25 fiat_shamir::{CanSample, Challenger},
26 protocols::sumcheck::{
27 self, equal_n_vars_check, immediate_switchover_heuristic,
28 prove::eq_ind::{eq_ind_expand, EqIndSumcheckProver, EqIndSumcheckProverBuilder},
29 CompositeSumClaim, Error as SumcheckError,
30 },
31 transcript::ProverTranscript,
32};
33
34#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
40pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
41 evaluation_order: EvaluationOrder,
42 witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
43 claims: &[GrandProductClaim<F>],
44 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
45 transcript: &mut ProverTranscript<Challenger_>,
46 backend: &Backend,
47) -> Result<GrandProductBatchProveOutput<F>, Error>
48where
49 F: TowerField,
50 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
51 FDomain: Field,
52 Challenger_: Challenger,
53 Backend: ComputationBackend,
54{
55 let witness_vec = witnesses.into_iter().collect::<Vec<_>>();
58
59 let n_claims = claims.len();
60 if n_claims == 0 {
61 return Ok(GrandProductBatchProveOutput::default());
62 }
63 if witness_vec.len() != n_claims {
64 bail!(Error::MismatchedWitnessClaimLength);
65 }
66
67 let provers_vec = witness_vec
69 .iter()
70 .zip(claims)
71 .map(|(witness, claim)| GrandProductProverState::new(claim, witness, backend))
72 .collect::<Result<Vec<_>, _>>()?;
73
74 let (original_indices, mut sorted_provers) =
75 stable_sort(provers_vec, |prover| prover.input_vars(), true);
76
77 let max_n_vars = sorted_provers
78 .first()
79 .expect("sorted_provers is not empty by invariant")
80 .input_vars();
81
82 let mut reverse_sorted_final_layer_claims = Vec::with_capacity(n_claims);
83
84 for layer_no in 0..max_n_vars {
85 process_finished_provers(
87 layer_no,
88 &mut sorted_provers,
89 &mut reverse_sorted_final_layer_claims,
90 )?;
91
92 let batch_sumcheck_output = {
96 let gpa_sumcheck_prover = GrandProductProverState::stage_gpa_sumcheck_provers(
97 evaluation_order,
98 &sorted_provers,
99 evaluation_domain_factory.clone(),
100 )?;
101
102 sumcheck::batch_prove(vec![gpa_sumcheck_prover], transcript)?
103 };
104
105 let gpa_challenge = transcript.sample();
107
108 for (i, prover) in sorted_provers.iter_mut().enumerate() {
110 prover.finalize_batch_layer_proof(
111 batch_sumcheck_output.multilinear_evals[0][2 * i],
112 batch_sumcheck_output.multilinear_evals[0][2 * i + 1],
113 batch_sumcheck_output.challenges.clone(),
114 gpa_challenge,
115 )?;
116 }
117 }
118 process_finished_provers(
119 max_n_vars,
120 &mut sorted_provers,
121 &mut reverse_sorted_final_layer_claims,
122 )?;
123
124 debug_assert!(sorted_provers.is_empty());
125 debug_assert_eq!(reverse_sorted_final_layer_claims.len(), n_claims);
126
127 reverse_sorted_final_layer_claims.reverse();
128 let sorted_final_layer_claim = reverse_sorted_final_layer_claims;
129
130 let final_layer_claims = unsort(original_indices, sorted_final_layer_claim);
131
132 Ok(GrandProductBatchProveOutput { final_layer_claims })
133}
134
135fn process_finished_provers<F, P, Backend>(
136 layer_no: usize,
137 sorted_provers: &mut Vec<GrandProductProverState<'_, F, P, Backend>>,
138 reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
139) -> Result<(), Error>
140where
141 F: TowerField,
142 P: PackedField<Scalar = F>,
143 Backend: ComputationBackend,
144{
145 while let Some(prover) = sorted_provers.last() {
146 if prover.input_vars() != layer_no {
147 break;
148 }
149 debug_assert!(layer_no > 0);
150 let finished_prover = sorted_provers.pop().expect("not empty");
151 let final_layer_claim = finished_prover.finalize()?;
152 reverse_sorted_final_layer_claims.push(final_layer_claim);
153 }
154
155 Ok(())
156}
157
158#[derive(Debug)]
163struct GrandProductProverState<'a, F, P, Backend>
164where
165 F: Field + From<P::Scalar>,
166 P: PackedField,
167 P::Scalar: Field + From<F>,
168 Backend: ComputationBackend,
169{
170 n_vars: usize,
171 layers: Vec<MLEDirectAdapter<P, PackedFieldStorage<'a, P>>>,
174 next_layer_halves: Vec<[MLEDirectAdapter<P, PackedFieldStorage<'a, P>>; 2]>,
177 current_layer_claim: LayerClaim<F>,
179
180 backend: Backend,
181}
182
183impl<'a, F, P, Backend> GrandProductProverState<'a, F, P, Backend>
184where
185 F: TowerField + From<P::Scalar>,
186 P: PackedField<Scalar = F>,
187 Backend: ComputationBackend,
188{
189 fn new(
191 claim: &GrandProductClaim<F>,
192 witness: &'a GrandProductWitness<P>,
193 backend: Backend,
194 ) -> Result<Self, Error> {
195 let n_vars = claim.n_vars;
196 if n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
197 bail!(Error::ProverClaimWitnessMismatch);
198 }
199
200 let n_layers = n_vars + 1;
202 let next_layer_halves = (1..n_layers)
203 .map(|i| {
204 let (left_evals, right_evals) = witness.ith_layer_eval_halves(i)?;
205 let left = MultilinearExtension::try_from(left_evals)?;
206 let right = MultilinearExtension::try_from(right_evals)?;
207 Ok([left, right].map(MLEDirectAdapter::from))
208 })
209 .collect::<Result<Vec<_>, Error>>()?;
210
211 let layers = (0..n_layers)
212 .map(|i| {
213 let ith_layer_evals = witness.ith_layer_evals(i)?;
214 let ith_layer_evals = if P::LOG_WIDTH < i {
215 PackedFieldStorage::from(ith_layer_evals)
216 } else {
217 debug_assert_eq!(ith_layer_evals.len(), 1);
218 PackedFieldStorage::new_inline(ith_layer_evals[0].iter().take(1 << i))
219 .expect("length is a power of 2")
220 };
221
222 let mle = MultilinearExtension::try_from(ith_layer_evals)?;
223 Ok(mle.into())
224 })
225 .collect::<Result<Vec<_>, Error>>()?;
226
227 debug_assert_eq!(next_layer_halves.len(), n_vars);
228 debug_assert_eq!(layers.len(), n_vars + 1);
229
230 let layer_claim = LayerClaim {
232 eval_point: vec![],
233 eval: claim.product,
234 };
235
236 Ok(Self {
238 n_vars,
239 next_layer_halves,
240 layers,
241 current_layer_claim: layer_claim,
242 backend,
243 })
244 }
245
246 const fn input_vars(&self) -> usize {
247 self.n_vars
248 }
249
250 fn current_layer_no(&self) -> usize {
251 self.current_layer_claim.eval_point.len()
252 }
253
254 #[allow(clippy::type_complexity)]
255 #[instrument(skip_all, level = "debug")]
256 fn stage_gpa_sumcheck_provers<FDomain>(
257 evaluation_order: EvaluationOrder,
258 provers: &[Self],
259 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
260 ) -> Result<
261 EqIndSumcheckProver<
262 FDomain,
263 P,
264 IndexComposition<BivariateProduct, 2>,
265 impl MultilinearPoly<P> + Send + Sync + 'a,
266 Backend,
267 >,
268 Error,
269 >
270 where
271 FDomain: Field,
272 P: PackedExtension<FDomain>,
273 {
274 let Some(first_prover) = provers.first() else {
276 unreachable!();
277 };
278
279 let n_claims = provers.len();
281 let n_multilinears = provers.len() * 2;
282 let current_layer_no = first_prover.current_layer_no();
283
284 let mut composite_claims = Vec::with_capacity(n_claims);
285 let mut multilinears = Vec::with_capacity(n_multilinears);
286
287 for (i, prover) in provers.iter().enumerate() {
288 let indices = [2 * i, 2 * i + 1];
289
290 let composite_claim = CompositeSumClaim {
291 sum: prover.current_layer_claim.eval,
292 composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
293 };
294
295 composite_claims.push(composite_claim);
296 multilinears.extend(prover.next_layer_halves[current_layer_no].clone());
297 }
298
299 let eq_ind_challenges = &first_prover.current_layer_claim.eval_point;
300 let n_vars = eq_ind_challenges.len();
301
302 let first_layer_mle_advice = provers
303 .iter()
304 .map(|prover| prover.layers[current_layer_no].clone())
305 .collect::<Vec<_>>();
306
307 let eq_ind_partial_evals =
308 eq_ind_expand(evaluation_order, n_vars, eq_ind_challenges, &first_prover.backend)?;
309
310 let first_round_eval_1s = first_round_eval_1s_from_first_layer_mle_advice(
311 evaluation_order,
312 n_vars,
313 &first_layer_mle_advice,
314 &eq_ind_partial_evals,
315 )?;
316
317 let prover = EqIndSumcheckProverBuilder::new(&first_prover.backend)
318 .with_first_round_eval_1s(&first_round_eval_1s)
319 .with_eq_ind_partial_evals(eq_ind_partial_evals)
320 .build(
321 evaluation_order,
322 multilinears,
323 eq_ind_challenges,
324 composite_claims,
325 evaluation_domain_factory,
326 immediate_switchover_heuristic,
328 )?;
329
330 Ok(prover)
331 }
332
333 fn finalize_batch_layer_proof(
334 &mut self,
335 zero_eval: F,
336 one_eval: F,
337 sumcheck_challenge: Vec<F>,
338 gpa_challenge: F,
339 ) -> Result<(), Error> {
340 if self.current_layer_no() >= self.input_vars() {
341 bail!(Error::TooManyRounds);
342 }
343 let new_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
344 let mut layer_challenge = sumcheck_challenge;
345 layer_challenge.push(gpa_challenge);
346
347 self.current_layer_claim = LayerClaim {
348 eval_point: layer_challenge,
349 eval: new_eval,
350 };
351
352 Ok(())
353 }
354
355 fn finalize(self) -> Result<LayerClaim<F>, Error> {
356 if self.current_layer_no() != self.input_vars() {
357 bail!(Error::PrematureFinalize);
358 }
359
360 let final_layer_claim = LayerClaim {
361 eval_point: self.current_layer_claim.eval_point,
362 eval: self.current_layer_claim.eval,
363 };
364 Ok(final_layer_claim)
365 }
366}
367
368pub fn first_round_eval_1s_from_first_layer_mle_advice<P, M>(
369 evaluation_order: EvaluationOrder,
370 n_vars: usize,
371 first_layer_mle_advice: &[M],
372 eq_ind_partial_evals: &[P],
373) -> Result<Vec<P::Scalar>, Error>
374where
375 P: PackedField,
376 M: MultilinearPoly<P> + Sync,
377{
378 let advice_n_vars = equal_n_vars_check(first_layer_mle_advice)?;
379
380 if n_vars != advice_n_vars {
381 bail!(Error::IncorrectFirstLayerAdviceLength);
382 }
383
384 if eq_ind_partial_evals.len() != 1 << n_vars.saturating_sub(P::LOG_WIDTH + 1) {
385 bail!(SumcheckError::IncorrectEqIndPartialEvalsSize);
386 }
387
388 let high_to_low_offset = 1 << n_vars.saturating_sub(1);
389 let first_round_eval_1s = first_layer_mle_advice
390 .into_par_iter()
391 .map(|advice_mle| {
392 let packed_sum = eq_ind_partial_evals
393 .par_iter()
394 .enumerate()
395 .map(|(i, &eq_ind)| {
396 eq_ind
397 * packed_from_fn_with_offset::<P>(i, |j| {
398 let index = match evaluation_order {
399 EvaluationOrder::LowToHigh => j << 1 | 1,
400 EvaluationOrder::HighToLow => j | high_to_low_offset,
401 };
402 advice_mle
403 .evaluate_on_hypercube(index)
404 .unwrap_or(P::Scalar::ZERO)
405 })
406 })
407 .sum::<P>();
408 packed_sum.iter().take(1 << n_vars).sum()
409 })
410 .collect();
411
412 Ok(first_round_eval_1s)
413}