binius_core/protocols/gkr_gpa/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{extrapolate_line_scalar, EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::{
7	bail,
8	sorting::{stable_sort, unsort},
9};
10use itertools::izip;
11use tracing::instrument;
12
13use super::{
14	gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
15	Error, GrandProductClaim, GrandProductWitness,
16};
17use crate::{
18	composition::{BivariateProduct, IndexComposition},
19	fiat_shamir::{CanSample, Challenger},
20	protocols::sumcheck::{
21		prove::{eq_ind::EqIndSumcheckProverBuilder, front_loaded, SumcheckProver},
22		BatchSumcheckOutput, CompositeSumClaim,
23	},
24	transcript::ProverTranscript,
25};
26
27/// Proves batch reduction turning each GrandProductClaim into LayerClaim on original multilinear.
28///
29/// REQUIRES:
30/// * witnesses and claims are of the same length
31/// * The ith witness corresponds to the ith claim
32#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
33pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
34	evaluation_order: EvaluationOrder,
35	witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
36	claims: &[GrandProductClaim<F>],
37	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
38	transcript: &mut ProverTranscript<Challenger_>,
39	backend: &Backend,
40) -> Result<GrandProductBatchProveOutput<F>, Error>
41where
42	F: TowerField,
43	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
44	FDomain: Field,
45	Challenger_: Challenger,
46	Backend: ComputationBackend,
47{
48	//  Ensure witnesses and claims are of the same length, zip them together
49	// 	For each witness-claim pair, create GrandProductProverState
50	let witnesses = witnesses.into_iter().collect::<Vec<_>>();
51
52	if witnesses.len() != claims.len() {
53		bail!(Error::MismatchedWitnessClaimLength);
54	}
55
56	// Create a vector of GrandProductProverStates
57	let prover_states = izip!(witnesses, claims)
58		.map(|(witness, claim)| GrandProductProverState::new(claim, witness))
59		.collect::<Result<Vec<_>, _>>()?;
60
61	let (original_indices, mut sorted_prover_states) =
62		stable_sort(prover_states, |state| state.remaining_layers.len(), true);
63
64	let mut reverse_sorted_final_layer_claims = Vec::with_capacity(claims.len());
65	let mut eval_point = Vec::new();
66
67	loop {
68		// Step 1: Process finished provers
69		process_finished_provers(
70			&mut sorted_prover_states,
71			&mut reverse_sorted_final_layer_claims,
72			&eval_point,
73		)?;
74
75		if sorted_prover_states.is_empty() {
76			break;
77		}
78
79		// Now we must create the batch layer proof for the kth to k+1th layer reduction
80
81		// Step 2: Create sumcheck batch proof
82		let BatchSumcheckOutput {
83			challenges,
84			multilinear_evals,
85		} = {
86			let eq_ind_sumcheck_prover = GrandProductProverState::stage_sumcheck_provers(
87				evaluation_order,
88				&mut sorted_prover_states,
89				evaluation_domain_factory.clone(),
90				&eval_point,
91				backend,
92			)?;
93
94			let batch_sumcheck_prover =
95				front_loaded::BatchProver::new(vec![eq_ind_sumcheck_prover], transcript)?;
96
97			let mut batch_output = batch_sumcheck_prover.run(transcript)?;
98
99			if evaluation_order == EvaluationOrder::HighToLow {
100				batch_output.challenges.reverse();
101			}
102
103			batch_output
104		};
105
106		// Step 3: Sample a challenge for the next layer
107		let gpa_challenge = transcript.sample();
108
109		eval_point.copy_from_slice(&challenges);
110		eval_point.push(gpa_challenge);
111
112		// Step 4: Finalize each prover to update its internal current_layer_claim
113		debug_assert_eq!(multilinear_evals.len(), 1);
114		let multilinear_evals = multilinear_evals
115			.first()
116			.expect("exactly one prover in a batch");
117		for (state, evals) in izip!(&mut sorted_prover_states, multilinear_evals.chunks_exact(2)) {
118			state.update_layer_eval(evals[0], evals[1], gpa_challenge);
119		}
120	}
121	process_finished_provers(
122		&mut sorted_prover_states,
123		&mut reverse_sorted_final_layer_claims,
124		&eval_point,
125	)?;
126
127	debug_assert!(sorted_prover_states.is_empty());
128	debug_assert_eq!(reverse_sorted_final_layer_claims.len(), claims.len());
129
130	reverse_sorted_final_layer_claims.reverse();
131	let sorted_final_layer_claims = reverse_sorted_final_layer_claims;
132
133	let final_layer_claims = unsort(original_indices, sorted_final_layer_claims);
134	Ok(GrandProductBatchProveOutput { final_layer_claims })
135}
136
137fn process_finished_provers<F, P>(
138	sorted_prover_states: &mut Vec<GrandProductProverState<P>>,
139	reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
140	eval_point: &[F],
141) -> Result<(), Error>
142where
143	F: TowerField,
144	P: PackedField<Scalar = F>,
145{
146	let first_finished =
147		sorted_prover_states.partition_point(|state| !state.remaining_layers.is_empty());
148
149	for state in sorted_prover_states.drain(first_finished..).rev() {
150		reverse_sorted_final_layer_claims.push(state.finalize(eval_point)?);
151	}
152
153	Ok(())
154}
155
156/// GPA protocol state for a single witness
157///
158/// Coordinates the proving of a grand product claim before and after
159/// the sumcheck-based layer reductions.
160#[derive(Debug)]
161struct GrandProductProverState<P>
162where
163	P: PackedField<Scalar: TowerField>,
164{
165	// Remaining layers of the product circuit, ordered from largest to smallest.
166	// Each step removes the last layer.
167	remaining_layers: Vec<Vec<P>>,
168	// The current eval claim (on a shared eval point).
169	layer_eval: P::Scalar,
170}
171
172impl<F, P> GrandProductProverState<P>
173where
174	F: TowerField,
175	P: PackedField<Scalar = F>,
176{
177	/// Create a new GrandProductProverState
178	fn new(claim: &GrandProductClaim<F>, witness: GrandProductWitness<P>) -> Result<Self, Error> {
179		if claim.n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
180			bail!(Error::ProverClaimWitnessMismatch);
181		}
182
183		let mut remaining_layers = witness.into_circuit_layers();
184		debug_assert_eq!(remaining_layers.len(), claim.n_vars + 1);
185		let _ = remaining_layers
186			.pop()
187			.expect("remaining_layers cannot be empty");
188
189		// Initialize Layer Claim
190		let layer_eval = claim.product;
191
192		// Return new GrandProductProver and the common product
193		Ok(Self {
194			remaining_layers,
195			layer_eval,
196		})
197	}
198
199	#[allow(clippy::type_complexity)]
200	#[instrument(skip_all, level = "debug")]
201	fn stage_sumcheck_provers<'a, FDomain, Backend>(
202		evaluation_order: EvaluationOrder,
203		states: &mut [Self],
204		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
205		eq_ind_challenges: &[P::Scalar],
206		backend: &'a Backend,
207	) -> Result<impl SumcheckProver<P::Scalar> + 'a, Error>
208	where
209		FDomain: Field,
210		P: PackedExtension<FDomain>,
211		Backend: ComputationBackend,
212	{
213		let n_vars = eq_ind_challenges.len();
214		let n_claims = states.len();
215		let n_multilinears = n_claims * 2;
216
217		let mut composite_claims = Vec::with_capacity(n_claims);
218		let mut multilinears = Vec::with_capacity(n_multilinears);
219		let mut const_suffixes = Vec::with_capacity(n_multilinears);
220
221		for (i, state) in states.iter_mut().enumerate() {
222			let indices = [2 * i, 2 * i + 1];
223
224			let composite_claim = CompositeSumClaim {
225				sum: state.layer_eval,
226				composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
227			};
228
229			composite_claims.push(composite_claim);
230
231			let layer = state
232				.remaining_layers
233				.pop()
234				.expect("not staging more than n_vars times");
235
236			let multilinear_pair =
237				if n_vars >= P::LOG_WIDTH && layer.len() < 1 << (n_vars - P::LOG_WIDTH) {
238					[layer, vec![]]
239				} else if n_vars >= P::LOG_WIDTH {
240					let mut evals_0 = layer;
241					let evals_1 = evals_0.split_off(1 << (n_vars - P::LOG_WIDTH));
242					[evals_0, evals_1]
243				} else {
244					let mut evals_0 = P::zero();
245					let mut evals_1 = P::zero();
246					let only_packed = layer.first().copied().unwrap_or_else(P::one);
247
248					for i in 0..1 << n_vars {
249						evals_0.set(i, only_packed.get(i));
250						evals_1.set(i, only_packed.get(i | 1 << n_vars));
251					}
252
253					[vec![evals_0], vec![evals_1]]
254				};
255
256			for multilinear in multilinear_pair {
257				let suffix_len = (1usize << n_vars).saturating_sub(multilinear.len() * P::WIDTH);
258				const_suffixes.push((F::ONE, suffix_len));
259				multilinears.push(multilinear);
260			}
261		}
262
263		let prover = EqIndSumcheckProverBuilder::without_switchover(n_vars, multilinears, backend)
264			.with_const_suffixes(&const_suffixes)?
265			.build(
266				evaluation_order,
267				eq_ind_challenges,
268				composite_claims,
269				evaluation_domain_factory,
270			)?;
271
272		Ok(prover)
273	}
274
275	fn update_layer_eval(&mut self, zero_eval: F, one_eval: F, gpa_challenge: F) {
276		self.layer_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
277	}
278
279	fn finalize(self, eval_point: &[F]) -> Result<LayerClaim<F>, Error> {
280		if !self.remaining_layers.is_empty() {
281			bail!(Error::PrematureFinalize);
282		}
283
284		Ok(LayerClaim {
285			eval_point: eval_point.to_vec(),
286			eval: self.layer_eval,
287		})
288	}
289}