binius_core/protocols/gkr_exp/
batch_verify.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, ExtensionField, Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6
7use super::{
8	common::{BaseExpReductionOutput, ExpClaim, LayerClaim},
9	compositions::VerifierExpComposition,
10	error::{Error, VerificationError},
11	verifiers::{ExpDynamicVerifier, ExpVerifier, GeneratorExpVerifier},
12};
13use crate::{
14	fiat_shamir::Challenger,
15	polynomial::MultivariatePoly,
16	protocols::sumcheck::{self, BatchSumcheckOutput, SumcheckClaim},
17	transcript::VerifierTranscript,
18	transparent::eq_ind::EqIndPartialEval,
19};
20
21/// Verify a batched GKR exponentiation protocol execution.
22///
23/// The protocol can be batched over multiple instances by grouping consecutive verifiers over
24/// eval_points in [ExpClaim] into [SumcheckClaim]s. To achieve this, we use
25/// [crate::composition::IndexComposition], where eq indicator is always the last element. Since
26/// exponents can have different bit sizes, resulting in a varying number of layers, we group
27/// them starting from the first layer to maximize the opportunity to share the same evaluation point.
28///
29/// # Requirements
30/// - Claims must be sorted in descending order by `n_vars`.
31pub fn batch_verify<FBase, F, Challenger_>(
32	evaluation_order: EvaluationOrder,
33	claims: &[ExpClaim<F>],
34	transcript: &mut VerifierTranscript<Challenger_>,
35) -> Result<BaseExpReductionOutput<F>, Error>
36where
37	FBase: TowerField,
38	F: TowerField + ExtensionField<FBase>,
39	Challenger_: Challenger,
40{
41	let mut layers_claims = Vec::new();
42
43	if claims.is_empty() {
44		return Ok(BaseExpReductionOutput { layers_claims });
45	}
46
47	// Check that the witnesses are in descending order by n_vars
48	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
49		bail!(Error::ClaimsOutOfOrder);
50	}
51
52	let mut verifiers = make_verifiers::<_, FBase>(claims)?;
53
54	let max_exponent_bit_number = verifiers
55		.first()
56		.map(|verifier| verifier.exponent_bit_width())
57		.unwrap_or(0);
58
59	for layer_no in 0..max_exponent_bit_number {
60		let SumcheckClaimsWithEvalPoints {
61			sumcheck_claims,
62			eval_points,
63		} = build_layer_gkr_sumcheck_claims(&verifiers, layer_no)?;
64
65		let sumcheck_verification_output =
66			sumcheck::batch_verify(evaluation_order, &sumcheck_claims, transcript)?;
67
68		let layer_exponent_claims = build_layer_exponent_bit_claims(
69			evaluation_order,
70			&mut verifiers,
71			sumcheck_verification_output,
72			eval_points,
73			layer_no,
74		)?;
75
76		layers_claims.push(layer_exponent_claims);
77
78		verifiers.retain(|verifier| !verifier.is_last_layer(layer_no));
79	}
80
81	Ok(BaseExpReductionOutput { layers_claims })
82}
83
84struct SumcheckClaimsWithEvalPoints<F: Field> {
85	sumcheck_claims: Vec<SumcheckClaim<F, VerifierExpComposition<F>>>,
86	eval_points: Vec<Vec<F>>,
87}
88
89/// Groups consecutive verifier by their `eval_point` and reduces them to sumcheck claims.
90fn build_layer_gkr_sumcheck_claims<'a, F>(
91	verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
92	layer_no: usize,
93) -> Result<SumcheckClaimsWithEvalPoints<F>, Error>
94where
95	F: Field,
96{
97	let mut sumcheck_claims = Vec::new();
98
99	let first_eval_point = verifiers[0].layer_claim_eval_point().to_vec();
100	let mut eval_points = vec![first_eval_point];
101
102	let mut active_index = 0;
103
104	// group verifiers by evaluation points and build sumcheck claims.
105	for i in 0..verifiers.len() {
106		if verifiers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
107			let sumcheck_claim = build_eval_point_claims(&verifiers[active_index..i], layer_no)?;
108
109			if let Some(sumcheck_claim) = sumcheck_claim {
110				sumcheck_claims.push(sumcheck_claim);
111			} else {
112				// extract the last point because verifiers with this point will not participate in the sumcheck.
113				eval_points.pop();
114			}
115
116			eval_points.push(verifiers[i].layer_claim_eval_point().to_vec());
117
118			active_index = i;
119		}
120
121		if i == verifiers.len() - 1 {
122			let sumcheck_claim = build_eval_point_claims(&verifiers[active_index..], layer_no)?;
123
124			if let Some(sumcheck_claim) = sumcheck_claim {
125				sumcheck_claims.push(sumcheck_claim);
126			}
127		}
128	}
129
130	Ok(SumcheckClaimsWithEvalPoints {
131		sumcheck_claims,
132		eval_points,
133	})
134}
135
136/// Builds sumcheck claim for verifiers that share the same `eval_point` from their internal
137/// [ExpClaim]s. The batched multilinears are structured as a single concatenated vector of all
138/// multilinears used by the verifiers, with the eq indicator positioned at the end.
139fn build_eval_point_claims<'a, F>(
140	verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
141	layer_no: usize,
142) -> Result<Option<SumcheckClaim<F, VerifierExpComposition<F>>>, Error>
143where
144	F: Field,
145{
146	let (mut composite_claims_n_multilinears, n_claims) =
147		verifiers
148			.iter()
149			.fold((0, 0), |(n_multilinears, n_claims), verifier| {
150				let layer_n_multilinears = verifier.layer_n_multilinears(layer_no);
151				let layer_n_claims = verifier.layer_n_claims(layer_no);
152
153				(n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
154			});
155
156	if composite_claims_n_multilinears == 0 {
157		return Ok(None);
158	}
159
160	let n_vars = verifiers[0].layer_claim_eval_point().len();
161
162	let eq_ind_index = composite_claims_n_multilinears;
163
164	// add eq_ind
165	composite_claims_n_multilinears += 1;
166
167	let mut multilinears_index = 0;
168
169	let mut composite_sums = Vec::with_capacity(n_claims);
170
171	for verifier in verifiers {
172		let composite_sum_claim = verifier.layer_composite_sum_claim(
173			layer_no,
174			composite_claims_n_multilinears,
175			multilinears_index,
176			eq_ind_index,
177		)?;
178
179		if let Some(composite_sum_claim) = composite_sum_claim {
180			composite_sums.push(composite_sum_claim);
181		}
182
183		multilinears_index += verifier.layer_n_multilinears(layer_no);
184	}
185
186	SumcheckClaim::new(n_vars, composite_claims_n_multilinears, composite_sums)
187		.map(Some)
188		.map_err(Error::from)
189}
190
191/// Reduces the sumcheck output to [LayerClaim]s and updates the internal verifier [ExpClaim]s for the next layer.
192pub fn build_layer_exponent_bit_claims<'a, F>(
193	evaluation_order: EvaluationOrder,
194	verifiers: &mut [Box<dyn ExpVerifier<F> + 'a>],
195	mut sumcheck_output: BatchSumcheckOutput<F>,
196	eval_points: Vec<Vec<F>>,
197	layer_no: usize,
198) -> Result<Vec<LayerClaim<F>>, Error>
199where
200	F: TowerField,
201{
202	let mut eval_claims_on_exponent_bit_columns = Vec::new();
203
204	for (multilinear_evals, current_eval_point) in sumcheck_output
205		.multilinear_evals
206		.iter_mut()
207		.zip(eval_points.into_iter())
208	{
209		let n_vars = current_eval_point.len();
210
211		let eval_point = match evaluation_order {
212			EvaluationOrder::LowToHigh => {
213				sumcheck_output.challenges[sumcheck_output.challenges.len() - n_vars..].to_vec()
214			}
215			EvaluationOrder::HighToLow => sumcheck_output.challenges[..n_vars].to_vec(),
216		};
217
218		let expected_eq_ind_eval =
219			EqIndPartialEval::new(current_eval_point).evaluate(&eval_point)?;
220
221		let eq_ind_eval = multilinear_evals
222			.pop()
223			.expect("multilinear_evals contains the evaluation of the equality indicator");
224
225		if expected_eq_ind_eval != eq_ind_eval {
226			return Err(VerificationError::IncorrectEqIndEvaluation.into());
227		}
228	}
229
230	let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
231
232	for verifier in verifiers {
233		let this_verifier_n_multilinears = verifier.layer_n_multilinears(layer_no);
234
235		let this_verifier_multilinear_evals = multilinear_evals
236			.by_ref()
237			.take(this_verifier_n_multilinears)
238			.collect::<Vec<_>>();
239
240		let layer_claims = verifier.finish_layer(
241			evaluation_order,
242			layer_no,
243			&this_verifier_multilinear_evals,
244			&sumcheck_output.challenges,
245		);
246
247		eval_claims_on_exponent_bit_columns.extend(layer_claims);
248	}
249
250	Ok(eval_claims_on_exponent_bit_columns)
251}
252
253/// Creates a vector of boxed [ExpVerifier]s from the given claims.
254fn make_verifiers<'a, F, FBase>(
255	claims: &[ExpClaim<F>],
256) -> Result<Vec<Box<dyn ExpVerifier<F> + 'a>>, Error>
257where
258	FBase: BinaryField,
259	F: BinaryField + ExtensionField<FBase>,
260{
261	claims
262		.iter()
263		.map(|claim| {
264			if claim.uses_dynamic_base {
265				ExpDynamicVerifier::new(claim)
266					.map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
267			} else {
268				GeneratorExpVerifier::<F, FBase>::new(claim)
269					.map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
270			}
271		})
272		.collect::<Result<Vec<_>, Error>>()
273}