binius_core/protocols/gkr_exp/
batch_prove.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::{bail, sorting::is_sorted_ascending};
7use itertools::izip;
8use tracing::instrument;
9
10use super::{
11	common::{BaseExpReductionOutput, ExpClaim, GKRExpProver, GKRExpProverBuilder, LayerClaim},
12	compositions::IndexedExpComposition,
13	error::Error,
14	provers::{
15		CompositeSumClaimWithMultilinears, DynamicBaseExpProver, ExpProver, StaticExpProver,
16	},
17	witness::BaseExpWitness,
18};
19use crate::{
20	fiat_shamir::Challenger,
21	protocols::sumcheck::{
22		self, immediate_switchover_heuristic, BatchSumcheckOutput, CompositeSumClaim,
23	},
24	transcript::ProverTranscript,
25	witness::MultilinearWitness,
26};
27
28/// Prove a batched GKR exponentiation protocol execution.
29///
30/// The protocol can be batched over multiple instances by grouping consecutive provers over
31/// `eval_points` in internal `LayerClaims` into `GkrExpProvers`. To achieve this, we use
32/// [`crate::composition::IndexComposition`]. Since exponents can have different bit sizes, resulting
33/// in a varying number of layers, we group them starting from the first layer to maximize the
34/// opportunity to share the same evaluation point.
35///
36/// # Requirements
37/// - Witnesses and claims must be in the same order as in [`super::batch_verify`] during proof verification.
38/// - Witnesses and claims must be sorted in descending order by n_vars.
39/// - Witnesses and claims must be of the same length.
40/// - The `i`th witness must correspond to the `i`th claim.
41///
42/// # Recommendations
43/// - Witnesses and claims should be grouped by evaluation points from the claims.
44#[instrument(skip_all, name = "gkr_exp::batch_prove")]
45pub fn batch_prove<'a, F, P, FDomain, Challenger_, Backend>(
46	evaluation_order: EvaluationOrder,
47	witnesses: impl IntoIterator<Item = BaseExpWitness<'a, P>>,
48	claims: &[ExpClaim<F>],
49	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
50	transcript: &mut ProverTranscript<Challenger_>,
51	backend: &Backend,
52) -> Result<BaseExpReductionOutput<F>, Error>
53where
54	F: ExtensionField<FDomain> + TowerField,
55	FDomain: Field,
56	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
57	Backend: ComputationBackend,
58	Challenger_: Challenger,
59{
60	let witnesses = witnesses.into_iter().collect::<Vec<_>>();
61
62	if witnesses.len() != claims.len() {
63		bail!(Error::MismatchedWitnessClaimLength);
64	}
65
66	let mut layers_claims = Vec::new();
67
68	if witnesses.is_empty() {
69		return Ok(BaseExpReductionOutput { layers_claims });
70	}
71
72	// Check that the witnesses are in descending order by n_vars
73	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
74		bail!(Error::ClaimsOutOfOrder);
75	}
76
77	let mut provers = make_provers(witnesses, claims)?;
78
79	let max_exponent_bit_number = provers
80		.iter()
81		.map(|p| p.exponent_bit_width())
82		.max()
83		.unwrap_or(0);
84
85	for layer_no in 0..max_exponent_bit_number {
86		let gkr_sumcheck_provers = build_layer_gkr_sumcheck_provers(
87			evaluation_order,
88			&mut provers,
89			layer_no,
90			evaluation_domain_factory.clone(),
91			backend,
92		)?;
93
94		let sumcheck_proof_output = sumcheck::batch_prove(gkr_sumcheck_provers, transcript)?;
95
96		let layer_exponent_claims = build_layer_exponent_bit_claims(
97			evaluation_order,
98			&mut provers,
99			sumcheck_proof_output,
100			layer_no,
101		)?;
102
103		layers_claims.push(layer_exponent_claims);
104
105		provers.retain(|prover| !prover.is_last_layer(layer_no));
106	}
107
108	Ok(BaseExpReductionOutput { layers_claims })
109}
110
111type GKRExpProvers<'a, F, P, FDomain, Backend> =
112	Vec<GKRExpProver<'a, FDomain, P, IndexedExpComposition<F>, MultilinearWitness<'a, P>, Backend>>;
113
114/// Groups consecutive provers by their `eval_point` and reduces them to sumcheck provers.
115#[instrument(skip_all, level = "debug")]
116fn build_layer_gkr_sumcheck_provers<'a, P, FDomain, Backend>(
117	evaluation_order: EvaluationOrder,
118	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
119	layer_no: usize,
120	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
121	backend: &'a Backend,
122) -> Result<GKRExpProvers<'a, P::Scalar, P, FDomain, Backend>, Error>
123where
124	FDomain: Field,
125	P: PackedField + PackedExtension<FDomain>,
126	P::Scalar: TowerField + ExtensionField<FDomain>,
127	Backend: ComputationBackend,
128{
129	assert!(!provers.is_empty());
130
131	let mut composite_claims = Vec::new();
132	let mut multilinears = Vec::new();
133
134	let first_eval_point = provers[0].layer_claim_eval_point().to_vec();
135	let mut eval_points = vec![first_eval_point];
136
137	let mut active_index = 0;
138
139	// group provers by evaluation points and build composite sum claims.
140	for i in 0..provers.len() {
141		if provers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
142			let CompositeSumClaimsWithMultilinears {
143				composite_claims: eval_point_composite_claims,
144				multilinears: eval_point_multilinears,
145			} = build_eval_point_claims::<P>(&mut provers[active_index..i], layer_no)?;
146
147			if eval_point_composite_claims.is_empty() {
148				// extract the last point because provers with this point will not participate in the sumcheck.
149				eval_points.pop();
150			} else {
151				composite_claims.push(eval_point_composite_claims);
152				multilinears.push(eval_point_multilinears);
153			}
154
155			eval_points.push(provers[i].layer_claim_eval_point().to_vec());
156			active_index = i;
157		}
158
159		if i == provers.len() - 1 {
160			let CompositeSumClaimsWithMultilinears {
161				composite_claims: eval_point_composite_claims,
162				multilinears: eval_point_multilinears,
163			} = build_eval_point_claims::<P>(&mut provers[active_index..], layer_no)?;
164
165			if !eval_point_composite_claims.is_empty() {
166				composite_claims.push(eval_point_composite_claims);
167				multilinears.push(eval_point_multilinears);
168			}
169		}
170	}
171
172	izip!(composite_claims, multilinears, eval_points)
173		.map(|(composite_claims, multilinears, eval_point)| {
174			GKRExpProverBuilder::<'a, P, _, Backend>::with_switchover(
175				multilinears,
176				immediate_switchover_heuristic,
177				backend,
178			)?
179			.build(
180				evaluation_order,
181				&eval_point,
182				composite_claims,
183				evaluation_domain_factory.clone(),
184			)
185		})
186		.collect::<Result<Vec<_>, _>>()
187		.map_err(Error::from)
188}
189
190struct CompositeSumClaimsWithMultilinears<'a, P: PackedField> {
191	composite_claims: Vec<CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>>,
192	multilinears: Vec<MultilinearWitness<'a, P>>,
193}
194
195/// Builds composite claims and multilinears for provers that share the same `eval_point` from their internal [LayerClaim]s.
196fn build_eval_point_claims<'a, P>(
197	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
198	layer_no: usize,
199) -> Result<CompositeSumClaimsWithMultilinears<'a, P>, Error>
200where
201	P: PackedField,
202{
203	let (composite_claims_n_multilinears, n_claims) =
204		provers
205			.iter()
206			.fold((0, 0), |(n_multilinears, n_claims), prover| {
207				let layer_n_multilinears = prover.layer_n_multilinears(layer_no);
208				let layer_n_claims = prover.layer_n_claims(layer_no);
209
210				(n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
211			});
212
213	let mut multilinears = Vec::with_capacity(composite_claims_n_multilinears);
214
215	let mut composite_claims = Vec::with_capacity(n_claims);
216
217	for prover in provers {
218		let multilinears_index = multilinears.len();
219
220		let meta = prover.layer_composite_sum_claim(
221			layer_no,
222			composite_claims_n_multilinears,
223			multilinears_index,
224		)?;
225
226		if let Some(meta) = meta {
227			let CompositeSumClaimWithMultilinears {
228				claim,
229				multilinears: this_layer_multilinears,
230			} = meta;
231
232			composite_claims.push(claim);
233
234			multilinears.extend(this_layer_multilinears);
235		}
236	}
237	Ok(CompositeSumClaimsWithMultilinears {
238		composite_claims,
239		multilinears,
240	})
241}
242
243/// Reduces the sumcheck output to [LayerClaim]s and updates the internal provers [LayerClaim]s for the next layer.
244fn build_layer_exponent_bit_claims<'a, P>(
245	evaluation_order: EvaluationOrder,
246	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
247	mut sumcheck_output: BatchSumcheckOutput<P::Scalar>,
248	layer_no: usize,
249) -> Result<Vec<LayerClaim<P::Scalar>>, Error>
250where
251	P: PackedField,
252{
253	let mut eval_claims_on_exponent_bit_columns = Vec::new();
254
255	// extract eq_ind_evals
256	for multilinear_evals in &mut sumcheck_output.multilinear_evals {
257		multilinear_evals.pop();
258	}
259
260	let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
261
262	for prover in provers {
263		let this_prover_n_multilinears = prover.layer_n_multilinears(layer_no);
264
265		let this_prover_multilinear_evals = multilinear_evals
266			.by_ref()
267			.take(this_prover_n_multilinears)
268			.collect::<Vec<_>>();
269
270		let exponent_bit_claims = prover.finish_layer(
271			evaluation_order,
272			layer_no,
273			&this_prover_multilinear_evals,
274			&sumcheck_output.challenges,
275		);
276
277		eval_claims_on_exponent_bit_columns.extend(exponent_bit_claims);
278	}
279
280	Ok(eval_claims_on_exponent_bit_columns)
281}
282
283/// Creates a vector of boxed [ExpProver]s from the given witnesses and claims.
284fn make_provers<'a, P>(
285	witnesses: Vec<BaseExpWitness<'a, P>>,
286	claims: &[ExpClaim<P::Scalar>],
287) -> Result<Vec<Box<dyn ExpProver<'a, P> + 'a>>, Error>
288where
289	P: PackedField,
290	P::Scalar: BinaryField,
291{
292	witnesses
293		.into_iter()
294		.zip(claims)
295		.map(|(witness, claim)| {
296			if witness.uses_dynamic_base() {
297				DynamicBaseExpProver::new(witness, claim)
298					.map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
299			} else {
300				StaticExpProver::<'a, P>::new(witness, claim)
301					.map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
302			}
303		})
304		.collect::<Result<Vec<_>, Error>>()
305}