binius_core/protocols/gkr_exp/
provers.rs1use binius_field::{BinaryField, PackedField};
4use binius_math::EvaluationOrder;
5use binius_utils::bail;
6
7use super::{
8 common::{ExpClaim, LayerClaim},
9 compositions::{ExpCompositions, IndexedExpComposition},
10 error::Error,
11 utils::first_layer_inverse,
12 witness::{BaseExpWitness, BaseWitness},
13};
14use crate::{
15 composition::{FixedDimIndexCompositions, IndexComposition},
16 protocols::sumcheck::CompositeSumClaim,
17 witness::MultilinearWitness,
18};
19
20pub trait ExpProver<'a, P: PackedField> {
21 fn exponent_bit_width(&self) -> usize;
22
23 fn is_last_layer(&self, layer_no: usize) -> bool {
24 self.exponent_bit_width() - 1 - layer_no == 0
25 }
26
27 fn layer_claim_eval_point(&self) -> &[P::Scalar];
29
30 fn layer_composite_sum_claim(
34 &self,
35 layer_no: usize,
36 composite_claims_n_multilinears: usize,
37 multilinears_index: usize,
38 ) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error>;
39
40 fn layer_n_multilinears(&self, layer_no: usize) -> usize;
42
43 fn layer_n_claims(&self, layer_no: usize) -> usize;
45
46 fn finish_layer(
49 &mut self,
50 evaluation_order: EvaluationOrder,
51 layer_no: usize,
52 multilinear_evals: &[P::Scalar],
53 r: &[P::Scalar],
54 ) -> Vec<LayerClaim<P::Scalar>>;
55}
56
57struct ExpCommonProver<'a, P: PackedField> {
58 witness: BaseExpWitness<'a, P>,
59 current_layer_claim: LayerClaim<P::Scalar>,
60}
61
62impl<'a, P: PackedField> ExpCommonProver<'a, P> {
63 fn new(witness: BaseExpWitness<'a, P>, claim: ExpClaim<P::Scalar>) -> Self {
64 Self {
65 witness,
66 current_layer_claim: claim.into(),
67 }
68 }
69
70 pub fn exponent_bit_width(&self) -> usize {
71 self.witness.exponent.len()
72 }
73
74 fn current_layer_single_bit_output_layers_data(
75 &self,
76 layer_no: usize,
77 ) -> MultilinearWitness<'a, P> {
78 let index = self.witness.single_bit_output_layers_data.len() - layer_no - 2;
79
80 self.witness.single_bit_output_layers_data[index].clone()
81 }
82
83 pub fn eval_point(&self) -> &[P::Scalar] {
84 &self.current_layer_claim.eval_point
85 }
86
87 fn current_layer_exponent_bit(&self, index: usize) -> MultilinearWitness<'a, P> {
88 self.witness.exponent[index].clone()
89 }
90
91 pub fn is_last_layer(&self, layer_no: usize) -> bool {
92 self.exponent_bit_width() - 1 - layer_no == 0
93 }
94}
95
96pub struct StaticExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
97
98impl<'a, P: PackedField> StaticExpProver<'a, P> {
99 pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
100 if witness.uses_dynamic_base() {
101 bail!(Error::IncorrectWitnessType);
102 }
103
104 Ok(Self(ExpCommonProver::new(witness, claim.clone())))
105 }
106}
107
108impl<'a, P> ExpProver<'a, P> for StaticExpProver<'a, P>
109where
110 P::Scalar: BinaryField,
111 P: PackedField,
112{
113 fn exponent_bit_width(&self) -> usize {
114 self.0.exponent_bit_width()
115 }
116
117 fn layer_composite_sum_claim(
118 &self,
119 layer_no: usize,
120 composite_claims_n_multilinears: usize,
121 multilinears_index: usize,
122 ) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
123 if self.0.is_last_layer(layer_no) {
124 return Ok(None);
125 }
126
127 let internal_layer_index = self.exponent_bit_width() - 1 - layer_no;
128
129 let this_layer_input = self.0.current_layer_single_bit_output_layers_data(layer_no);
130
131 let exponent_bit = self
132 .0
133 .current_layer_exponent_bit(internal_layer_index)
134 .clone();
135
136 let this_layer_multilinears = vec![this_layer_input, exponent_bit];
137
138 let this_layer_input_index = multilinears_index;
139 let exponent_bit_index = multilinears_index + 1;
140
141 let base = match self.0.witness.base.clone() {
142 BaseWitness::Static(base) => base,
143 _ => unreachable!("witness must contain static base"),
144 };
145
146 let base_power_static = base.pow(1 << internal_layer_index);
147
148 let composition = IndexComposition::new(
149 composite_claims_n_multilinears,
150 [this_layer_input_index, exponent_bit_index],
151 ExpCompositions::StaticBase { base_power_static },
152 )?;
153
154 let composition = FixedDimIndexCompositions::Bivariate(composition);
155
156 let this_layer_composite_claim = CompositeSumClaim {
157 sum: self.0.current_layer_claim.eval,
158 composition,
159 };
160
161 Ok(Some(CompositeSumClaimWithMultilinears {
162 claim: this_layer_composite_claim,
163 multilinears: this_layer_multilinears,
164 }))
165 }
166
167 fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
168 self.0.eval_point()
169 }
170
171 fn finish_layer(
172 &mut self,
173 evaluation_order: EvaluationOrder,
174 layer_no: usize,
175 multilinear_evals: &[P::Scalar],
176 r: &[P::Scalar],
177 ) -> Vec<LayerClaim<P::Scalar>> {
178 let exponent_bit_claim = if self.is_last_layer(layer_no) {
179 let LayerClaim { eval_point, eval } = self.0.current_layer_claim.clone();
182
183 let base = match self.0.witness.base.clone() {
184 BaseWitness::Static(base) => base,
185 _ => unreachable!("witness must contain static base"),
186 };
187
188 LayerClaim::<P::Scalar> {
189 eval_point,
190 eval: first_layer_inverse(eval, base),
191 }
192 } else {
193 let n_vars = self.layer_claim_eval_point().len();
194
195 let this_layer_input_eval = multilinear_evals[0];
196
197 let exponent_bit_eval = multilinear_evals[1];
198
199 let eval_point = match evaluation_order {
200 EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
201 EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
202 };
203
204 if !self.is_last_layer(layer_no) {
205 self.0.current_layer_claim = LayerClaim {
206 eval: this_layer_input_eval,
207 eval_point: eval_point.clone(),
208 };
209 }
210
211 LayerClaim {
212 eval: exponent_bit_eval,
213 eval_point,
214 }
215 };
216
217 vec![exponent_bit_claim]
218 }
219
220 fn layer_n_multilinears(&self, layer_no: usize) -> usize {
221 if self.is_last_layer(layer_no) {
222 0
223 } else {
224 2
226 }
227 }
228
229 fn layer_n_claims(&self, layer_no: usize) -> usize {
230 if self.is_last_layer(layer_no) {
231 0
232 } else {
233 1
234 }
235 }
236}
237
238pub struct DynamicBaseExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
239
240impl<'a, P: PackedField> DynamicBaseExpProver<'a, P> {
241 pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
242 if !witness.uses_dynamic_base() {
243 bail!(Error::IncorrectWitnessType);
244 }
245
246 Ok(Self(ExpCommonProver::new(witness, claim.clone())))
247 }
248}
249
250pub struct CompositeSumClaimWithMultilinears<'a, P: PackedField> {
251 pub claim: CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>,
252 pub multilinears: Vec<MultilinearWitness<'a, P>>,
253}
254
255impl<'a, P: PackedField> ExpProver<'a, P> for DynamicBaseExpProver<'a, P> {
256 fn exponent_bit_width(&self) -> usize {
257 self.0.exponent_bit_width()
258 }
259
260 fn layer_composite_sum_claim(
261 &self,
262 layer_no: usize,
263 composite_claims_n_multilinears: usize,
264 multilinears_index: usize,
265 ) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
266 let base = match self.0.witness.base.clone() {
267 BaseWitness::Dynamic(base) => base,
268 _ => unreachable!("DynamicBase witness must contain base"),
269 };
270
271 let exponent_bit = self.0.current_layer_exponent_bit(layer_no);
272
273 let (composition, this_layer_multilinears) = if self.0.is_last_layer(layer_no) {
274 let this_layer_multilinears = vec![base, exponent_bit];
275
276 let base_index = multilinears_index;
277 let exponent_bit_index = multilinears_index + 1;
278
279 let composition = IndexComposition::new(
280 composite_claims_n_multilinears,
281 [base_index, exponent_bit_index],
282 ExpCompositions::DynamicBaseLastLayer,
283 )?;
284 let composition = FixedDimIndexCompositions::Bivariate(composition);
285 (composition, this_layer_multilinears)
286 } else {
287 let this_layer_input = self
288 .0
289 .current_layer_single_bit_output_layers_data(layer_no)
290 .clone();
291
292 let this_layer_multilinears = vec![this_layer_input, exponent_bit, base];
293
294 let this_layer_input_index = multilinears_index;
295 let exponent_bit_index = multilinears_index + 1;
296 let base_index = multilinears_index + 2;
297
298 let composition = IndexComposition::new(
299 composite_claims_n_multilinears,
300 [this_layer_input_index, exponent_bit_index, base_index],
301 ExpCompositions::DynamicBase,
302 )?;
303 let composition = FixedDimIndexCompositions::Trivariate(composition);
304 (composition, this_layer_multilinears)
305 };
306
307 let this_layer_composite_claim = CompositeSumClaim {
308 sum: self.0.current_layer_claim.eval,
309 composition,
310 };
311
312 Ok(Some(CompositeSumClaimWithMultilinears {
313 claim: this_layer_composite_claim,
314 multilinears: this_layer_multilinears,
315 }))
316 }
317
318 fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
319 self.0.eval_point()
320 }
321
322 fn finish_layer(
323 &mut self,
324 evaluation_order: EvaluationOrder,
325 layer_no: usize,
326 multilinear_evals: &[P::Scalar],
327 r: &[P::Scalar],
328 ) -> Vec<LayerClaim<P::Scalar>> {
329 let n_vars = self.0.eval_point().len();
330
331 let eval_point = match evaluation_order {
332 EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
333 EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
334 };
335
336 let mut claims = Vec::with_capacity(2);
337
338 let exponent_bit_eval = multilinear_evals[1];
339
340 let exponent_bit_claim = LayerClaim {
341 eval: exponent_bit_eval,
342 eval_point: eval_point.clone(),
343 };
344
345 claims.push(exponent_bit_claim);
346
347 if self.is_last_layer(layer_no) {
348 let base_eval = multilinear_evals[0];
349
350 let base_claim = LayerClaim {
351 eval: base_eval,
352 eval_point,
353 };
354 claims.push(base_claim)
355 } else {
356 let this_layer_input_eval = multilinear_evals[0];
357
358 self.0.current_layer_claim = LayerClaim {
359 eval: this_layer_input_eval,
360 eval_point: eval_point.clone(),
361 };
362
363 let base_eval = multilinear_evals[2];
364
365 let base_claim = LayerClaim {
366 eval: base_eval,
367 eval_point,
368 };
369
370 claims.push(base_claim)
371 }
372
373 claims
374 }
375
376 fn layer_n_multilinears(&self, layer_no: usize) -> usize {
377 if self.is_last_layer(layer_no) {
378 2
380 } else {
381 3
383 }
384 }
385
386 fn layer_n_claims(&self, _layer_no: usize) -> usize {
387 1
388 }
389}