binius_core/protocols/gkr_exp/
batch_verify.rs1use binius_field::{BinaryField, ExtensionField, Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6
7use super::{
8 common::{BaseExpReductionOutput, ExpClaim, LayerClaim},
9 compositions::VerifierExpComposition,
10 error::{Error, VerificationError},
11 verifiers::{ExpDynamicVerifier, ExpVerifier, GeneratorExpVerifier},
12};
13use crate::{
14 fiat_shamir::Challenger,
15 polynomial::MultivariatePoly,
16 protocols::sumcheck::{self, BatchSumcheckOutput, SumcheckClaim},
17 transcript::VerifierTranscript,
18 transparent::eq_ind::EqIndPartialEval,
19};
20
21pub fn batch_verify<FBase, F, Challenger_>(
32 evaluation_order: EvaluationOrder,
33 claims: &[ExpClaim<F>],
34 transcript: &mut VerifierTranscript<Challenger_>,
35) -> Result<BaseExpReductionOutput<F>, Error>
36where
37 FBase: TowerField,
38 F: TowerField + ExtensionField<FBase>,
39 Challenger_: Challenger,
40{
41 let mut layers_claims = Vec::new();
42
43 if claims.is_empty() {
44 return Ok(BaseExpReductionOutput { layers_claims });
45 }
46
47 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
49 bail!(Error::ClaimsOutOfOrder);
50 }
51
52 let mut verifiers = make_verifiers::<_, FBase>(claims)?;
53
54 let max_exponent_bit_number = verifiers
55 .first()
56 .map(|verifier| verifier.exponent_bit_width())
57 .unwrap_or(0);
58
59 for layer_no in 0..max_exponent_bit_number {
60 let SumcheckClaimsWithEvalPoints {
61 sumcheck_claims,
62 eval_points,
63 } = build_layer_gkr_sumcheck_claims(&verifiers, layer_no)?;
64
65 let sumcheck_verification_output =
66 sumcheck::batch_verify(evaluation_order, &sumcheck_claims, transcript)?;
67
68 let layer_exponent_claims = build_layer_exponent_bit_claims(
69 evaluation_order,
70 &mut verifiers,
71 sumcheck_verification_output,
72 eval_points,
73 layer_no,
74 )?;
75
76 layers_claims.push(layer_exponent_claims);
77
78 verifiers.retain(|verifier| !verifier.is_last_layer(layer_no));
79 }
80
81 Ok(BaseExpReductionOutput { layers_claims })
82}
83
84struct SumcheckClaimsWithEvalPoints<F: Field> {
85 sumcheck_claims: Vec<SumcheckClaim<F, VerifierExpComposition<F>>>,
86 eval_points: Vec<Vec<F>>,
87}
88
89fn build_layer_gkr_sumcheck_claims<'a, F>(
91 verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
92 layer_no: usize,
93) -> Result<SumcheckClaimsWithEvalPoints<F>, Error>
94where
95 F: Field,
96{
97 let mut sumcheck_claims = Vec::new();
98
99 let first_eval_point = verifiers[0].layer_claim_eval_point().to_vec();
100 let mut eval_points = vec![first_eval_point];
101
102 let mut active_index = 0;
103
104 for i in 0..verifiers.len() {
106 if verifiers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
107 let sumcheck_claim = build_eval_point_claims(&verifiers[active_index..i], layer_no)?;
108
109 if let Some(sumcheck_claim) = sumcheck_claim {
110 sumcheck_claims.push(sumcheck_claim);
111 } else {
112 eval_points.pop();
114 }
115
116 eval_points.push(verifiers[i].layer_claim_eval_point().to_vec());
117
118 active_index = i;
119 }
120
121 if i == verifiers.len() - 1 {
122 let sumcheck_claim = build_eval_point_claims(&verifiers[active_index..], layer_no)?;
123
124 if let Some(sumcheck_claim) = sumcheck_claim {
125 sumcheck_claims.push(sumcheck_claim);
126 }
127 }
128 }
129
130 Ok(SumcheckClaimsWithEvalPoints {
131 sumcheck_claims,
132 eval_points,
133 })
134}
135
136fn build_eval_point_claims<'a, F>(
140 verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
141 layer_no: usize,
142) -> Result<Option<SumcheckClaim<F, VerifierExpComposition<F>>>, Error>
143where
144 F: Field,
145{
146 let (mut composite_claims_n_multilinears, n_claims) =
147 verifiers
148 .iter()
149 .fold((0, 0), |(n_multilinears, n_claims), verifier| {
150 let layer_n_multilinears = verifier.layer_n_multilinears(layer_no);
151 let layer_n_claims = verifier.layer_n_claims(layer_no);
152
153 (n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
154 });
155
156 if composite_claims_n_multilinears == 0 {
157 return Ok(None);
158 }
159
160 let n_vars = verifiers[0].layer_claim_eval_point().len();
161
162 let eq_ind_index = composite_claims_n_multilinears;
163
164 composite_claims_n_multilinears += 1;
166
167 let mut multilinears_index = 0;
168
169 let mut composite_sums = Vec::with_capacity(n_claims);
170
171 for verifier in verifiers {
172 let composite_sum_claim = verifier.layer_composite_sum_claim(
173 layer_no,
174 composite_claims_n_multilinears,
175 multilinears_index,
176 eq_ind_index,
177 )?;
178
179 if let Some(composite_sum_claim) = composite_sum_claim {
180 composite_sums.push(composite_sum_claim);
181 }
182
183 multilinears_index += verifier.layer_n_multilinears(layer_no);
184 }
185
186 SumcheckClaim::new(n_vars, composite_claims_n_multilinears, composite_sums)
187 .map(Some)
188 .map_err(Error::from)
189}
190
191pub fn build_layer_exponent_bit_claims<'a, F>(
193 evaluation_order: EvaluationOrder,
194 verifiers: &mut [Box<dyn ExpVerifier<F> + 'a>],
195 mut sumcheck_output: BatchSumcheckOutput<F>,
196 eval_points: Vec<Vec<F>>,
197 layer_no: usize,
198) -> Result<Vec<LayerClaim<F>>, Error>
199where
200 F: TowerField,
201{
202 let mut eval_claims_on_exponent_bit_columns = Vec::new();
203
204 for (multilinear_evals, current_eval_point) in sumcheck_output
205 .multilinear_evals
206 .iter_mut()
207 .zip(eval_points.into_iter())
208 {
209 let n_vars = current_eval_point.len();
210
211 let eval_point = match evaluation_order {
212 EvaluationOrder::LowToHigh => {
213 sumcheck_output.challenges[sumcheck_output.challenges.len() - n_vars..].to_vec()
214 }
215 EvaluationOrder::HighToLow => sumcheck_output.challenges[..n_vars].to_vec(),
216 };
217
218 let expected_eq_ind_eval =
219 EqIndPartialEval::new(current_eval_point).evaluate(&eval_point)?;
220
221 let eq_ind_eval = multilinear_evals
222 .pop()
223 .expect("multilinear_evals contains the evaluation of the equality indicator");
224
225 if expected_eq_ind_eval != eq_ind_eval {
226 return Err(VerificationError::IncorrectEqIndEvaluation.into());
227 }
228 }
229
230 let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
231
232 for verifier in verifiers {
233 let this_verifier_n_multilinears = verifier.layer_n_multilinears(layer_no);
234
235 let this_verifier_multilinear_evals = multilinear_evals
236 .by_ref()
237 .take(this_verifier_n_multilinears)
238 .collect::<Vec<_>>();
239
240 let layer_claims = verifier.finish_layer(
241 evaluation_order,
242 layer_no,
243 &this_verifier_multilinear_evals,
244 &sumcheck_output.challenges,
245 );
246
247 eval_claims_on_exponent_bit_columns.extend(layer_claims);
248 }
249
250 Ok(eval_claims_on_exponent_bit_columns)
251}
252
253fn make_verifiers<'a, F, FBase>(
255 claims: &[ExpClaim<F>],
256) -> Result<Vec<Box<dyn ExpVerifier<F> + 'a>>, Error>
257where
258 FBase: BinaryField,
259 F: BinaryField + ExtensionField<FBase>,
260{
261 claims
262 .iter()
263 .map(|claim| {
264 if claim.uses_dynamic_base {
265 ExpDynamicVerifier::new(claim)
266 .map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
267 } else {
268 GeneratorExpVerifier::<F, FBase>::new(claim)
269 .map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
270 }
271 })
272 .collect::<Result<Vec<_>, Error>>()
273}