binius_core/protocols/gkr_exp/
batch_verify.rs1use binius_field::{BinaryField, Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6
7use super::{
8 common::{BaseExpReductionOutput, ExpClaim, LayerClaim},
9 compositions::IndexedExpComposition,
10 error::{Error, VerificationError},
11 verifiers::{DynamicExpVerifier, ExpVerifier, StaticBaseExpVerifier},
12};
13use crate::{
14 fiat_shamir::Challenger,
15 polynomial::MultivariatePoly,
16 protocols::sumcheck::{self, BatchSumcheckOutput, EqIndSumcheckClaim},
17 transcript::VerifierTranscript,
18 transparent::eq_ind::EqIndPartialEval,
19};
20
21pub fn batch_verify<F, Challenger_>(
32 evaluation_order: EvaluationOrder,
33 claims: &[ExpClaim<F>],
34 transcript: &mut VerifierTranscript<Challenger_>,
35) -> Result<BaseExpReductionOutput<F>, Error>
36where
37 F: TowerField,
38 Challenger_: Challenger,
39{
40 let mut layers_claims = Vec::new();
41
42 if claims.is_empty() {
43 return Ok(BaseExpReductionOutput { layers_claims });
44 }
45
46 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
48 bail!(Error::ClaimsOutOfOrder);
49 }
50
51 let mut verifiers = make_verifiers(claims)?;
52
53 let max_exponent_bit_number = verifiers
54 .iter()
55 .map(|p| p.exponent_bit_width())
56 .max()
57 .unwrap_or(0);
58
59 for layer_no in 0..max_exponent_bit_number {
60 let EqIndSumcheckClaimsWithEvalPoints {
61 eq_ind_sumcheck_claims,
62 eval_points,
63 } = build_layer_eq_ind_sumcheck_claims(&verifiers, layer_no)?;
64
65 let regular_sumcheck_claims =
66 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
67
68 let sumcheck_verification_output =
69 sumcheck::batch_verify(evaluation_order, ®ular_sumcheck_claims, transcript)?;
70
71 let layer_exponent_claims = build_layer_exponent_bit_claims(
72 evaluation_order,
73 &mut verifiers,
74 sumcheck_verification_output,
75 eval_points,
76 layer_no,
77 )?;
78
79 layers_claims.push(layer_exponent_claims);
80
81 verifiers.retain(|verifier| !verifier.is_last_layer(layer_no));
82 }
83
84 Ok(BaseExpReductionOutput { layers_claims })
85}
86
87struct EqIndSumcheckClaimsWithEvalPoints<F: Field> {
88 eq_ind_sumcheck_claims: Vec<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>,
89 eval_points: Vec<Vec<F>>,
90}
91
92fn build_layer_eq_ind_sumcheck_claims<'a, F>(
94 verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
95 layer_no: usize,
96) -> Result<EqIndSumcheckClaimsWithEvalPoints<F>, Error>
97where
98 F: Field,
99{
100 let mut eq_ind_sumcheck_claims = Vec::new();
101
102 let first_eval_point = verifiers[0].layer_claim_eval_point().to_vec();
103 let mut eval_points = vec![first_eval_point];
104
105 let mut active_index = 0;
106
107 for i in 0..verifiers.len() {
109 if verifiers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
110 let eq_ind_sumcheck_claim =
111 build_eval_point_claims(&verifiers[active_index..i], layer_no)?;
112
113 if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
114 eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
115 } else {
116 eval_points.pop();
118 }
119
120 eval_points.push(verifiers[i].layer_claim_eval_point().to_vec());
121
122 active_index = i;
123 }
124
125 if i == verifiers.len() - 1 {
126 let eq_ind_sumcheck_claim =
127 build_eval_point_claims(&verifiers[active_index..], layer_no)?;
128
129 if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
130 eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
131 }
132 }
133 }
134
135 Ok(EqIndSumcheckClaimsWithEvalPoints {
136 eq_ind_sumcheck_claims,
137 eval_points,
138 })
139}
140
141fn build_eval_point_claims<'a, F>(
145 verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
146 layer_no: usize,
147) -> Result<Option<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>, Error>
148where
149 F: Field,
150{
151 let (composite_claims_n_multilinears, n_claims) =
152 verifiers
153 .iter()
154 .fold((0, 0), |(n_multilinears, n_claims), verifier| {
155 let layer_n_multilinears = verifier.layer_n_multilinears(layer_no);
156 let layer_n_claims = verifier.layer_n_claims(layer_no);
157
158 (n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
159 });
160
161 if composite_claims_n_multilinears == 0 {
162 return Ok(None);
163 }
164
165 let n_vars = verifiers[0].layer_claim_eval_point().len();
166
167 let mut multilinears_index = 0;
168
169 let mut composite_sums = Vec::with_capacity(n_claims);
170
171 for verifier in verifiers {
172 let composite_sum_claim = verifier.layer_composite_sum_claim(
173 layer_no,
174 composite_claims_n_multilinears,
175 multilinears_index,
176 )?;
177
178 if let Some(composite_sum_claim) = composite_sum_claim {
179 composite_sums.push(composite_sum_claim);
180 }
181
182 multilinears_index += verifier.layer_n_multilinears(layer_no);
183 }
184
185 Ok(Some(EqIndSumcheckClaim::new(n_vars, composite_claims_n_multilinears, composite_sums)?))
186}
187
188pub fn build_layer_exponent_bit_claims<'a, F>(
190 evaluation_order: EvaluationOrder,
191 verifiers: &mut [Box<dyn ExpVerifier<F> + 'a>],
192 mut sumcheck_output: BatchSumcheckOutput<F>,
193 eval_points: Vec<Vec<F>>,
194 layer_no: usize,
195) -> Result<Vec<LayerClaim<F>>, Error>
196where
197 F: TowerField,
198{
199 let mut eval_claims_on_exponent_bit_columns = Vec::new();
200
201 for (multilinear_evals, current_eval_point) in sumcheck_output
202 .multilinear_evals
203 .iter_mut()
204 .zip(eval_points.into_iter())
205 {
206 let n_vars = current_eval_point.len();
207
208 let eval_point = match evaluation_order {
209 EvaluationOrder::LowToHigh => {
210 sumcheck_output.challenges[sumcheck_output.challenges.len() - n_vars..].to_vec()
211 }
212 EvaluationOrder::HighToLow => sumcheck_output.challenges[..n_vars].to_vec(),
213 };
214
215 let expected_eq_ind_eval =
216 EqIndPartialEval::new(current_eval_point).evaluate(&eval_point)?;
217
218 let eq_ind_eval = multilinear_evals
219 .pop()
220 .expect("multilinear_evals contains the evaluation of the equality indicator");
221
222 if expected_eq_ind_eval != eq_ind_eval {
223 return Err(VerificationError::IncorrectEqIndEvaluation.into());
224 }
225 }
226
227 let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
228
229 for verifier in verifiers {
230 let this_verifier_n_multilinears = verifier.layer_n_multilinears(layer_no);
231
232 let this_verifier_multilinear_evals = multilinear_evals
233 .by_ref()
234 .take(this_verifier_n_multilinears)
235 .collect::<Vec<_>>();
236
237 let layer_claims = verifier.finish_layer(
238 evaluation_order,
239 layer_no,
240 &this_verifier_multilinear_evals,
241 &sumcheck_output.challenges,
242 );
243
244 eval_claims_on_exponent_bit_columns.extend(layer_claims);
245 }
246
247 Ok(eval_claims_on_exponent_bit_columns)
248}
249
250fn make_verifiers<'a, F>(claims: &[ExpClaim<F>]) -> Result<Vec<Box<dyn ExpVerifier<F> + 'a>>, Error>
252where
253 F: BinaryField,
254{
255 claims
256 .iter()
257 .map(|claim| {
258 if claim.static_base.is_none() {
259 DynamicExpVerifier::new(claim)
260 .map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
261 } else {
262 StaticBaseExpVerifier::<F>::new(claim)
263 .map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
264 }
265 })
266 .collect::<Result<Vec<_>, Error>>()
267}