binius_core/protocols/gkr_gpa/
prove.rs1use binius_field::{Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder, extrapolate_line_scalar};
6use binius_utils::{
7 bail,
8 sorting::{stable_sort, unsort},
9};
10use itertools::izip;
11use tracing::instrument;
12
13use super::{
14 Error, GrandProductClaim, GrandProductWitness,
15 gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
16};
17use crate::{
18 composition::{BivariateProduct, IndexComposition},
19 fiat_shamir::{CanSample, Challenger},
20 protocols::sumcheck::{
21 BatchSumcheckOutput, CompositeSumClaim,
22 prove::{SumcheckProver, eq_ind::EqIndSumcheckProverBuilder, front_loaded},
23 },
24 transcript::ProverTranscript,
25};
26
27#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
33pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
34 evaluation_order: EvaluationOrder,
35 witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
36 claims: &[GrandProductClaim<F>],
37 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
38 transcript: &mut ProverTranscript<Challenger_>,
39 backend: &Backend,
40) -> Result<GrandProductBatchProveOutput<F>, Error>
41where
42 F: TowerField,
43 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
44 FDomain: Field,
45 Challenger_: Challenger,
46 Backend: ComputationBackend,
47{
48 let witnesses = witnesses.into_iter().collect::<Vec<_>>();
51
52 if witnesses.len() != claims.len() {
53 bail!(Error::MismatchedWitnessClaimLength);
54 }
55
56 let prover_states = izip!(witnesses, claims)
58 .map(|(witness, claim)| GrandProductProverState::new(claim, witness))
59 .collect::<Result<Vec<_>, _>>()?;
60
61 let (original_indices, mut sorted_prover_states) =
62 stable_sort(prover_states, |state| state.remaining_layers.len(), true);
63
64 let mut reverse_sorted_final_layer_claims = Vec::with_capacity(claims.len());
65 let mut eval_point = Vec::new();
66
67 loop {
68 process_finished_provers(
70 &mut sorted_prover_states,
71 &mut reverse_sorted_final_layer_claims,
72 &eval_point,
73 )?;
74
75 if sorted_prover_states.is_empty() {
76 break;
77 }
78
79 let BatchSumcheckOutput {
83 challenges,
84 multilinear_evals,
85 } = {
86 let _layer_span = tracing::info_span!(
87 "[task] GKR GPA Layer Sumcheck",
88 phase = "exp",
89 perfetto_category = "task.main"
90 )
91 .entered();
92
93 let eq_ind_sumcheck_prover = GrandProductProverState::stage_sumcheck_provers(
94 evaluation_order,
95 &mut sorted_prover_states,
96 evaluation_domain_factory.clone(),
97 &eval_point,
98 backend,
99 )?;
100
101 let batch_sumcheck_prover =
102 front_loaded::BatchProver::new(vec![eq_ind_sumcheck_prover], transcript)?;
103
104 let mut batch_output = batch_sumcheck_prover.run(transcript)?;
105
106 if evaluation_order == EvaluationOrder::HighToLow {
107 batch_output.challenges.reverse();
108 }
109
110 batch_output
111 };
112
113 let gpa_challenge = transcript.sample();
115
116 eval_point.copy_from_slice(&challenges);
117 eval_point.push(gpa_challenge);
118
119 debug_assert_eq!(multilinear_evals.len(), 1);
121 let multilinear_evals = multilinear_evals
122 .first()
123 .expect("exactly one prover in a batch");
124 for (state, evals) in izip!(&mut sorted_prover_states, multilinear_evals.chunks_exact(2)) {
125 state.update_layer_eval(evals[0], evals[1], gpa_challenge);
126 }
127 }
128 process_finished_provers(
129 &mut sorted_prover_states,
130 &mut reverse_sorted_final_layer_claims,
131 &eval_point,
132 )?;
133
134 debug_assert!(sorted_prover_states.is_empty());
135 debug_assert_eq!(reverse_sorted_final_layer_claims.len(), claims.len());
136
137 reverse_sorted_final_layer_claims.reverse();
138 let sorted_final_layer_claims = reverse_sorted_final_layer_claims;
139
140 let final_layer_claims = unsort(original_indices, sorted_final_layer_claims);
141 Ok(GrandProductBatchProveOutput { final_layer_claims })
142}
143
144fn process_finished_provers<F, P>(
145 sorted_prover_states: &mut Vec<GrandProductProverState<P>>,
146 reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
147 eval_point: &[F],
148) -> Result<(), Error>
149where
150 F: TowerField,
151 P: PackedField<Scalar = F>,
152{
153 let first_finished =
154 sorted_prover_states.partition_point(|state| !state.remaining_layers.is_empty());
155
156 for state in sorted_prover_states.drain(first_finished..).rev() {
157 reverse_sorted_final_layer_claims.push(state.finalize(eval_point)?);
158 }
159
160 Ok(())
161}
162
163#[derive(Debug)]
168struct GrandProductProverState<P>
169where
170 P: PackedField<Scalar: TowerField>,
171{
172 remaining_layers: Vec<Vec<P>>,
175 layer_eval: P::Scalar,
177}
178
179impl<F, P> GrandProductProverState<P>
180where
181 F: TowerField,
182 P: PackedField<Scalar = F>,
183{
184 fn new(claim: &GrandProductClaim<F>, witness: GrandProductWitness<P>) -> Result<Self, Error> {
186 if claim.n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
187 bail!(Error::ProverClaimWitnessMismatch);
188 }
189
190 let mut remaining_layers = witness.into_circuit_layers();
191 debug_assert_eq!(remaining_layers.len(), claim.n_vars + 1);
192 let _ = remaining_layers
193 .pop()
194 .expect("remaining_layers cannot be empty");
195
196 let layer_eval = claim.product;
198
199 Ok(Self {
201 remaining_layers,
202 layer_eval,
203 })
204 }
205
206 #[allow(clippy::type_complexity)]
207 #[instrument(skip_all, level = "debug")]
208 fn stage_sumcheck_provers<'a, FDomain, Backend>(
209 evaluation_order: EvaluationOrder,
210 states: &mut [Self],
211 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
212 eq_ind_challenges: &[P::Scalar],
213 backend: &'a Backend,
214 ) -> Result<impl SumcheckProver<P::Scalar> + 'a, Error>
215 where
216 FDomain: Field,
217 P: PackedExtension<FDomain>,
218 Backend: ComputationBackend,
219 {
220 let n_vars = eq_ind_challenges.len();
221 let n_claims = states.len();
222 let n_multilinears = n_claims * 2;
223
224 let mut composite_claims = Vec::with_capacity(n_claims);
225 let mut multilinears = Vec::with_capacity(n_multilinears);
226 let mut const_suffixes = Vec::with_capacity(n_multilinears);
227
228 for (i, state) in states.iter_mut().enumerate() {
229 let indices = [2 * i, 2 * i + 1];
230
231 let composite_claim = CompositeSumClaim {
232 sum: state.layer_eval,
233 composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
234 };
235
236 composite_claims.push(composite_claim);
237
238 let layer = state
239 .remaining_layers
240 .pop()
241 .expect("not staging more than n_vars times");
242
243 let multilinear_pair =
244 if n_vars >= P::LOG_WIDTH && layer.len() < 1 << (n_vars - P::LOG_WIDTH) {
245 [layer, vec![]]
246 } else if n_vars >= P::LOG_WIDTH {
247 let mut evals_0 = layer;
248 let evals_1 = evals_0.split_off(1 << (n_vars - P::LOG_WIDTH));
249 [evals_0, evals_1]
250 } else {
251 let mut evals_0 = P::zero();
252 let mut evals_1 = P::zero();
253 let only_packed = layer.first().copied().unwrap_or_else(P::one);
254
255 for i in 0..1 << n_vars {
256 evals_0.set(i, only_packed.get(i));
257 evals_1.set(i, only_packed.get(i | 1 << n_vars));
258 }
259
260 [vec![evals_0], vec![evals_1]]
261 };
262
263 for multilinear in multilinear_pair {
264 let suffix_len = (1usize << n_vars).saturating_sub(multilinear.len() * P::WIDTH);
265 const_suffixes.push((F::ONE, suffix_len));
266 multilinears.push(multilinear);
267 }
268 }
269
270 let prover = EqIndSumcheckProverBuilder::without_switchover(n_vars, multilinears, backend)
271 .with_const_suffixes(&const_suffixes)?
272 .build(
273 evaluation_order,
274 eq_ind_challenges,
275 composite_claims,
276 evaluation_domain_factory,
277 )?;
278
279 Ok(prover)
280 }
281
282 fn update_layer_eval(&mut self, zero_eval: F, one_eval: F, gpa_challenge: F) {
283 self.layer_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
284 }
285
286 fn finalize(self, eval_point: &[F]) -> Result<LayerClaim<F>, Error> {
287 if !self.remaining_layers.is_empty() {
288 bail!(Error::PrematureFinalize);
289 }
290
291 Ok(LayerClaim {
292 eval_point: eval_point.to_vec(),
293 eval: self.layer_eval,
294 })
295 }
296}