binius_core/protocols/gkr_gpa/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder, extrapolate_line_scalar};
6use binius_utils::{
7	bail,
8	sorting::{stable_sort, unsort},
9};
10use itertools::izip;
11use tracing::instrument;
12
13use super::{
14	Error, GrandProductClaim, GrandProductWitness,
15	gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
16};
17use crate::{
18	composition::{BivariateProduct, IndexComposition},
19	fiat_shamir::{CanSample, Challenger},
20	protocols::sumcheck::{
21		BatchSumcheckOutput, CompositeSumClaim,
22		prove::{SumcheckProver, eq_ind::EqIndSumcheckProverBuilder, front_loaded},
23	},
24	transcript::ProverTranscript,
25};
26
27/// Proves batch reduction turning each GrandProductClaim into LayerClaim on original multilinear.
28///
29/// REQUIRES:
30/// * witnesses and claims are of the same length
31/// * The ith witness corresponds to the ith claim
32#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
33pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
34	evaluation_order: EvaluationOrder,
35	witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
36	claims: &[GrandProductClaim<F>],
37	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
38	transcript: &mut ProverTranscript<Challenger_>,
39	backend: &Backend,
40) -> Result<GrandProductBatchProveOutput<F>, Error>
41where
42	F: TowerField,
43	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
44	FDomain: Field,
45	Challenger_: Challenger,
46	Backend: ComputationBackend,
47{
48	//  Ensure witnesses and claims are of the same length, zip them together
49	// 	For each witness-claim pair, create GrandProductProverState
50	let witnesses = witnesses.into_iter().collect::<Vec<_>>();
51
52	if witnesses.len() != claims.len() {
53		bail!(Error::MismatchedWitnessClaimLength);
54	}
55
56	// Create a vector of GrandProductProverStates
57	let prover_states = izip!(witnesses, claims)
58		.map(|(witness, claim)| GrandProductProverState::new(claim, witness))
59		.collect::<Result<Vec<_>, _>>()?;
60
61	let (original_indices, mut sorted_prover_states) =
62		stable_sort(prover_states, |state| state.remaining_layers.len(), true);
63
64	let mut reverse_sorted_final_layer_claims = Vec::with_capacity(claims.len());
65	let mut eval_point = Vec::new();
66
67	loop {
68		// Step 1: Process finished provers
69		process_finished_provers(
70			&mut sorted_prover_states,
71			&mut reverse_sorted_final_layer_claims,
72			&eval_point,
73		)?;
74
75		if sorted_prover_states.is_empty() {
76			break;
77		}
78
79		// Now we must create the batch layer proof for the kth to k+1th layer reduction
80
81		// Step 2: Create sumcheck batch proof
82		let BatchSumcheckOutput {
83			challenges,
84			multilinear_evals,
85		} = {
86			let _layer_span = tracing::info_span!(
87				"[task] GKR GPA Layer Sumcheck",
88				phase = "exp",
89				perfetto_category = "task.main"
90			)
91			.entered();
92
93			let eq_ind_sumcheck_prover = GrandProductProverState::stage_sumcheck_provers(
94				evaluation_order,
95				&mut sorted_prover_states,
96				evaluation_domain_factory.clone(),
97				&eval_point,
98				backend,
99			)?;
100
101			let batch_sumcheck_prover =
102				front_loaded::BatchProver::new(vec![eq_ind_sumcheck_prover], transcript)?;
103
104			let mut batch_output = batch_sumcheck_prover.run(transcript)?;
105
106			if evaluation_order == EvaluationOrder::HighToLow {
107				batch_output.challenges.reverse();
108			}
109
110			batch_output
111		};
112
113		// Step 3: Sample a challenge for the next layer
114		let gpa_challenge = transcript.sample();
115
116		eval_point.copy_from_slice(&challenges);
117		eval_point.push(gpa_challenge);
118
119		// Step 4: Finalize each prover to update its internal current_layer_claim
120		debug_assert_eq!(multilinear_evals.len(), 1);
121		let multilinear_evals = multilinear_evals
122			.first()
123			.expect("exactly one prover in a batch");
124		for (state, evals) in izip!(&mut sorted_prover_states, multilinear_evals.chunks_exact(2)) {
125			state.update_layer_eval(evals[0], evals[1], gpa_challenge);
126		}
127	}
128	process_finished_provers(
129		&mut sorted_prover_states,
130		&mut reverse_sorted_final_layer_claims,
131		&eval_point,
132	)?;
133
134	debug_assert!(sorted_prover_states.is_empty());
135	debug_assert_eq!(reverse_sorted_final_layer_claims.len(), claims.len());
136
137	reverse_sorted_final_layer_claims.reverse();
138	let sorted_final_layer_claims = reverse_sorted_final_layer_claims;
139
140	let final_layer_claims = unsort(original_indices, sorted_final_layer_claims);
141	Ok(GrandProductBatchProveOutput { final_layer_claims })
142}
143
144fn process_finished_provers<F, P>(
145	sorted_prover_states: &mut Vec<GrandProductProverState<P>>,
146	reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
147	eval_point: &[F],
148) -> Result<(), Error>
149where
150	F: TowerField,
151	P: PackedField<Scalar = F>,
152{
153	let first_finished =
154		sorted_prover_states.partition_point(|state| !state.remaining_layers.is_empty());
155
156	for state in sorted_prover_states.drain(first_finished..).rev() {
157		reverse_sorted_final_layer_claims.push(state.finalize(eval_point)?);
158	}
159
160	Ok(())
161}
162
163/// GPA protocol state for a single witness
164///
165/// Coordinates the proving of a grand product claim before and after
166/// the sumcheck-based layer reductions.
167#[derive(Debug)]
168struct GrandProductProverState<P>
169where
170	P: PackedField<Scalar: TowerField>,
171{
172	// Remaining layers of the product circuit, ordered from largest to smallest.
173	// Each step removes the last layer.
174	remaining_layers: Vec<Vec<P>>,
175	// The current eval claim (on a shared eval point).
176	layer_eval: P::Scalar,
177}
178
179impl<F, P> GrandProductProverState<P>
180where
181	F: TowerField,
182	P: PackedField<Scalar = F>,
183{
184	/// Create a new GrandProductProverState
185	fn new(claim: &GrandProductClaim<F>, witness: GrandProductWitness<P>) -> Result<Self, Error> {
186		if claim.n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
187			bail!(Error::ProverClaimWitnessMismatch);
188		}
189
190		let mut remaining_layers = witness.into_circuit_layers();
191		debug_assert_eq!(remaining_layers.len(), claim.n_vars + 1);
192		let _ = remaining_layers
193			.pop()
194			.expect("remaining_layers cannot be empty");
195
196		// Initialize Layer Claim
197		let layer_eval = claim.product;
198
199		// Return new GrandProductProver and the common product
200		Ok(Self {
201			remaining_layers,
202			layer_eval,
203		})
204	}
205
206	#[allow(clippy::type_complexity)]
207	#[instrument(skip_all, level = "debug")]
208	fn stage_sumcheck_provers<'a, FDomain, Backend>(
209		evaluation_order: EvaluationOrder,
210		states: &mut [Self],
211		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
212		eq_ind_challenges: &[P::Scalar],
213		backend: &'a Backend,
214	) -> Result<impl SumcheckProver<P::Scalar> + 'a, Error>
215	where
216		FDomain: Field,
217		P: PackedExtension<FDomain>,
218		Backend: ComputationBackend,
219	{
220		let n_vars = eq_ind_challenges.len();
221		let n_claims = states.len();
222		let n_multilinears = n_claims * 2;
223
224		let mut composite_claims = Vec::with_capacity(n_claims);
225		let mut multilinears = Vec::with_capacity(n_multilinears);
226		let mut const_suffixes = Vec::with_capacity(n_multilinears);
227
228		for (i, state) in states.iter_mut().enumerate() {
229			let indices = [2 * i, 2 * i + 1];
230
231			let composite_claim = CompositeSumClaim {
232				sum: state.layer_eval,
233				composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
234			};
235
236			composite_claims.push(composite_claim);
237
238			let layer = state
239				.remaining_layers
240				.pop()
241				.expect("not staging more than n_vars times");
242
243			let multilinear_pair =
244				if n_vars >= P::LOG_WIDTH && layer.len() < 1 << (n_vars - P::LOG_WIDTH) {
245					[layer, vec![]]
246				} else if n_vars >= P::LOG_WIDTH {
247					let mut evals_0 = layer;
248					let evals_1 = evals_0.split_off(1 << (n_vars - P::LOG_WIDTH));
249					[evals_0, evals_1]
250				} else {
251					let mut evals_0 = P::zero();
252					let mut evals_1 = P::zero();
253					let only_packed = layer.first().copied().unwrap_or_else(P::one);
254
255					for i in 0..1 << n_vars {
256						evals_0.set(i, only_packed.get(i));
257						evals_1.set(i, only_packed.get(i | 1 << n_vars));
258					}
259
260					[vec![evals_0], vec![evals_1]]
261				};
262
263			for multilinear in multilinear_pair {
264				let suffix_len = (1usize << n_vars).saturating_sub(multilinear.len() * P::WIDTH);
265				const_suffixes.push((F::ONE, suffix_len));
266				multilinears.push(multilinear);
267			}
268		}
269
270		let prover = EqIndSumcheckProverBuilder::without_switchover(n_vars, multilinears, backend)
271			.with_const_suffixes(&const_suffixes)?
272			.build(
273				evaluation_order,
274				eq_ind_challenges,
275				composite_claims,
276				evaluation_domain_factory,
277			)?;
278
279		Ok(prover)
280	}
281
282	fn update_layer_eval(&mut self, zero_eval: F, one_eval: F, gpa_challenge: F) {
283		self.layer_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
284	}
285
286	fn finalize(self, eval_point: &[F]) -> Result<LayerClaim<F>, Error> {
287		if !self.remaining_layers.is_empty() {
288			bail!(Error::PrematureFinalize);
289		}
290
291		Ok(LayerClaim {
292			eval_point: eval_point.to_vec(),
293			eval: self.layer_eval,
294		})
295	}
296}