binius_core/protocols/gkr_gpa/
prove.rs1use binius_field::{Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{extrapolate_line_scalar, EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::{
7 bail,
8 sorting::{stable_sort, unsort},
9};
10use itertools::izip;
11use tracing::instrument;
12
13use super::{
14 gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
15 Error, GrandProductClaim, GrandProductWitness,
16};
17use crate::{
18 composition::{BivariateProduct, IndexComposition},
19 fiat_shamir::{CanSample, Challenger},
20 protocols::sumcheck::{
21 prove::{eq_ind::EqIndSumcheckProverBuilder, front_loaded, SumcheckProver},
22 BatchSumcheckOutput, CompositeSumClaim,
23 },
24 transcript::ProverTranscript,
25};
26
27#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
33pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
34 evaluation_order: EvaluationOrder,
35 witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
36 claims: &[GrandProductClaim<F>],
37 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
38 transcript: &mut ProverTranscript<Challenger_>,
39 backend: &Backend,
40) -> Result<GrandProductBatchProveOutput<F>, Error>
41where
42 F: TowerField,
43 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
44 FDomain: Field,
45 Challenger_: Challenger,
46 Backend: ComputationBackend,
47{
48 let witnesses = witnesses.into_iter().collect::<Vec<_>>();
51
52 if witnesses.len() != claims.len() {
53 bail!(Error::MismatchedWitnessClaimLength);
54 }
55
56 let prover_states = izip!(witnesses, claims)
58 .map(|(witness, claim)| GrandProductProverState::new(claim, witness))
59 .collect::<Result<Vec<_>, _>>()?;
60
61 let (original_indices, mut sorted_prover_states) =
62 stable_sort(prover_states, |state| state.remaining_layers.len(), true);
63
64 let mut reverse_sorted_final_layer_claims = Vec::with_capacity(claims.len());
65 let mut eval_point = Vec::new();
66
67 loop {
68 process_finished_provers(
70 &mut sorted_prover_states,
71 &mut reverse_sorted_final_layer_claims,
72 &eval_point,
73 )?;
74
75 if sorted_prover_states.is_empty() {
76 break;
77 }
78
79 let BatchSumcheckOutput {
83 challenges,
84 multilinear_evals,
85 } = {
86 let eq_ind_sumcheck_prover = GrandProductProverState::stage_sumcheck_provers(
87 evaluation_order,
88 &mut sorted_prover_states,
89 evaluation_domain_factory.clone(),
90 &eval_point,
91 backend,
92 )?;
93
94 let batch_sumcheck_prover =
95 front_loaded::BatchProver::new(vec![eq_ind_sumcheck_prover], transcript)?;
96
97 let mut batch_output = batch_sumcheck_prover.run(transcript)?;
98
99 if evaluation_order == EvaluationOrder::HighToLow {
100 batch_output.challenges.reverse();
101 }
102
103 batch_output
104 };
105
106 let gpa_challenge = transcript.sample();
108
109 eval_point.copy_from_slice(&challenges);
110 eval_point.push(gpa_challenge);
111
112 debug_assert_eq!(multilinear_evals.len(), 1);
114 let multilinear_evals = multilinear_evals
115 .first()
116 .expect("exactly one prover in a batch");
117 for (state, evals) in izip!(&mut sorted_prover_states, multilinear_evals.chunks_exact(2)) {
118 state.update_layer_eval(evals[0], evals[1], gpa_challenge);
119 }
120 }
121 process_finished_provers(
122 &mut sorted_prover_states,
123 &mut reverse_sorted_final_layer_claims,
124 &eval_point,
125 )?;
126
127 debug_assert!(sorted_prover_states.is_empty());
128 debug_assert_eq!(reverse_sorted_final_layer_claims.len(), claims.len());
129
130 reverse_sorted_final_layer_claims.reverse();
131 let sorted_final_layer_claims = reverse_sorted_final_layer_claims;
132
133 let final_layer_claims = unsort(original_indices, sorted_final_layer_claims);
134 Ok(GrandProductBatchProveOutput { final_layer_claims })
135}
136
137fn process_finished_provers<F, P>(
138 sorted_prover_states: &mut Vec<GrandProductProverState<P>>,
139 reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
140 eval_point: &[F],
141) -> Result<(), Error>
142where
143 F: TowerField,
144 P: PackedField<Scalar = F>,
145{
146 let first_finished =
147 sorted_prover_states.partition_point(|state| !state.remaining_layers.is_empty());
148
149 for state in sorted_prover_states.drain(first_finished..).rev() {
150 reverse_sorted_final_layer_claims.push(state.finalize(eval_point)?);
151 }
152
153 Ok(())
154}
155
156#[derive(Debug)]
161struct GrandProductProverState<P>
162where
163 P: PackedField<Scalar: TowerField>,
164{
165 remaining_layers: Vec<Vec<P>>,
168 layer_eval: P::Scalar,
170}
171
172impl<F, P> GrandProductProverState<P>
173where
174 F: TowerField,
175 P: PackedField<Scalar = F>,
176{
177 fn new(claim: &GrandProductClaim<F>, witness: GrandProductWitness<P>) -> Result<Self, Error> {
179 if claim.n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
180 bail!(Error::ProverClaimWitnessMismatch);
181 }
182
183 let mut remaining_layers = witness.into_circuit_layers();
184 debug_assert_eq!(remaining_layers.len(), claim.n_vars + 1);
185 let _ = remaining_layers
186 .pop()
187 .expect("remaining_layers cannot be empty");
188
189 let layer_eval = claim.product;
191
192 Ok(Self {
194 remaining_layers,
195 layer_eval,
196 })
197 }
198
199 #[allow(clippy::type_complexity)]
200 #[instrument(skip_all, level = "debug")]
201 fn stage_sumcheck_provers<'a, FDomain, Backend>(
202 evaluation_order: EvaluationOrder,
203 states: &mut [Self],
204 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
205 eq_ind_challenges: &[P::Scalar],
206 backend: &'a Backend,
207 ) -> Result<impl SumcheckProver<P::Scalar> + 'a, Error>
208 where
209 FDomain: Field,
210 P: PackedExtension<FDomain>,
211 Backend: ComputationBackend,
212 {
213 let n_vars = eq_ind_challenges.len();
214 let n_claims = states.len();
215 let n_multilinears = n_claims * 2;
216
217 let mut composite_claims = Vec::with_capacity(n_claims);
218 let mut multilinears = Vec::with_capacity(n_multilinears);
219 let mut const_suffixes = Vec::with_capacity(n_multilinears);
220
221 for (i, state) in states.iter_mut().enumerate() {
222 let indices = [2 * i, 2 * i + 1];
223
224 let composite_claim = CompositeSumClaim {
225 sum: state.layer_eval,
226 composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
227 };
228
229 composite_claims.push(composite_claim);
230
231 let layer = state
232 .remaining_layers
233 .pop()
234 .expect("not staging more than n_vars times");
235
236 let multilinear_pair =
237 if n_vars >= P::LOG_WIDTH && layer.len() < 1 << (n_vars - P::LOG_WIDTH) {
238 [layer, vec![]]
239 } else if n_vars >= P::LOG_WIDTH {
240 let mut evals_0 = layer;
241 let evals_1 = evals_0.split_off(1 << (n_vars - P::LOG_WIDTH));
242 [evals_0, evals_1]
243 } else {
244 let mut evals_0 = P::zero();
245 let mut evals_1 = P::zero();
246 let only_packed = layer.first().copied().unwrap_or_else(P::one);
247
248 for i in 0..1 << n_vars {
249 evals_0.set(i, only_packed.get(i));
250 evals_1.set(i, only_packed.get(i | 1 << n_vars));
251 }
252
253 [vec![evals_0], vec![evals_1]]
254 };
255
256 for multilinear in multilinear_pair {
257 let suffix_len = (1usize << n_vars).saturating_sub(multilinear.len() * P::WIDTH);
258 const_suffixes.push((F::ONE, suffix_len));
259 multilinears.push(multilinear);
260 }
261 }
262
263 let prover = EqIndSumcheckProverBuilder::without_switchover(n_vars, multilinears, backend)
264 .with_const_suffixes(&const_suffixes)?
265 .build(
266 evaluation_order,
267 eq_ind_challenges,
268 composite_claims,
269 evaluation_domain_factory,
270 )?;
271
272 Ok(prover)
273 }
274
275 fn update_layer_eval(&mut self, zero_eval: F, one_eval: F, gpa_challenge: F) {
276 self.layer_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
277 }
278
279 fn finalize(self, eval_point: &[F]) -> Result<LayerClaim<F>, Error> {
280 if !self.remaining_layers.is_empty() {
281 bail!(Error::PrematureFinalize);
282 }
283
284 Ok(LayerClaim {
285 eval_point: eval_point.to_vec(),
286 eval: self.layer_eval,
287 })
288 }
289}