binius_core/protocols/gkr_exp/
batch_verify.rs1use binius_field::{BinaryField, Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6
7use super::{
8 common::{BaseExpReductionOutput, ExpClaim, LayerClaim},
9 compositions::IndexedExpComposition,
10 error::{Error, VerificationError},
11 verifiers::{DynamicExpVerifier, ExpVerifier, StaticBaseExpVerifier},
12};
13use crate::{
14 fiat_shamir::Challenger,
15 polynomial::MultivariatePoly,
16 protocols::sumcheck::{self, BatchSumcheckOutput, EqIndSumcheckClaim},
17 transcript::VerifierTranscript,
18 transparent::eq_ind::EqIndPartialEval,
19};
20
21pub fn batch_verify<F, Challenger_>(
33 evaluation_order: EvaluationOrder,
34 claims: &[ExpClaim<F>],
35 transcript: &mut VerifierTranscript<Challenger_>,
36) -> Result<BaseExpReductionOutput<F>, Error>
37where
38 F: TowerField,
39 Challenger_: Challenger,
40{
41 let mut layers_claims = Vec::new();
42
43 if claims.is_empty() {
44 return Ok(BaseExpReductionOutput { layers_claims });
45 }
46
47 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
49 bail!(Error::ClaimsOutOfOrder);
50 }
51
52 let mut verifiers = make_verifiers(claims)?;
53
54 let max_exponent_bit_number = verifiers
55 .iter()
56 .map(|p| p.exponent_bit_width())
57 .max()
58 .unwrap_or(0);
59
60 for layer_no in 0..max_exponent_bit_number {
61 let EqIndSumcheckClaimsWithEvalPoints {
62 eq_ind_sumcheck_claims,
63 eval_points,
64 } = build_layer_eq_ind_sumcheck_claims(&verifiers, layer_no)?;
65
66 let regular_sumcheck_claims =
67 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
68
69 let sumcheck_verification_output =
70 sumcheck::batch_verify(evaluation_order, ®ular_sumcheck_claims, transcript)?;
71
72 let layer_exponent_claims = build_layer_exponent_bit_claims(
73 evaluation_order,
74 &mut verifiers,
75 sumcheck_verification_output,
76 eval_points,
77 layer_no,
78 )?;
79
80 layers_claims.push(layer_exponent_claims);
81
82 verifiers.retain(|verifier| !verifier.is_last_layer(layer_no));
83 }
84
85 Ok(BaseExpReductionOutput { layers_claims })
86}
87
88struct EqIndSumcheckClaimsWithEvalPoints<F: Field> {
89 eq_ind_sumcheck_claims: Vec<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>,
90 eval_points: Vec<Vec<F>>,
91}
92
93fn build_layer_eq_ind_sumcheck_claims<'a, F>(
95 verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
96 layer_no: usize,
97) -> Result<EqIndSumcheckClaimsWithEvalPoints<F>, Error>
98where
99 F: Field,
100{
101 let mut eq_ind_sumcheck_claims = Vec::new();
102
103 let first_eval_point = verifiers[0].layer_claim_eval_point().to_vec();
104 let mut eval_points = vec![first_eval_point];
105
106 let mut active_index = 0;
107
108 for i in 0..verifiers.len() {
110 if verifiers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
111 let eq_ind_sumcheck_claim =
112 build_eval_point_claims(&verifiers[active_index..i], layer_no)?;
113
114 if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
115 eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
116 } else {
117 eval_points.pop();
120 }
121
122 eval_points.push(verifiers[i].layer_claim_eval_point().to_vec());
123
124 active_index = i;
125 }
126
127 if i == verifiers.len() - 1 {
128 let eq_ind_sumcheck_claim =
129 build_eval_point_claims(&verifiers[active_index..], layer_no)?;
130
131 if let Some(eq_ind_sumcheck_claim) = eq_ind_sumcheck_claim {
132 eq_ind_sumcheck_claims.push(eq_ind_sumcheck_claim);
133 }
134 }
135 }
136
137 Ok(EqIndSumcheckClaimsWithEvalPoints {
138 eq_ind_sumcheck_claims,
139 eval_points,
140 })
141}
142
143fn build_eval_point_claims<'a, F>(
147 verifiers: &[Box<dyn ExpVerifier<F> + 'a>],
148 layer_no: usize,
149) -> Result<Option<EqIndSumcheckClaim<F, IndexedExpComposition<F>>>, Error>
150where
151 F: Field,
152{
153 let (composite_claims_n_multilinears, n_claims) =
154 verifiers
155 .iter()
156 .fold((0, 0), |(n_multilinears, n_claims), verifier| {
157 let layer_n_multilinears = verifier.layer_n_multilinears(layer_no);
158 let layer_n_claims = verifier.layer_n_claims(layer_no);
159
160 (n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
161 });
162
163 if composite_claims_n_multilinears == 0 {
164 return Ok(None);
165 }
166
167 let n_vars = verifiers[0].layer_claim_eval_point().len();
168
169 let mut multilinears_index = 0;
170
171 let mut composite_sums = Vec::with_capacity(n_claims);
172
173 for verifier in verifiers {
174 let composite_sum_claim = verifier.layer_composite_sum_claim(
175 layer_no,
176 composite_claims_n_multilinears,
177 multilinears_index,
178 )?;
179
180 if let Some(composite_sum_claim) = composite_sum_claim {
181 composite_sums.push(composite_sum_claim);
182 }
183
184 multilinears_index += verifier.layer_n_multilinears(layer_no);
185 }
186
187 Ok(Some(EqIndSumcheckClaim::new(n_vars, composite_claims_n_multilinears, composite_sums)?))
188}
189
190pub fn build_layer_exponent_bit_claims<'a, F>(
193 evaluation_order: EvaluationOrder,
194 verifiers: &mut [Box<dyn ExpVerifier<F> + 'a>],
195 mut sumcheck_output: BatchSumcheckOutput<F>,
196 eval_points: Vec<Vec<F>>,
197 layer_no: usize,
198) -> Result<Vec<LayerClaim<F>>, Error>
199where
200 F: TowerField,
201{
202 let mut eval_claims_on_exponent_bit_columns = Vec::new();
203
204 for (multilinear_evals, current_eval_point) in sumcheck_output
205 .multilinear_evals
206 .iter_mut()
207 .zip(eval_points.into_iter())
208 {
209 let n_vars = current_eval_point.len();
210
211 let eval_point = match evaluation_order {
212 EvaluationOrder::LowToHigh => {
213 sumcheck_output.challenges[sumcheck_output.challenges.len() - n_vars..].to_vec()
214 }
215 EvaluationOrder::HighToLow => sumcheck_output.challenges[..n_vars].to_vec(),
216 };
217
218 let expected_eq_ind_eval =
219 EqIndPartialEval::new(current_eval_point).evaluate(&eval_point)?;
220
221 let eq_ind_eval = multilinear_evals
222 .pop()
223 .expect("multilinear_evals contains the evaluation of the equality indicator");
224
225 if expected_eq_ind_eval != eq_ind_eval {
226 return Err(VerificationError::IncorrectEqIndEvaluation.into());
227 }
228 }
229
230 let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
231
232 for verifier in verifiers {
233 let this_verifier_n_multilinears = verifier.layer_n_multilinears(layer_no);
234
235 let this_verifier_multilinear_evals = multilinear_evals
236 .by_ref()
237 .take(this_verifier_n_multilinears)
238 .collect::<Vec<_>>();
239
240 let layer_claims = verifier.finish_layer(
241 evaluation_order,
242 layer_no,
243 &this_verifier_multilinear_evals,
244 &sumcheck_output.challenges,
245 );
246
247 eval_claims_on_exponent_bit_columns.extend(layer_claims);
248 }
249
250 Ok(eval_claims_on_exponent_bit_columns)
251}
252
253fn make_verifiers<'a, F>(claims: &[ExpClaim<F>]) -> Result<Vec<Box<dyn ExpVerifier<F> + 'a>>, Error>
255where
256 F: BinaryField,
257{
258 claims
259 .iter()
260 .map(|claim| {
261 if claim.static_base.is_none() {
262 DynamicExpVerifier::new(claim)
263 .map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
264 } else {
265 StaticBaseExpVerifier::<F>::new(claim)
266 .map(|verifier| Box::new(verifier) as Box<dyn ExpVerifier<F> + 'a>)
267 }
268 })
269 .collect::<Result<Vec<_>, Error>>()
270}