binius_core/protocols/gkr_exp/
batch_verify.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6
7use super::{
8	common::{BaseExpReductionOutput, ExpClaim, LayerClaim},
9	compositions::IndexedExpComposition,
10	error::{Error, VerificationError},
11	verifiers::{DynamicExpVerifier, ExpVerifier, StaticBaseExpVerifier},
12};
13use crate::{
14	fiat_shamir::Challenger,
15	polynomial::MultivariatePoly,
16	protocols::sumcheck::{self, BatchSumcheckOutput, EqIndSumcheckClaim},
17	transcript::VerifierTranscript,
18	transparent::eq_ind::EqIndPartialEval,
19};
20
21/// Verify a batched GKR exponentiation protocol execution.
22///
23/// The protocol can be batched over multiple instances by grouping consecutive verifiers over
24/// eval_points in [ExpClaim] into [EqIndSumcheckClaim]s. To achieve this, we use
25/// [crate::composition::IndexComposition], where eq indicator is always the last element. Since
26/// exponents can have different bit sizes, resulting in a varying number of layers, we group
27/// them starting from the first layer to maximize the opportunity to share the same evaluation
28/// point.
29///
30/// # Requirements
31/// - Claims must be sorted in descending order by `n_vars`.
32pub fn batch_verify<F, Challenger_>(
33	evaluation_order: EvaluationOrder,
34	claims: &[ExpClaim<F>],
35	transcript: &mut VerifierTranscript<Challenger_>,
36) -> Result<BaseExpReductionOutput<F>, Error>
37where
38	F: TowerField,
39	Challenger_: Challenger,
40{
41	let mut layers_claims = Vec::new();
42
43	if claims.is_empty() {
44		return Ok(BaseExpReductionOutput { layers_claims });
45	}
46
47	// Check that the witnesses are in descending order by n_vars
48	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
49		bail!(Error::ClaimsOutOfOrder);
50	}
51
52	let mut verifiers = make_verifiers(claims)?;
53
54	let max_exponent_bit_number = verifiers
55		.iter()
56		.map(|p| p.exponent_bit_width())
57		.max()
58		.unwrap_or(0);
59
60	for layer_no in 0..max_exponent_bit_number {
61		let EqIndSumcheckClaimsWithEvalPoints {
62			eq_ind_sumcheck_claims,
63			eval_points,
64		} = build_layer_eq_ind_sumcheck_claims(&verifiers, layer_no)?;
65
66		let regular_sumcheck_claims =
67			sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
68
69		let sumcheck_verification_output =
70			sumcheck::batch_verify(evaluation_order, &regular_sumcheck_claims, transcript)?;
71
72		let layer_exponent_claims = build_layer_exponent_bit_claims(
73			evaluation_order,
74			&mut verifiers,
75			sumcheck_verification_output,
76			eval_points,
77			layer_no,
78		)?;
79
80		layers_claims.push(layer_exponent_claims);
81
82		verifiers.retain(|verifier| !verifier.is_last_layer(layer_no));
83	}
84
85	Ok(BaseExpReductionOutput { layers_claims })
86}
87
88struct EqIndSumcheckClaimsWithEvalPoints<F: Field> {
89	eq_ind_sumcheck_claims: Vec<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>,
90	eval_points: Vec<Vec<F>>,
91}
92
93/// Groups consecutive verifier by their `eval_point` and reduces them to sumcheck claims.
94fn build_layer_eq_ind_sumcheck_claims<'a, F>(
95	verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
96	layer_no: usize,
97) -> Result<EqIndSumcheckClaimsWithEvalPoints<F>, Error>
98where
99	F: Field,
100{
101	let mut eq_ind_sumcheck_claims = Vec::new();
102
103	let first_eval_point = verifiers[0].layer_claim_eval_point().to_vec();
104	let mut eval_points = vec![first_eval_point];
105
106	let mut active_index = 0;
107
108	// group verifiers by evaluation points and build sumcheck claims.
109	for i in 0..verifiers.len() {
110		if verifiers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
111			let eq_ind_sumcheck_claim =
112				build_eval_point_claims(&verifiers[active_index..i], layer_no)?;
113
114			if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
115				eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
116			} else {
117				// extract the last point because verifiers with this point will not participate in
118				// the sumcheck.
119				eval_points.pop();
120			}
121
122			eval_points.push(verifiers[i].layer_claim_eval_point().to_vec());
123
124			active_index = i;
125		}
126
127		if i == verifiers.len() - 1 {
128			let eq_ind_sumcheck_claim =
129				build_eval_point_claims(&verifiers[active_index..], layer_no)?;
130
131			if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
132				eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
133			}
134		}
135	}
136
137	Ok(EqIndSumcheckClaimsWithEvalPoints {
138		eq_ind_sumcheck_claims,
139		eval_points,
140	})
141}
142
143/// Builds sumcheck claim for verifiers that share the same `eval_point` from their internal
144/// [ExpClaim]s. The batched multilinears are structured as a single concatenated vector of all
145/// multilinears used by the verifiers, with the eq indicator positioned at the end.
146fn build_eval_point_claims<'a, F>(
147	verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
148	layer_no: usize,
149) -> Result<Option<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>, Error>
150where
151	F: Field,
152{
153	let (composite_claims_n_multilinears, n_claims) =
154		verifiers
155			.iter()
156			.fold((0, 0), |(n_multilinears, n_claims), verifier| {
157				let layer_n_multilinears = verifier.layer_n_multilinears(layer_no);
158				let layer_n_claims = verifier.layer_n_claims(layer_no);
159
160				(n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
161			});
162
163	if composite_claims_n_multilinears == 0 {
164		return Ok(None);
165	}
166
167	let n_vars = verifiers[0].layer_claim_eval_point().len();
168
169	let mut multilinears_index = 0;
170
171	let mut composite_sums = Vec::with_capacity(n_claims);
172
173	for verifier in verifiers {
174		let composite_sum_claim = verifier.layer_composite_sum_claim(
175			layer_no,
176			composite_claims_n_multilinears,
177			multilinears_index,
178		)?;
179
180		if let Some(composite_sum_claim) = composite_sum_claim {
181			composite_sums.push(composite_sum_claim);
182		}
183
184		multilinears_index += verifier.layer_n_multilinears(layer_no);
185	}
186
187	Ok(Some(EqIndSumcheckClaim::new(n_vars, composite_claims_n_multilinears, composite_sums)?))
188}
189
190/// Reduces the sumcheck output to [LayerClaim]s and updates the internal verifier [ExpClaim]s for
191/// the next layer.
192pub fn build_layer_exponent_bit_claims<'a, F>(
193	evaluation_order: EvaluationOrder,
194	verifiers: &mut [Box<dyn ExpVerifier<F> + 'a>],
195	mut sumcheck_output: BatchSumcheckOutput<F>,
196	eval_points: Vec<Vec<F>>,
197	layer_no: usize,
198) -> Result<Vec<LayerClaim<F>>, Error>
199where
200	F: TowerField,
201{
202	let mut eval_claims_on_exponent_bit_columns = Vec::new();
203
204	for (multilinear_evals, current_eval_point) in sumcheck_output
205		.multilinear_evals
206		.iter_mut()
207		.zip(eval_points.into_iter())
208	{
209		let n_vars = current_eval_point.len();
210
211		let eval_point = match evaluation_order {
212			EvaluationOrder::LowToHigh => {
213				sumcheck_output.challenges[sumcheck_output.challenges.len() - n_vars..].to_vec()
214			}
215			EvaluationOrder::HighToLow => sumcheck_output.challenges[..n_vars].to_vec(),
216		};
217
218		let expected_eq_ind_eval =
219			EqIndPartialEval::new(current_eval_point).evaluate(&eval_point)?;
220
221		let eq_ind_eval = multilinear_evals
222			.pop()
223			.expect("multilinear_evals contains the evaluation of the equality indicator");
224
225		if expected_eq_ind_eval != eq_ind_eval {
226			return Err(VerificationError::IncorrectEqIndEvaluation.into());
227		}
228	}
229
230	let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
231
232	for verifier in verifiers {
233		let this_verifier_n_multilinears = verifier.layer_n_multilinears(layer_no);
234
235		let this_verifier_multilinear_evals = multilinear_evals
236			.by_ref()
237			.take(this_verifier_n_multilinears)
238			.collect::<Vec<_>>();
239
240		let layer_claims = verifier.finish_layer(
241			evaluation_order,
242			layer_no,
243			&this_verifier_multilinear_evals,
244			&sumcheck_output.challenges,
245		);
246
247		eval_claims_on_exponent_bit_columns.extend(layer_claims);
248	}
249
250	Ok(eval_claims_on_exponent_bit_columns)
251}
252
253/// Creates a vector of boxed [ExpVerifier]s from the given claims.
254fn make_verifiers<'a, F>(claims: &[ExpClaim<F>]) -> Result<Vec<Box<dyn ExpVerifier<F> + 'a>>, Error>
255where
256	F: BinaryField,
257{
258	claims
259		.iter()
260		.map(|claim| {
261			if claim.static_base.is_none() {
262				DynamicExpVerifier::new(claim)
263					.map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
264			} else {
265				StaticBaseExpVerifier::<F>::new(claim)
266					.map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
267			}
268		})
269		.collect::<Result<Vec<_>, Error>>()
270}