binius_core/protocols/gkr_exp/
batch_verify.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6
7use super::{
8	common::{BaseExpReductionOutput, ExpClaim, LayerClaim},
9	compositions::IndexedExpComposition,
10	error::{Error, VerificationError},
11	verifiers::{DynamicExpVerifier, ExpVerifier, StaticBaseExpVerifier},
12};
13use crate::{
14	fiat_shamir::Challenger,
15	polynomial::MultivariatePoly,
16	protocols::sumcheck::{self, BatchSumcheckOutput, EqIndSumcheckClaim},
17	transcript::VerifierTranscript,
18	transparent::eq_ind::EqIndPartialEval,
19};
20
21/// Verify a batched GKR exponentiation protocol execution.
22///
23/// The protocol can be batched over multiple instances by grouping consecutive verifiers over
24/// eval_points in [ExpClaim] into [EqIndSumcheckClaim]s. To achieve this, we use
25/// [crate::composition::IndexComposition], where eq indicator is always the last element. Since
26/// exponents can have different bit sizes, resulting in a varying number of layers, we group
27/// them starting from the first layer to maximize the opportunity to share the same evaluation point.
28///
29/// # Requirements
30/// - Claims must be sorted in descending order by `n_vars`.
31pub fn batch_verify<F, Challenger_>(
32	evaluation_order: EvaluationOrder,
33	claims: &[ExpClaim<F>],
34	transcript: &mut VerifierTranscript<Challenger_>,
35) -> Result<BaseExpReductionOutput<F>, Error>
36where
37	F: TowerField,
38	Challenger_: Challenger,
39{
40	let mut layers_claims = Vec::new();
41
42	if claims.is_empty() {
43		return Ok(BaseExpReductionOutput { layers_claims });
44	}
45
46	// Check that the witnesses are in descending order by n_vars
47	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
48		bail!(Error::ClaimsOutOfOrder);
49	}
50
51	let mut verifiers = make_verifiers(claims)?;
52
53	let max_exponent_bit_number = verifiers
54		.iter()
55		.map(|p| p.exponent_bit_width())
56		.max()
57		.unwrap_or(0);
58
59	for layer_no in 0..max_exponent_bit_number {
60		let EqIndSumcheckClaimsWithEvalPoints {
61			eq_ind_sumcheck_claims,
62			eval_points,
63		} = build_layer_eq_ind_sumcheck_claims(&verifiers, layer_no)?;
64
65		let regular_sumcheck_claims =
66			sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
67
68		let sumcheck_verification_output =
69			sumcheck::batch_verify(evaluation_order, &regular_sumcheck_claims, transcript)?;
70
71		let layer_exponent_claims = build_layer_exponent_bit_claims(
72			evaluation_order,
73			&mut verifiers,
74			sumcheck_verification_output,
75			eval_points,
76			layer_no,
77		)?;
78
79		layers_claims.push(layer_exponent_claims);
80
81		verifiers.retain(|verifier| !verifier.is_last_layer(layer_no));
82	}
83
84	Ok(BaseExpReductionOutput { layers_claims })
85}
86
87struct EqIndSumcheckClaimsWithEvalPoints<F: Field> {
88	eq_ind_sumcheck_claims: Vec<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>,
89	eval_points: Vec<Vec<F>>,
90}
91
92/// Groups consecutive verifier by their `eval_point` and reduces them to sumcheck claims.
93fn build_layer_eq_ind_sumcheck_claims<'a, F>(
94	verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
95	layer_no: usize,
96) -> Result<EqIndSumcheckClaimsWithEvalPoints<F>, Error>
97where
98	F: Field,
99{
100	let mut eq_ind_sumcheck_claims = Vec::new();
101
102	let first_eval_point = verifiers[0].layer_claim_eval_point().to_vec();
103	let mut eval_points = vec![first_eval_point];
104
105	let mut active_index = 0;
106
107	// group verifiers by evaluation points and build sumcheck claims.
108	for i in 0..verifiers.len() {
109		if verifiers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
110			let eq_ind_sumcheck_claim =
111				build_eval_point_claims(&verifiers[active_index..i], layer_no)?;
112
113			if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
114				eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
115			} else {
116				// extract the last point because verifiers with this point will not participate in the sumcheck.
117				eval_points.pop();
118			}
119
120			eval_points.push(verifiers[i].layer_claim_eval_point().to_vec());
121
122			active_index = i;
123		}
124
125		if i == verifiers.len() - 1 {
126			let eq_ind_sumcheck_claim =
127				build_eval_point_claims(&verifiers[active_index..], layer_no)?;
128
129			if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
130				eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
131			}
132		}
133	}
134
135	Ok(EqIndSumcheckClaimsWithEvalPoints {
136		eq_ind_sumcheck_claims,
137		eval_points,
138	})
139}
140
141/// Builds sumcheck claim for verifiers that share the same `eval_point` from their internal
142/// [ExpClaim]s. The batched multilinears are structured as a single concatenated vector of all
143/// multilinears used by the verifiers, with the eq indicator positioned at the end.
144fn build_eval_point_claims<'a, F>(
145	verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
146	layer_no: usize,
147) -> Result<Option<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>, Error>
148where
149	F: Field,
150{
151	let (composite_claims_n_multilinears, n_claims) =
152		verifiers
153			.iter()
154			.fold((0, 0), |(n_multilinears, n_claims), verifier| {
155				let layer_n_multilinears = verifier.layer_n_multilinears(layer_no);
156				let layer_n_claims = verifier.layer_n_claims(layer_no);
157
158				(n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
159			});
160
161	if composite_claims_n_multilinears == 0 {
162		return Ok(None);
163	}
164
165	let n_vars = verifiers[0].layer_claim_eval_point().len();
166
167	let mut multilinears_index = 0;
168
169	let mut composite_sums = Vec::with_capacity(n_claims);
170
171	for verifier in verifiers {
172		let composite_sum_claim = verifier.layer_composite_sum_claim(
173			layer_no,
174			composite_claims_n_multilinears,
175			multilinears_index,
176		)?;
177
178		if let Some(composite_sum_claim) = composite_sum_claim {
179			composite_sums.push(composite_sum_claim);
180		}
181
182		multilinears_index += verifier.layer_n_multilinears(layer_no);
183	}
184
185	Ok(Some(EqIndSumcheckClaim::new(n_vars, composite_claims_n_multilinears, composite_sums)?))
186}
187
188/// Reduces the sumcheck output to [LayerClaim]s and updates the internal verifier [ExpClaim]s for the next layer.
189pub fn build_layer_exponent_bit_claims<'a, F>(
190	evaluation_order: EvaluationOrder,
191	verifiers: &mut [Box<dyn ExpVerifier<F> + 'a>],
192	mut sumcheck_output: BatchSumcheckOutput<F>,
193	eval_points: Vec<Vec<F>>,
194	layer_no: usize,
195) -> Result<Vec<LayerClaim<F>>, Error>
196where
197	F: TowerField,
198{
199	let mut eval_claims_on_exponent_bit_columns = Vec::new();
200
201	for (multilinear_evals, current_eval_point) in sumcheck_output
202		.multilinear_evals
203		.iter_mut()
204		.zip(eval_points.into_iter())
205	{
206		let n_vars = current_eval_point.len();
207
208		let eval_point = match evaluation_order {
209			EvaluationOrder::LowToHigh => {
210				sumcheck_output.challenges[sumcheck_output.challenges.len() - n_vars..].to_vec()
211			}
212			EvaluationOrder::HighToLow => sumcheck_output.challenges[..n_vars].to_vec(),
213		};
214
215		let expected_eq_ind_eval =
216			EqIndPartialEval::new(current_eval_point).evaluate(&eval_point)?;
217
218		let eq_ind_eval = multilinear_evals
219			.pop()
220			.expect("multilinear_evals contains the evaluation of the equality indicator");
221
222		if expected_eq_ind_eval != eq_ind_eval {
223			return Err(VerificationError::IncorrectEqIndEvaluation.into());
224		}
225	}
226
227	let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
228
229	for verifier in verifiers {
230		let this_verifier_n_multilinears = verifier.layer_n_multilinears(layer_no);
231
232		let this_verifier_multilinear_evals = multilinear_evals
233			.by_ref()
234			.take(this_verifier_n_multilinears)
235			.collect::<Vec<_>>();
236
237		let layer_claims = verifier.finish_layer(
238			evaluation_order,
239			layer_no,
240			&this_verifier_multilinear_evals,
241			&sumcheck_output.challenges,
242		);
243
244		eval_claims_on_exponent_bit_columns.extend(layer_claims);
245	}
246
247	Ok(eval_claims_on_exponent_bit_columns)
248}
249
250/// Creates a vector of boxed [ExpVerifier]s from the given claims.
251fn make_verifiers<'a, F>(claims: &[ExpClaim<F>]) -> Result<Vec<Box<dyn ExpVerifier<F> + 'a>>, Error>
252where
253	F: BinaryField,
254{
255	claims
256		.iter()
257		.map(|claim| {
258			if claim.static_base.is_none() {
259				DynamicExpVerifier::new(claim)
260					.map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
261			} else {
262				StaticBaseExpVerifier::<F>::new(claim)
263					.map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
264			}
265		})
266		.collect::<Result<Vec<_>, Error>>()
267}