binius_core/protocols/gkr_gpa/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{
6	extrapolate_line_scalar, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
7	MultilinearExtension, MultilinearPoly,
8};
9use binius_utils::{
10	bail,
11	sorting::{stable_sort, unsort},
12};
13use tracing::instrument;
14
15use super::{
16	gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
17	gpa_sumcheck::prove::GPAProver,
18	packed_field_storage::PackedFieldStorage,
19	Error, GrandProductClaim, GrandProductWitness,
20};
21use crate::{
22	composition::{BivariateProduct, IndexComposition},
23	fiat_shamir::{CanSample, Challenger},
24	protocols::sumcheck::{self, CompositeSumClaim},
25	transcript::ProverTranscript,
26};
27
28/// Proves batch reduction turning each GrandProductClaim into an EvalcheckMultilinearClaim
29///
30/// REQUIRES:
31/// * witnesses and claims are of the same length
32/// * The ith witness corresponds to the ith claim
33#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
34pub fn batch_prove<F, P, FDomain, Challenger_, Backend>(
35	evaluation_order: EvaluationOrder,
36	witnesses: impl IntoIterator<Item = GrandProductWitness<P>>,
37	claims: &[GrandProductClaim<F>],
38	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
39	transcript: &mut ProverTranscript<Challenger_>,
40	backend: &Backend,
41) -> Result<GrandProductBatchProveOutput<F>, Error>
42where
43	F: TowerField,
44	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
45	FDomain: Field,
46	Challenger_: Challenger,
47	Backend: ComputationBackend,
48{
49	//  Ensure witnesses and claims are of the same length, zip them together
50	// 	For each witness-claim pair, create GrandProductProver
51	let witness_vec = witnesses.into_iter().collect::<Vec<_>>();
52
53	let n_claims = claims.len();
54	if n_claims == 0 {
55		return Ok(GrandProductBatchProveOutput::default());
56	}
57	if witness_vec.len() != n_claims {
58		bail!(Error::MismatchedWitnessClaimLength);
59	}
60
61	// Create a vector of GrandProductProverStates
62	let provers_vec = witness_vec
63		.iter()
64		.zip(claims)
65		.map(|(witness, claim)| GrandProductProverState::new(claim, witness, backend))
66		.collect::<Result<Vec<_>, _>>()?;
67
68	let (original_indices, mut sorted_provers) =
69		stable_sort(provers_vec, |prover| prover.input_vars(), true);
70
71	let max_n_vars = sorted_provers
72		.first()
73		.expect("sorted_provers is not empty by invariant")
74		.input_vars();
75
76	let mut reverse_sorted_final_layer_claims = Vec::with_capacity(n_claims);
77
78	for layer_no in 0..max_n_vars {
79		// Step 1: Process finished provers
80		process_finished_provers(
81			layer_no,
82			&mut sorted_provers,
83			&mut reverse_sorted_final_layer_claims,
84		)?;
85
86		// Now we must create the batch layer proof for the kth to k+1th layer reduction
87
88		// Step 2: Create sumcheck batch proof
89		let batch_sumcheck_output = {
90			let gpa_sumcheck_prover = GrandProductProverState::stage_gpa_sumcheck_provers(
91				evaluation_order,
92				&sorted_provers,
93				evaluation_domain_factory.clone(),
94			)?;
95
96			sumcheck::batch_prove(vec![gpa_sumcheck_prover], transcript)?
97		};
98
99		// Step 3: Sample a challenge for the next layer
100		let gpa_challenge = transcript.sample();
101
102		// Step 4: Finalize each prover to update its internal current_layer_claim
103		for (i, prover) in sorted_provers.iter_mut().enumerate() {
104			prover.finalize_batch_layer_proof(
105				batch_sumcheck_output.multilinear_evals[0][2 * i],
106				batch_sumcheck_output.multilinear_evals[0][2 * i + 1],
107				batch_sumcheck_output.challenges.clone(),
108				gpa_challenge,
109			)?;
110		}
111	}
112	process_finished_provers(
113		max_n_vars,
114		&mut sorted_provers,
115		&mut reverse_sorted_final_layer_claims,
116	)?;
117
118	debug_assert!(sorted_provers.is_empty());
119	debug_assert_eq!(reverse_sorted_final_layer_claims.len(), n_claims);
120
121	reverse_sorted_final_layer_claims.reverse();
122	let sorted_final_layer_claim = reverse_sorted_final_layer_claims;
123
124	let final_layer_claims = unsort(original_indices, sorted_final_layer_claim);
125
126	Ok(GrandProductBatchProveOutput { final_layer_claims })
127}
128
129fn process_finished_provers<F, P, Backend>(
130	layer_no: usize,
131	sorted_provers: &mut Vec<GrandProductProverState<'_, F, P, Backend>>,
132	reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
133) -> Result<(), Error>
134where
135	F: TowerField,
136	P: PackedField<Scalar = F>,
137	Backend: ComputationBackend,
138{
139	while let Some(prover) = sorted_provers.last() {
140		if prover.input_vars() != layer_no {
141			break;
142		}
143		debug_assert!(layer_no > 0);
144		let finished_prover = sorted_provers.pop().expect("not empty");
145		let final_layer_claim = finished_prover.finalize()?;
146		reverse_sorted_final_layer_claims.push(final_layer_claim);
147	}
148
149	Ok(())
150}
151
152/// GPA protocol prover state
153///
154/// Coordinates the proving of a grand product claim before and after
155/// the sumcheck-based layer reductions.
156#[derive(Debug)]
157struct GrandProductProverState<'a, F, P, Backend>
158where
159	F: Field + From<P::Scalar>,
160	P: PackedField,
161	P::Scalar: Field + From<F>,
162	Backend: ComputationBackend,
163{
164	n_vars: usize,
165	// Layers of the product circuit as multilinear polynomials
166	// The ith element is the ith layer of the product circuit
167	layers: Vec<MLEDirectAdapter<P, PackedFieldStorage<'a, P>>>,
168	// The ith element consists of a tuple of the
169	// first and second halves of the (i+1)th layer of the product circuit
170	next_layer_halves: Vec<[MLEDirectAdapter<P, PackedFieldStorage<'a, P>>; 2]>,
171	// The current claim about a layer multilinear of the product circuit
172	current_layer_claim: LayerClaim<F>,
173
174	backend: Backend,
175}
176
177impl<'a, F, P, Backend> GrandProductProverState<'a, F, P, Backend>
178where
179	F: TowerField + From<P::Scalar>,
180	P: PackedField<Scalar = F>,
181	Backend: ComputationBackend,
182{
183	/// Create a new GrandProductProverState
184	fn new(
185		claim: &GrandProductClaim<F>,
186		witness: &'a GrandProductWitness<P>,
187		backend: Backend,
188	) -> Result<Self, Error> {
189		let n_vars = claim.n_vars;
190		if n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
191			bail!(Error::ProverClaimWitnessMismatch);
192		}
193
194		// Build multilinear polynomials from circuit evaluations
195		let n_layers = n_vars + 1;
196		let next_layer_halves = (1..n_layers)
197			.map(|i| {
198				let (left_evals, right_evals) = witness.ith_layer_eval_halves(i)?;
199				let left = MultilinearExtension::try_from(left_evals)?;
200				let right = MultilinearExtension::try_from(right_evals)?;
201				Ok([left, right].map(MLEDirectAdapter::from))
202			})
203			.collect::<Result<Vec<_>, Error>>()?;
204
205		let layers = (0..n_layers)
206			.map(|i| {
207				let ith_layer_evals = witness.ith_layer_evals(i)?;
208				let ith_layer_evals = if P::LOG_WIDTH < i {
209					PackedFieldStorage::from(ith_layer_evals)
210				} else {
211					debug_assert_eq!(ith_layer_evals.len(), 1);
212					PackedFieldStorage::new_inline(ith_layer_evals[0].iter().take(1 << i))
213						.expect("length is a power of 2")
214				};
215
216				let mle = MultilinearExtension::try_from(ith_layer_evals)?;
217				Ok(mle.into())
218			})
219			.collect::<Result<Vec<_>, Error>>()?;
220
221		debug_assert_eq!(next_layer_halves.len(), n_vars);
222		debug_assert_eq!(layers.len(), n_vars + 1);
223
224		// Initialize Layer Claim
225		let layer_claim = LayerClaim {
226			eval_point: vec![],
227			eval: claim.product,
228		};
229
230		// Return new GrandProductProver and the common product
231		Ok(Self {
232			n_vars,
233			next_layer_halves,
234			layers,
235			current_layer_claim: layer_claim,
236			backend,
237		})
238	}
239
240	const fn input_vars(&self) -> usize {
241		self.n_vars
242	}
243
244	fn current_layer_no(&self) -> usize {
245		self.current_layer_claim.eval_point.len()
246	}
247
248	#[allow(clippy::type_complexity)]
249	#[instrument(skip_all, level = "debug")]
250	fn stage_gpa_sumcheck_provers<FDomain>(
251		evaluation_order: EvaluationOrder,
252		provers: &[Self],
253		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
254	) -> Result<
255		GPAProver<
256			FDomain,
257			P,
258			IndexComposition<BivariateProduct, 2>,
259			impl MultilinearPoly<P> + Send + Sync + 'a,
260			Backend,
261		>,
262		Error,
263	>
264	where
265		FDomain: Field,
266		P: PackedExtension<FDomain>,
267	{
268		// test same layer
269		let Some(first_prover) = provers.first() else {
270			unreachable!();
271		};
272
273		// construct witness
274		let n_claims = provers.len();
275		let n_multilinears = provers.len() * 2;
276		let current_layer_no = first_prover.current_layer_no();
277
278		let mut composite_claims = Vec::with_capacity(n_claims);
279		let mut multilinears = Vec::with_capacity(n_multilinears);
280
281		for (i, prover) in provers.iter().enumerate() {
282			let indices = [2 * i, 2 * i + 1];
283
284			let composite_claim = CompositeSumClaim {
285				sum: prover.current_layer_claim.eval,
286				composition: IndexComposition::new(n_multilinears, indices, BivariateProduct {})?,
287			};
288
289			composite_claims.push(composite_claim);
290			multilinears.extend(prover.next_layer_halves[current_layer_no].clone());
291		}
292
293		let first_layer_mle_advice = provers
294			.iter()
295			.map(|prover| prover.layers[current_layer_no].clone())
296			.collect::<Vec<_>>();
297
298		Ok(GPAProver::new(
299			evaluation_order,
300			multilinears,
301			Some(first_layer_mle_advice),
302			composite_claims,
303			evaluation_domain_factory,
304			&first_prover.current_layer_claim.eval_point,
305			&first_prover.backend,
306		)?)
307	}
308
309	fn finalize_batch_layer_proof(
310		&mut self,
311		zero_eval: F,
312		one_eval: F,
313		sumcheck_challenge: Vec<F>,
314		gpa_challenge: F,
315	) -> Result<(), Error> {
316		if self.current_layer_no() >= self.input_vars() {
317			bail!(Error::TooManyRounds);
318		}
319		let new_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
320		let mut layer_challenge = sumcheck_challenge;
321		layer_challenge.push(gpa_challenge);
322
323		self.current_layer_claim = LayerClaim {
324			eval_point: layer_challenge,
325			eval: new_eval,
326		};
327
328		Ok(())
329	}
330
331	fn finalize(self) -> Result<LayerClaim<F>, Error> {
332		if self.current_layer_no() != self.input_vars() {
333			bail!(Error::PrematureFinalize);
334		}
335
336		let final_layer_claim = LayerClaim {
337			eval_point: self.current_layer_claim.eval_point,
338			eval: self.current_layer_claim.eval,
339		};
340		Ok(final_layer_claim)
341	}
342}