binius_core/protocols/gkr_exp/
batch_prove.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::{EvaluationDomainFactory, EvaluationOrder};
6use binius_utils::{bail, sorting::is_sorted_ascending};
7use itertools::izip;
8use tracing::instrument;
9
10use super::{
11	common::{BaseExpReductionOutput, ExpClaim, GKRExpProver, GKRExpProverBuilder, LayerClaim},
12	compositions::IndexedExpComposition,
13	error::Error,
14	provers::{
15		CompositeSumClaimWithMultilinears, DynamicBaseExpProver, ExpProver, StaticExpProver,
16	},
17	witness::BaseExpWitness,
18};
19use crate::{
20	fiat_shamir::Challenger,
21	protocols::sumcheck::{
22		self, BatchSumcheckOutput, CompositeSumClaim, immediate_switchover_heuristic,
23	},
24	transcript::ProverTranscript,
25	witness::MultilinearWitness,
26};
27
28/// Prove a batched GKR exponentiation protocol execution.
29///
30/// The protocol can be batched over multiple instances by grouping consecutive provers over
31/// `eval_points` in internal `LayerClaims` into `GkrExpProvers`. To achieve this, we use
32/// [`crate::composition::IndexComposition`]. Since exponents can have different bit sizes,
33/// resulting in a varying number of layers, we group them starting from the first layer to maximize
34/// the opportunity to share the same evaluation point.
35///
36/// # Requirements
37/// - Witnesses and claims must be in the same order as in [`super::batch_verify`] during proof
38///   verification.
39/// - Witnesses and claims must be sorted in descending order by n_vars.
40/// - Witnesses and claims must be of the same length.
41/// - The `i`th witness must correspond to the `i`th claim.
42///
43/// # Recommendations
44/// - Witnesses and claims should be grouped by evaluation points from the claims.
45#[instrument(skip_all, name = "gkr_exp::batch_prove")]
46pub fn batch_prove<'a, F, P, FDomain, Challenger_, Backend>(
47	evaluation_order: EvaluationOrder,
48	witnesses: impl IntoIterator<Item = BaseExpWitness<'a, P>>,
49	claims: &[ExpClaim<F>],
50	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
51	transcript: &mut ProverTranscript<Challenger_>,
52	backend: &Backend,
53) -> Result<BaseExpReductionOutput<F>, Error>
54where
55	F: ExtensionField<FDomain> + TowerField,
56	FDomain: Field,
57	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
58	Backend: ComputationBackend,
59	Challenger_: Challenger,
60{
61	let witnesses = witnesses.into_iter().collect::<Vec<_>>();
62
63	if witnesses.len() != claims.len() {
64		bail!(Error::MismatchedWitnessClaimLength);
65	}
66
67	let mut layers_claims = Vec::new();
68
69	if witnesses.is_empty() {
70		return Ok(BaseExpReductionOutput { layers_claims });
71	}
72
73	// Check that the witnesses are in descending order by n_vars
74	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
75		bail!(Error::ClaimsOutOfOrder);
76	}
77
78	let mut provers = make_provers(witnesses, claims)?;
79
80	let max_exponent_bit_number = provers
81		.iter()
82		.map(|p| p.exponent_bit_width())
83		.max()
84		.unwrap_or(0);
85
86	for layer_no in 0..max_exponent_bit_number {
87		let _layer_span = tracing::info_span!(
88			"[task] GKR Exp Layer Sumcheck",
89			phase = "exp",
90			perfetto_category = "task.main"
91		)
92		.entered();
93		let gkr_sumcheck_provers = build_layer_gkr_sumcheck_provers(
94			evaluation_order,
95			&mut provers,
96			layer_no,
97			evaluation_domain_factory.clone(),
98			backend,
99		)?;
100
101		let sumcheck_proof_output = sumcheck::batch_prove(gkr_sumcheck_provers, transcript)?;
102
103		let layer_exponent_claims = build_layer_exponent_bit_claims(
104			evaluation_order,
105			&mut provers,
106			sumcheck_proof_output,
107			layer_no,
108		)?;
109
110		layers_claims.push(layer_exponent_claims);
111
112		provers.retain(|prover| !prover.is_last_layer(layer_no));
113	}
114
115	Ok(BaseExpReductionOutput { layers_claims })
116}
117
118type GKRExpProvers<'a, F, P, FDomain, Backend> =
119	Vec<GKRExpProver<'a, FDomain, P, IndexedExpComposition<F>, MultilinearWitness<'a, P>, Backend>>;
120
121/// Groups consecutive provers by their `eval_point` and reduces them to sumcheck provers.
122#[instrument(skip_all, level = "debug")]
123fn build_layer_gkr_sumcheck_provers<'a, P, FDomain, Backend>(
124	evaluation_order: EvaluationOrder,
125	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
126	layer_no: usize,
127	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
128	backend: &'a Backend,
129) -> Result<GKRExpProvers<'a, P::Scalar, P, FDomain, Backend>, Error>
130where
131	FDomain: Field,
132	P: PackedField + PackedExtension<FDomain>,
133	P::Scalar: TowerField + ExtensionField<FDomain>,
134	Backend: ComputationBackend,
135{
136	assert!(!provers.is_empty());
137
138	let mut composite_claims = Vec::new();
139	let mut multilinears = Vec::new();
140
141	let first_eval_point = provers[0].layer_claim_eval_point().to_vec();
142	let mut eval_points = vec![first_eval_point];
143
144	let mut active_index = 0;
145
146	// group provers by evaluation points and build composite sum claims.
147	for i in 0..provers.len() {
148		if provers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
149			let CompositeSumClaimsWithMultilinears {
150				composite_claims: eval_point_composite_claims,
151				multilinears: eval_point_multilinears,
152			} = build_eval_point_claims::<P>(&mut provers[active_index..i], layer_no)?;
153
154			if eval_point_composite_claims.is_empty() {
155				// extract the last point because provers with this point will not participate in
156				// the sumcheck.
157				eval_points.pop();
158			} else {
159				composite_claims.push(eval_point_composite_claims);
160				multilinears.push(eval_point_multilinears);
161			}
162
163			eval_points.push(provers[i].layer_claim_eval_point().to_vec());
164			active_index = i;
165		}
166
167		if i == provers.len() - 1 {
168			let CompositeSumClaimsWithMultilinears {
169				composite_claims: eval_point_composite_claims,
170				multilinears: eval_point_multilinears,
171			} = build_eval_point_claims::<P>(&mut provers[active_index..], layer_no)?;
172
173			if !eval_point_composite_claims.is_empty() {
174				composite_claims.push(eval_point_composite_claims);
175				multilinears.push(eval_point_multilinears);
176			}
177		}
178	}
179
180	izip!(composite_claims, multilinears, eval_points)
181		.map(|(composite_claims, multilinears, eval_point)| {
182			GKRExpProverBuilder::<'a, P, _, Backend>::with_switchover(
183				multilinears,
184				immediate_switchover_heuristic,
185				backend,
186			)?
187			.build(
188				evaluation_order,
189				&eval_point,
190				composite_claims,
191				evaluation_domain_factory.clone(),
192			)
193		})
194		.collect::<Result<Vec<_>, _>>()
195		.map_err(Error::from)
196}
197
198struct CompositeSumClaimsWithMultilinears<'a, P: PackedField> {
199	composite_claims: Vec<CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>>,
200	multilinears: Vec<MultilinearWitness<'a, P>>,
201}
202
203/// Builds composite claims and multilinears for provers that share the same `eval_point` from their
204/// internal [LayerClaim]s.
205fn build_eval_point_claims<'a, P>(
206	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
207	layer_no: usize,
208) -> Result<CompositeSumClaimsWithMultilinears<'a, P>, Error>
209where
210	P: PackedField,
211{
212	let (composite_claims_n_multilinears, n_claims) =
213		provers
214			.iter()
215			.fold((0, 0), |(n_multilinears, n_claims), prover| {
216				let layer_n_multilinears = prover.layer_n_multilinears(layer_no);
217				let layer_n_claims = prover.layer_n_claims(layer_no);
218
219				(n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
220			});
221
222	let mut multilinears = Vec::with_capacity(composite_claims_n_multilinears);
223
224	let mut composite_claims = Vec::with_capacity(n_claims);
225
226	for prover in provers {
227		let multilinears_index = multilinears.len();
228
229		let meta = prover.layer_composite_sum_claim(
230			layer_no,
231			composite_claims_n_multilinears,
232			multilinears_index,
233		)?;
234
235		if let Some(meta) = meta {
236			let CompositeSumClaimWithMultilinears {
237				claim,
238				multilinears: this_layer_multilinears,
239			} = meta;
240
241			composite_claims.push(claim);
242
243			multilinears.extend(this_layer_multilinears);
244		}
245	}
246	Ok(CompositeSumClaimsWithMultilinears {
247		composite_claims,
248		multilinears,
249	})
250}
251
252/// Reduces the sumcheck output to [LayerClaim]s and updates the internal provers [LayerClaim]s for
253/// the next layer.
254fn build_layer_exponent_bit_claims<'a, P>(
255	evaluation_order: EvaluationOrder,
256	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
257	mut sumcheck_output: BatchSumcheckOutput<P::Scalar>,
258	layer_no: usize,
259) -> Result<Vec<LayerClaim<P::Scalar>>, Error>
260where
261	P: PackedField,
262{
263	let mut eval_claims_on_exponent_bit_columns = Vec::new();
264
265	// extract eq_ind_evals
266	for multilinear_evals in &mut sumcheck_output.multilinear_evals {
267		multilinear_evals.pop();
268	}
269
270	let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
271
272	for prover in provers {
273		let this_prover_n_multilinears = prover.layer_n_multilinears(layer_no);
274
275		let this_prover_multilinear_evals = multilinear_evals
276			.by_ref()
277			.take(this_prover_n_multilinears)
278			.collect::<Vec<_>>();
279
280		let exponent_bit_claims = prover.finish_layer(
281			evaluation_order,
282			layer_no,
283			&this_prover_multilinear_evals,
284			&sumcheck_output.challenges,
285		);
286
287		eval_claims_on_exponent_bit_columns.extend(exponent_bit_claims);
288	}
289
290	Ok(eval_claims_on_exponent_bit_columns)
291}
292
293/// Creates a vector of boxed [ExpProver]s from the given witnesses and claims.
294fn make_provers<'a, P>(
295	witnesses: Vec<BaseExpWitness<'a, P>>,
296	claims: &[ExpClaim<P::Scalar>],
297) -> Result<Vec<Box<dyn ExpProver<'a, P> + 'a>>, Error>
298where
299	P: PackedField,
300	P::Scalar: BinaryField,
301{
302	witnesses
303		.into_iter()
304		.zip(claims)
305		.map(|(witness, claim)| {
306			if witness.uses_dynamic_base() {
307				DynamicBaseExpProver::new(witness, claim)
308					.map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
309			} else {
310				StaticExpProver::<'a, P>::new(witness, claim)
311					.map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
312			}
313		})
314		.collect::<Result<Vec<_>, Error>>()
315}