binius_core/protocols/gkr_exp/
batch_prove.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{
4	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
5	TowerField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::{EvaluationDomainFactory, EvaluationOrder};
9use binius_utils::{bail, sorting::is_sorted_ascending};
10use itertools::izip;
11use tracing::instrument;
12
13use super::{
14	common::{BaseExpReductionOutput, ExpClaim, GKRExpProver, LayerClaim},
15	compositions::ProverExpComposition,
16	error::Error,
17	provers::{
18		CompositeSumClaimWithMultilinears, DynamicBaseExpProver, ExpProver, GeneratorExpProver,
19	},
20	witness::BaseExpWitness,
21};
22use crate::{
23	fiat_shamir::Challenger,
24	protocols::sumcheck::{self, BatchSumcheckOutput, CompositeSumClaim},
25	transcript::ProverTranscript,
26	witness::MultilinearWitness,
27};
28
29/// Prove a batched GKR exponentiation protocol execution.
30///
31/// The protocol can be batched over multiple instances by grouping consecutive provers over
32/// `eval_points` in internal `LayerClaims` into `GkrExpProvers`. To achieve this, we use
33/// [`crate::composition::IndexComposition`]. Since exponents can have different bit sizes, resulting
34/// in a varying number of layers, we group them starting from the first layer to maximize the
35/// opportunity to share the same evaluation point.
36///
37/// # Requirements
38/// - Witnesses and claims must be in the same order as in [`super::batch_verify`] during proof verification.
39/// - Witnesses and claims must be sorted in descending order by n_vars.
40/// - Witnesses and claims must be of the same length.
41/// - The `i`th witness must correspond to the `i`th claim.
42///
43/// # Recommendations
44/// - Witnesses and claims should be grouped by evaluation points from the claims.
45pub fn batch_prove<'a, FBase, F, P, FDomain, Challenger_, Backend>(
46	evaluation_order: EvaluationOrder,
47	witnesses: impl IntoIterator<Item = BaseExpWitness<'a, P, FBase>>,
48	claims: &[ExpClaim<F>],
49	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
50	transcript: &mut ProverTranscript<Challenger_>,
51	backend: &Backend,
52) -> Result<BaseExpReductionOutput<F>, Error>
53where
54	F: ExtensionField<FBase> + ExtensionField<FDomain> + TowerField,
55	FDomain: Field,
56	FBase: TowerField + ExtensionField<FDomain>,
57	P: PackedFieldIndexable<Scalar = F>
58		+ PackedExtension<F, PackedSubfield = P>
59		+ PackedExtension<FDomain>,
60	Backend: ComputationBackend,
61	Challenger_: Challenger,
62{
63	let witnesses = witnesses.into_iter().collect::<Vec<_>>();
64
65	if witnesses.len() != claims.len() {
66		bail!(Error::MismatchedWitnessClaimLength);
67	}
68
69	let mut layers_claims = Vec::new();
70
71	if witnesses.is_empty() {
72		return Ok(BaseExpReductionOutput { layers_claims });
73	}
74
75	// Check that the witnesses are in descending order by n_vars
76	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
77		bail!(Error::ClaimsOutOfOrder);
78	}
79
80	let mut provers = make_provers::<_, FBase>(witnesses, claims)?;
81
82	let max_exponent_bit_number = provers.first().map(|p| p.exponent_bit_width()).unwrap_or(0);
83
84	for layer_no in 0..max_exponent_bit_number {
85		let gkr_sumcheck_provers = build_layer_gkr_sumcheck_provers(
86			evaluation_order,
87			&mut provers,
88			layer_no,
89			evaluation_domain_factory.clone(),
90			backend,
91		)?;
92
93		let sumcheck_proof_output = sumcheck::batch_prove(gkr_sumcheck_provers, transcript)?;
94
95		let layer_exponent_claims = build_layer_exponent_bit_claims(
96			evaluation_order,
97			&mut provers,
98			sumcheck_proof_output,
99			layer_no,
100		)?;
101
102		layers_claims.push(layer_exponent_claims);
103
104		provers.retain(|prover| !prover.is_last_layer(layer_no));
105	}
106
107	Ok(BaseExpReductionOutput { layers_claims })
108}
109
110type GKRExpProvers<'a, F, P, FDomain, Backend> =
111	Vec<GKRExpProver<'a, FDomain, P, ProverExpComposition<F>, MultilinearWitness<'a, P>, Backend>>;
112
113/// Groups consecutive provers by their `eval_point` and reduces them to sumcheck provers.
114#[instrument(skip_all, level = "debug")]
115fn build_layer_gkr_sumcheck_provers<'a, P, FDomain, Backend>(
116	evaluation_order: EvaluationOrder,
117	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
118	layer_no: usize,
119	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
120	backend: &'a Backend,
121) -> Result<GKRExpProvers<'a, P::Scalar, P, FDomain, Backend>, Error>
122where
123	FDomain: Field,
124	P: PackedFieldIndexable + PackedExtension<FDomain>,
125	P::Scalar: TowerField + ExtensionField<FDomain>,
126	Backend: ComputationBackend,
127{
128	assert!(!provers.is_empty());
129
130	let mut composite_claims = Vec::new();
131	let mut multilinears = Vec::new();
132
133	let first_eval_point = provers[0].layer_claim_eval_point().to_vec();
134	let mut eval_points = vec![first_eval_point];
135
136	let mut active_index = 0;
137
138	// group provers by evaluation points and build composite sum claims.
139	for i in 0..provers.len() {
140		if provers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
141			let CompositeSumClaimsWithMultilinears {
142				composite_claims: eval_point_composite_claims,
143				multilinears: eval_point_multilinears,
144			} = build_eval_point_claims::<P>(&mut provers[active_index..i], layer_no)?;
145
146			if eval_point_composite_claims.is_empty() {
147				// extract the last point because provers with this point will not participate in the sumcheck.
148				eval_points.pop();
149			} else {
150				composite_claims.push(eval_point_composite_claims);
151				multilinears.push(eval_point_multilinears);
152			}
153
154			eval_points.push(provers[i].layer_claim_eval_point().to_vec());
155			active_index = i;
156		}
157
158		if i == provers.len() - 1 {
159			let CompositeSumClaimsWithMultilinears {
160				composite_claims: eval_point_composite_claims,
161				multilinears: eval_point_multilinears,
162			} = build_eval_point_claims::<P>(&mut provers[active_index..], layer_no)?;
163
164			if !eval_point_composite_claims.is_empty() {
165				composite_claims.push(eval_point_composite_claims);
166				multilinears.push(eval_point_multilinears);
167			}
168		}
169	}
170
171	izip!(composite_claims, multilinears, eval_points)
172		.map(|(composite_claims, multilinears, eval_point)| {
173			GKRExpProver::<'a, FDomain, P, _, _, Backend>::new(
174				evaluation_order,
175				multilinears,
176				None,
177				composite_claims,
178				evaluation_domain_factory.clone(),
179				&eval_point,
180				backend,
181			)
182		})
183		.collect::<Result<Vec<_>, _>>()
184		.map_err(Error::from)
185}
186
187struct CompositeSumClaimsWithMultilinears<'a, P: PackedField> {
188	composite_claims: Vec<CompositeSumClaim<P::Scalar, ProverExpComposition<P::Scalar>>>,
189	multilinears: Vec<MultilinearWitness<'a, P>>,
190}
191
192/// Builds composite claims and multilinears for provers that share the same `eval_point` from their internal [LayerClaim]s.
193fn build_eval_point_claims<'a, P>(
194	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
195	layer_no: usize,
196) -> Result<CompositeSumClaimsWithMultilinears<'a, P>, Error>
197where
198	P: PackedField,
199{
200	let (composite_claims_n_multilinears, n_claims) =
201		provers
202			.iter()
203			.fold((0, 0), |(n_multilinears, n_claims), prover| {
204				let layer_n_multilinears = prover.layer_n_multilinears(layer_no);
205				let layer_n_claims = prover.layer_n_claims(layer_no);
206
207				(n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
208			});
209
210	let mut multilinears = Vec::with_capacity(composite_claims_n_multilinears);
211
212	let mut composite_claims = Vec::with_capacity(n_claims);
213
214	for prover in provers {
215		let multilinears_index = multilinears.len();
216
217		let meta = prover.layer_composite_sum_claim(
218			layer_no,
219			composite_claims_n_multilinears,
220			multilinears_index,
221		)?;
222
223		if let Some(meta) = meta {
224			let CompositeSumClaimWithMultilinears {
225				claim,
226				multilinears: this_layer_multilinears,
227			} = meta;
228
229			composite_claims.push(claim);
230
231			multilinears.extend(this_layer_multilinears);
232		}
233	}
234	Ok(CompositeSumClaimsWithMultilinears {
235		composite_claims,
236		multilinears,
237	})
238}
239
240/// Reduces the sumcheck output to [LayerClaim]s and updates the internal provers [LayerClaim]s for the next layer.
241fn build_layer_exponent_bit_claims<'a, P>(
242	evaluation_order: EvaluationOrder,
243	provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
244	mut sumcheck_output: BatchSumcheckOutput<P::Scalar>,
245	layer_no: usize,
246) -> Result<Vec<LayerClaim<P::Scalar>>, Error>
247where
248	P: PackedField,
249{
250	let mut eval_claims_on_exponent_bit_columns = Vec::new();
251
252	// extract eq_ind_evals
253	for multilinear_evals in &mut sumcheck_output.multilinear_evals {
254		multilinear_evals.pop();
255	}
256
257	let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
258
259	for prover in provers {
260		let this_prover_n_multilinears = prover.layer_n_multilinears(layer_no);
261
262		let this_prover_multilinear_evals = multilinear_evals
263			.by_ref()
264			.take(this_prover_n_multilinears)
265			.collect::<Vec<_>>();
266
267		let exponent_bit_claims = prover.finish_layer(
268			evaluation_order,
269			layer_no,
270			&this_prover_multilinear_evals,
271			&sumcheck_output.challenges,
272		);
273
274		eval_claims_on_exponent_bit_columns.extend(exponent_bit_claims);
275	}
276
277	Ok(eval_claims_on_exponent_bit_columns)
278}
279
280/// Creates a vector of boxed [ExpProver]s from the given witnesses and claims.
281fn make_provers<'a, P, FBase>(
282	witnesses: Vec<BaseExpWitness<'a, P, FBase>>,
283	claims: &[ExpClaim<P::Scalar>],
284) -> Result<Vec<Box<dyn ExpProver<'a, P> + 'a>>, Error>
285where
286	P: PackedField,
287	FBase: BinaryField,
288	P::Scalar: BinaryField + ExtensionField<FBase>,
289{
290	witnesses
291		.into_iter()
292		.zip(claims)
293		.map(|(witness, claim)| {
294			if witness.uses_dynamic_base() {
295				DynamicBaseExpProver::new(witness, claim)
296					.map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
297			} else {
298				GeneratorExpProver::<'a, P, FBase>::new(witness, claim)
299					.map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
300			}
301		})
302		.collect::<Result<Vec<_>, Error>>()
303}