binius_core/protocols/gkr_gpa/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	packed::packed_from_fn_with_offset, Field, PackedExtension, PackedField, TowerField,
5};
6use binius_hal::ComputationBackend;
7use binius_math::{
8	extrapolate_line_scalar, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
9	MultilinearExtension, MultilinearPoly,
10};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::{
13	bail,
14	sorting::{stable_sort, unsort},
15};
16use tracing::instrument;
17
18use super::{
19	gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
20	packed_field_storage::PackedFieldStorage,
21	Error, GrandProductClaim, GrandProductWitness,
22};
23use crate::{
24	composition::{BivariateProduct, IndexComposition},
25	fiat_shamir::{CanSample, Challenger},
26	protocols::sumcheck::{
27		self, equal_n_vars_check, immediate_switchover_heuristic,
28		prove::eq_ind::{eq_ind_expand, EqIndSumcheckProver, EqIndSumcheckProverBuilder},
29		CompositeSumClaim, Error as SumcheckError,
30	},
31	transcript::ProverTranscript,
32};
33
34/// Proves batch reduction turning each GrandProductClaim into an EvalcheckMultilinearClaim
35///
36/// REQUIRES:
37/// * witnesses and claims are of the same length
38/// * The ith witness corresponds to the ith claim
39#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
40pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
41	evaluation_order: EvaluationOrder,
42	witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
43	claims: &[GrandProductClaim<F>],
44	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
45	transcript: &mut ProverTranscript<Challenger_>,
46	backend: &Backend,
47) -> Result<GrandProductBatchProveOutput<F>, Error>
48where
49	F: TowerField,
50	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
51	FDomain: Field,
52	Challenger_: Challenger,
53	Backend: ComputationBackend,
54{
55	//  Ensure witnesses and claims are of the same length, zip them together
56	// 	For each witness-claim pair, create GrandProductProver
57	let witness_vec = witnesses.into_iter().collect::<Vec<_>>();
58
59	let n_claims = claims.len();
60	if n_claims == 0 {
61		return Ok(GrandProductBatchProveOutput::default());
62	}
63	if witness_vec.len() != n_claims {
64		bail!(Error::MismatchedWitnessClaimLength);
65	}
66
67	// Create a vector of GrandProductProverStates
68	let provers_vec = witness_vec
69		.iter()
70		.zip(claims)
71		.map(|(witness, claim)| GrandProductProverState::new(claim, witness, backend))
72		.collect::<Result<Vec<_>, _>>()?;
73
74	let (original_indices, mut sorted_provers) =
75		stable_sort(provers_vec, |prover| prover.input_vars(), true);
76
77	let max_n_vars = sorted_provers
78		.first()
79		.expect("sorted_provers is not empty by invariant")
80		.input_vars();
81
82	let mut reverse_sorted_final_layer_claims = Vec::with_capacity(n_claims);
83
84	for layer_no in 0..max_n_vars {
85		// Step 1: Process finished provers
86		process_finished_provers(
87			layer_no,
88			&mut sorted_provers,
89			&mut reverse_sorted_final_layer_claims,
90		)?;
91
92		// Now we must create the batch layer proof for the kth to k+1th layer reduction
93
94		// Step 2: Create sumcheck batch proof
95		let batch_sumcheck_output = {
96			let gpa_sumcheck_prover = GrandProductProverState::stage_gpa_sumcheck_provers(
97				evaluation_order,
98				&sorted_provers,
99				evaluation_domain_factory.clone(),
100			)?;
101
102			sumcheck::batch_prove(vec![gpa_sumcheck_prover], transcript)?
103		};
104
105		// Step 3: Sample a challenge for the next layer
106		let gpa_challenge = transcript.sample();
107
108		// Step 4: Finalize each prover to update its internal current_layer_claim
109		for (i, prover) in sorted_provers.iter_mut().enumerate() {
110			prover.finalize_batch_layer_proof(
111				batch_sumcheck_output.multilinear_evals[0][2 * i],
112				batch_sumcheck_output.multilinear_evals[0][2 * i + 1],
113				batch_sumcheck_output.challenges.clone(),
114				gpa_challenge,
115			)?;
116		}
117	}
118	process_finished_provers(
119		max_n_vars,
120		&mut sorted_provers,
121		&mut reverse_sorted_final_layer_claims,
122	)?;
123
124	debug_assert!(sorted_provers.is_empty());
125	debug_assert_eq!(reverse_sorted_final_layer_claims.len(), n_claims);
126
127	reverse_sorted_final_layer_claims.reverse();
128	let sorted_final_layer_claim = reverse_sorted_final_layer_claims;
129
130	let final_layer_claims = unsort(original_indices, sorted_final_layer_claim);
131
132	Ok(GrandProductBatchProveOutput { final_layer_claims })
133}
134
135fn process_finished_provers<F, P, Backend>(
136	layer_no: usize,
137	sorted_provers: &mut Vec<GrandProductProverState<'_, F, P, Backend>>,
138	reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
139) -> Result<(), Error>
140where
141	F: TowerField,
142	P: PackedField<Scalar = F>,
143	Backend: ComputationBackend,
144{
145	while let Some(prover) = sorted_provers.last() {
146		if prover.input_vars() != layer_no {
147			break;
148		}
149		debug_assert!(layer_no > 0);
150		let finished_prover = sorted_provers.pop().expect("not empty");
151		let final_layer_claim = finished_prover.finalize()?;
152		reverse_sorted_final_layer_claims.push(final_layer_claim);
153	}
154
155	Ok(())
156}
157
158/// GPA protocol prover state
159///
160/// Coordinates the proving of a grand product claim before and after
161/// the sumcheck-based layer reductions.
162#[derive(Debug)]
163struct GrandProductProverState<'a, F, P, Backend>
164where
165	F: Field + From<P::Scalar>,
166	P: PackedField,
167	P::Scalar: Field + From<F>,
168	Backend: ComputationBackend,
169{
170	n_vars: usize,
171	// Layers of the product circuit as multilinear polynomials
172	// The ith element is the ith layer of the product circuit
173	layers: Vec<MLEDirectAdapter<P, PackedFieldStorage<'a, P>>>,
174	// The ith element consists of a tuple of the
175	// first and second halves of the (i+1)th layer of the product circuit
176	next_layer_halves: Vec<[MLEDirectAdapter<P, PackedFieldStorage<'a, P>>; 2]>,
177	// The current claim about a layer multilinear of the product circuit
178	current_layer_claim: LayerClaim<F>,
179
180	backend: Backend,
181}
182
183impl<'a, F, P, Backend> GrandProductProverState<'a, F, P, Backend>
184where
185	F: TowerField + From<P::Scalar>,
186	P: PackedField<Scalar = F>,
187	Backend: ComputationBackend,
188{
189	/// Create a new GrandProductProverState
190	fn new(
191		claim: &GrandProductClaim<F>,
192		witness: &'a GrandProductWitness<P>,
193		backend: Backend,
194	) -> Result<Self, Error> {
195		let n_vars = claim.n_vars;
196		if n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
197			bail!(Error::ProverClaimWitnessMismatch);
198		}
199
200		// Build multilinear polynomials from circuit evaluations
201		let n_layers = n_vars + 1;
202		let next_layer_halves = (1..n_layers)
203			.map(|i| {
204				let (left_evals, right_evals) = witness.ith_layer_eval_halves(i)?;
205				let left = MultilinearExtension::try_from(left_evals)?;
206				let right = MultilinearExtension::try_from(right_evals)?;
207				Ok([left, right].map(MLEDirectAdapter::from))
208			})
209			.collect::<Result<Vec<_>, Error>>()?;
210
211		let layers = (0..n_layers)
212			.map(|i| {
213				let ith_layer_evals = witness.ith_layer_evals(i)?;
214				let ith_layer_evals = if P::LOG_WIDTH < i {
215					PackedFieldStorage::from(ith_layer_evals)
216				} else {
217					debug_assert_eq!(ith_layer_evals.len(), 1);
218					PackedFieldStorage::new_inline(ith_layer_evals[0].iter().take(1 << i))
219						.expect("length is a power of 2")
220				};
221
222				let mle = MultilinearExtension::try_from(ith_layer_evals)?;
223				Ok(mle.into())
224			})
225			.collect::<Result<Vec<_>, Error>>()?;
226
227		debug_assert_eq!(next_layer_halves.len(), n_vars);
228		debug_assert_eq!(layers.len(), n_vars + 1);
229
230		// Initialize Layer Claim
231		let layer_claim = LayerClaim {
232			eval_point: vec![],
233			eval: claim.product,
234		};
235
236		// Return new GrandProductProver and the common product
237		Ok(Self {
238			n_vars,
239			next_layer_halves,
240			layers,
241			current_layer_claim: layer_claim,
242			backend,
243		})
244	}
245
246	const fn input_vars(&self) -> usize {
247		self.n_vars
248	}
249
250	fn current_layer_no(&self) -> usize {
251		self.current_layer_claim.eval_point.len()
252	}
253
254	#[allow(clippy::type_complexity)]
255	#[instrument(skip_all, level = "debug")]
256	fn stage_gpa_sumcheck_provers<FDomain>(
257		evaluation_order: EvaluationOrder,
258		provers: &[Self],
259		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
260	) -> Result<
261		EqIndSumcheckProver<
262			FDomain,
263			P,
264			IndexComposition<BivariateProduct, 2>,
265			impl MultilinearPoly<P> + Send + Sync + 'a,
266			Backend,
267		>,
268		Error,
269	>
270	where
271		FDomain: Field,
272		P: PackedExtension<FDomain>,
273	{
274		// test same layer
275		let Some(first_prover) = provers.first() else {
276			unreachable!();
277		};
278
279		// construct witness
280		let n_claims = provers.len();
281		let n_multilinears = provers.len() * 2;
282		let current_layer_no = first_prover.current_layer_no();
283
284		let mut composite_claims = Vec::with_capacity(n_claims);
285		let mut multilinears = Vec::with_capacity(n_multilinears);
286
287		for (i, prover) in provers.iter().enumerate() {
288			let indices = [2 * i, 2 * i + 1];
289
290			let composite_claim = CompositeSumClaim {
291				sum: prover.current_layer_claim.eval,
292				composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
293			};
294
295			composite_claims.push(composite_claim);
296			multilinears.extend(prover.next_layer_halves[current_layer_no].clone());
297		}
298
299		let eq_ind_challenges = &first_prover.current_layer_claim.eval_point;
300		let n_vars = eq_ind_challenges.len();
301
302		let first_layer_mle_advice = provers
303			.iter()
304			.map(|prover| prover.layers[current_layer_no].clone())
305			.collect::<Vec<_>>();
306
307		let eq_ind_partial_evals =
308			eq_ind_expand(evaluation_order, n_vars, eq_ind_challenges, &first_prover.backend)?;
309
310		let first_round_eval_1s = first_round_eval_1s_from_first_layer_mle_advice(
311			evaluation_order,
312			n_vars,
313			&first_layer_mle_advice,
314			&eq_ind_partial_evals,
315		)?;
316
317		let prover = EqIndSumcheckProverBuilder::new(&first_prover.backend)
318			.with_first_round_eval_1s(&first_round_eval_1s)
319			.with_eq_ind_partial_evals(eq_ind_partial_evals)
320			.build(
321				evaluation_order,
322				multilinears,
323				eq_ind_challenges,
324				composite_claims,
325				evaluation_domain_factory,
326				// We use GPA protocol only for big fields, which is why switchover is trivial
327				immediate_switchover_heuristic,
328			)?;
329
330		Ok(prover)
331	}
332
333	fn finalize_batch_layer_proof(
334		&mut self,
335		zero_eval: F,
336		one_eval: F,
337		sumcheck_challenge: Vec<F>,
338		gpa_challenge: F,
339	) -> Result<(), Error> {
340		if self.current_layer_no() >= self.input_vars() {
341			bail!(Error::TooManyRounds);
342		}
343		let new_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
344		let mut layer_challenge = sumcheck_challenge;
345		layer_challenge.push(gpa_challenge);
346
347		self.current_layer_claim = LayerClaim {
348			eval_point: layer_challenge,
349			eval: new_eval,
350		};
351
352		Ok(())
353	}
354
355	fn finalize(self) -> Result<LayerClaim<F>, Error> {
356		if self.current_layer_no() != self.input_vars() {
357			bail!(Error::PrematureFinalize);
358		}
359
360		let final_layer_claim = LayerClaim {
361			eval_point: self.current_layer_claim.eval_point,
362			eval: self.current_layer_claim.eval,
363		};
364		Ok(final_layer_claim)
365	}
366}
367
368pub fn first_round_eval_1s_from_first_layer_mle_advice<P, M>(
369	evaluation_order: EvaluationOrder,
370	n_vars: usize,
371	first_layer_mle_advice: &[M],
372	eq_ind_partial_evals: &[P],
373) -> Result<Vec<P::Scalar>, Error>
374where
375	P: PackedField,
376	M: MultilinearPoly<P> + Sync,
377{
378	let advice_n_vars = equal_n_vars_check(first_layer_mle_advice)?;
379
380	if n_vars != advice_n_vars {
381		bail!(Error::IncorrectFirstLayerAdviceLength);
382	}
383
384	if eq_ind_partial_evals.len() != 1 << n_vars.saturating_sub(P::LOG_WIDTH + 1) {
385		bail!(SumcheckError::IncorrectEqIndPartialEvalsSize);
386	}
387
388	let high_to_low_offset = 1 << n_vars.saturating_sub(1);
389	let first_round_eval_1s = first_layer_mle_advice
390		.into_par_iter()
391		.map(|advice_mle| {
392			let packed_sum = eq_ind_partial_evals
393				.par_iter()
394				.enumerate()
395				.map(|(i, &eq_ind)| {
396					eq_ind
397						* packed_from_fn_with_offset::<P>(i, |j| {
398							let index = match evaluation_order {
399								EvaluationOrder::LowToHigh => j << 1 | 1,
400								EvaluationOrder::HighToLow => j | high_to_low_offset,
401							};
402							advice_mle
403								.evaluate_on_hypercube(index)
404								.unwrap_or(P::Scalar::ZERO)
405						})
406				})
407				.sum::<P>();
408			packed_sum.iter().take(1 << n_vars).sum()
409		})
410		.collect();
411
412	Ok(first_round_eval_1s)
413}