1use binius_field::{Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{
6 extrapolate_line_scalar, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
7 MultilinearExtension, MultilinearPoly,
8};
9use binius_utils::{
10 bail,
11 sorting::{stable_sort, unsort},
12};
13use tracing::instrument;
14
15use super::{
16 gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
17 gpa_sumcheck::prove::GPAProver,
18 packed_field_storage::PackedFieldStorage,
19 Error, GrandProductClaim, GrandProductWitness,
20};
21use crate::{
22 composition::{BivariateProduct, IndexComposition},
23 fiat_shamir::{CanSample, Challenger},
24 protocols::sumcheck::{self, CompositeSumClaim},
25 transcript::ProverTranscript,
26};
27
28#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
34pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
35 evaluation_order: EvaluationOrder,
36 witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
37 claims: &[GrandProductClaim<F>],
38 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
39 transcript: &mut ProverTranscript<Challenger_>,
40 backend: &Backend,
41) -> Result<GrandProductBatchProveOutput<F>, Error>
42where
43 F: TowerField,
44 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
45 FDomain: Field,
46 Challenger_: Challenger,
47 Backend: ComputationBackend,
48{
49 let witness_vec = witnesses.into_iter().collect::<Vec<_>>();
52
53 let n_claims = claims.len();
54 if n_claims == 0 {
55 return Ok(GrandProductBatchProveOutput::default());
56 }
57 if witness_vec.len() != n_claims {
58 bail!(Error::MismatchedWitnessClaimLength);
59 }
60
61 let provers_vec = witness_vec
63 .iter()
64 .zip(claims)
65 .map(|(witness, claim)| GrandProductProverState::new(claim, witness, backend))
66 .collect::<Result<Vec<_>, _>>()?;
67
68 let (original_indices, mut sorted_provers) =
69 stable_sort(provers_vec, |prover| prover.input_vars(), true);
70
71 let max_n_vars = sorted_provers
72 .first()
73 .expect("sorted_provers is not empty by invariant")
74 .input_vars();
75
76 let mut reverse_sorted_final_layer_claims = Vec::with_capacity(n_claims);
77
78 for layer_no in 0..max_n_vars {
79 process_finished_provers(
81 layer_no,
82 &mut sorted_provers,
83 &mut reverse_sorted_final_layer_claims,
84 )?;
85
86 let batch_sumcheck_output = {
90 let gpa_sumcheck_prover = GrandProductProverState::stage_gpa_sumcheck_provers(
91 evaluation_order,
92 &sorted_provers,
93 evaluation_domain_factory.clone(),
94 )?;
95
96 sumcheck::batch_prove(vec![gpa_sumcheck_prover], transcript)?
97 };
98
99 let gpa_challenge = transcript.sample();
101
102 for (i, prover) in sorted_provers.iter_mut().enumerate() {
104 prover.finalize_batch_layer_proof(
105 batch_sumcheck_output.multilinear_evals[0][2 * i],
106 batch_sumcheck_output.multilinear_evals[0][2 * i + 1],
107 batch_sumcheck_output.challenges.clone(),
108 gpa_challenge,
109 )?;
110 }
111 }
112 process_finished_provers(
113 max_n_vars,
114 &mut sorted_provers,
115 &mut reverse_sorted_final_layer_claims,
116 )?;
117
118 debug_assert!(sorted_provers.is_empty());
119 debug_assert_eq!(reverse_sorted_final_layer_claims.len(), n_claims);
120
121 reverse_sorted_final_layer_claims.reverse();
122 let sorted_final_layer_claim = reverse_sorted_final_layer_claims;
123
124 let final_layer_claims = unsort(original_indices, sorted_final_layer_claim);
125
126 Ok(GrandProductBatchProveOutput { final_layer_claims })
127}
128
129fn process_finished_provers<F, P, Backend>(
130 layer_no: usize,
131 sorted_provers: &mut Vec<GrandProductProverState<'_, F, P, Backend>>,
132 reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
133) -> Result<(), Error>
134where
135 F: TowerField,
136 P: PackedField<Scalar = F>,
137 Backend: ComputationBackend,
138{
139 while let Some(prover) = sorted_provers.last() {
140 if prover.input_vars() != layer_no {
141 break;
142 }
143 debug_assert!(layer_no > 0);
144 let finished_prover = sorted_provers.pop().expect("not empty");
145 let final_layer_claim = finished_prover.finalize()?;
146 reverse_sorted_final_layer_claims.push(final_layer_claim);
147 }
148
149 Ok(())
150}
151
152#[derive(Debug)]
157struct GrandProductProverState<'a, F, P, Backend>
158where
159 F: Field + From<P::Scalar>,
160 P: PackedField,
161 P::Scalar: Field + From<F>,
162 Backend: ComputationBackend,
163{
164 n_vars: usize,
165 layers: Vec<MLEDirectAdapter<P, PackedFieldStorage<'a, P>>>,
168 next_layer_halves: Vec<[MLEDirectAdapter<P, PackedFieldStorage<'a, P>>; 2]>,
171 current_layer_claim: LayerClaim<F>,
173
174 backend: Backend,
175}
176
177impl<'a, F, P, Backend> GrandProductProverState<'a, F, P, Backend>
178where
179 F: TowerField + From<P::Scalar>,
180 P: PackedField<Scalar = F>,
181 Backend: ComputationBackend,
182{
183 fn new(
185 claim: &GrandProductClaim<F>,
186 witness: &'a GrandProductWitness<P>,
187 backend: Backend,
188 ) -> Result<Self, Error> {
189 let n_vars = claim.n_vars;
190 if n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
191 bail!(Error::ProverClaimWitnessMismatch);
192 }
193
194 let n_layers = n_vars + 1;
196 let next_layer_halves = (1..n_layers)
197 .map(|i| {
198 let (left_evals, right_evals) = witness.ith_layer_eval_halves(i)?;
199 let left = MultilinearExtension::try_from(left_evals)?;
200 let right = MultilinearExtension::try_from(right_evals)?;
201 Ok([left, right].map(MLEDirectAdapter::from))
202 })
203 .collect::<Result<Vec<_>, Error>>()?;
204
205 let layers = (0..n_layers)
206 .map(|i| {
207 let ith_layer_evals = witness.ith_layer_evals(i)?;
208 let ith_layer_evals = if P::LOG_WIDTH < i {
209 PackedFieldStorage::from(ith_layer_evals)
210 } else {
211 debug_assert_eq!(ith_layer_evals.len(), 1);
212 PackedFieldStorage::new_inline(ith_layer_evals[0].iter().take(1 << i))
213 .expect("length is a power of 2")
214 };
215
216 let mle = MultilinearExtension::try_from(ith_layer_evals)?;
217 Ok(mle.into())
218 })
219 .collect::<Result<Vec<_>, Error>>()?;
220
221 debug_assert_eq!(next_layer_halves.len(), n_vars);
222 debug_assert_eq!(layers.len(), n_vars + 1);
223
224 let layer_claim = LayerClaim {
226 eval_point: vec![],
227 eval: claim.product,
228 };
229
230 Ok(Self {
232 n_vars,
233 next_layer_halves,
234 layers,
235 current_layer_claim: layer_claim,
236 backend,
237 })
238 }
239
240 const fn input_vars(&self) -> usize {
241 self.n_vars
242 }
243
244 fn current_layer_no(&self) -> usize {
245 self.current_layer_claim.eval_point.len()
246 }
247
248 #[allow(clippy::type_complexity)]
249 #[instrument(skip_all, level = "debug")]
250 fn stage_gpa_sumcheck_provers<FDomain>(
251 evaluation_order: EvaluationOrder,
252 provers: &[Self],
253 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
254 ) -> Result<
255 GPAProver<
256 FDomain,
257 P,
258 IndexComposition<BivariateProduct, 2>,
259 impl MultilinearPoly<P> + Send + Sync + 'a,
260 Backend,
261 >,
262 Error,
263 >
264 where
265 FDomain: Field,
266 P: PackedExtension<FDomain>,
267 {
268 let Some(first_prover) = provers.first() else {
270 unreachable!();
271 };
272
273 let n_claims = provers.len();
275 let n_multilinears = provers.len() * 2;
276 let current_layer_no = first_prover.current_layer_no();
277
278 let mut composite_claims = Vec::with_capacity(n_claims);
279 let mut multilinears = Vec::with_capacity(n_multilinears);
280
281 for (i, prover) in provers.iter().enumerate() {
282 let indices = [2 * i, 2 * i + 1];
283
284 let composite_claim = CompositeSumClaim {
285 sum: prover.current_layer_claim.eval,
286 composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
287 };
288
289 composite_claims.push(composite_claim);
290 multilinears.extend(prover.next_layer_halves[current_layer_no].clone());
291 }
292
293 let first_layer_mle_advice = provers
294 .iter()
295 .map(|prover| prover.layers[current_layer_no].clone())
296 .collect::<Vec<_>>();
297
298 Ok(GPAProver::new(
299 evaluation_order,
300 multilinears,
301 Some(first_layer_mle_advice),
302 composite_claims,
303 evaluation_domain_factory,
304 &first_prover.current_layer_claim.eval_point,
305 &first_prover.backend,
306 )?)
307 }
308
309 fn finalize_batch_layer_proof(
310 &mut self,
311 zero_eval: F,
312 one_eval: F,
313 sumcheck_challenge: Vec<F>,
314 gpa_challenge: F,
315 ) -> Result<(), Error> {
316 if self.current_layer_no() >= self.input_vars() {
317 bail!(Error::TooManyRounds);
318 }
319 let new_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
320 let mut layer_challenge = sumcheck_challenge;
321 layer_challenge.push(gpa_challenge);
322
323 self.current_layer_claim = LayerClaim {
324 eval_point: layer_challenge,
325 eval: new_eval,
326 };
327
328 Ok(())
329 }
330
331 fn finalize(self) -> Result<LayerClaim<F>, Error> {
332 if self.current_layer_no() != self.input_vars() {
333 bail!(Error::PrematureFinalize);
334 }
335
336 let final_layer_claim = LayerClaim {
337 eval_point: self.current_layer_claim.eval_point,
338 eval: self.current_layer_claim.eval,
339 };
340 Ok(final_layer_claim)
341 }
342}