binius_core/protocols/gkr_exp/
provers.rs1use binius_field::{BinaryField, PackedField};
4use binius_math::EvaluationOrder;
5use binius_utils::bail;
6
7use super::{
8 common::{ExpClaim, LayerClaim},
9 compositions::{ExpCompositions, IndexedExpComposition},
10 error::Error,
11 utils::first_layer_inverse,
12 witness::{BaseExpWitness, BaseWitness},
13};
14use crate::{
15 composition::{FixedDimIndexCompositions, IndexComposition},
16 protocols::sumcheck::CompositeSumClaim,
17 witness::MultilinearWitness,
18};
19
20pub trait ExpProver<'a, P: PackedField> {
21 fn exponent_bit_width(&self) -> usize;
22
23 fn is_last_layer(&self, layer_no: usize) -> bool {
24 self.exponent_bit_width() - 1 - layer_no == 0
25 }
26
27 fn layer_claim_eval_point(&self) -> &[P::Scalar];
29
30 fn layer_composite_sum_claim(
34 &self,
35 layer_no: usize,
36 composite_claims_n_multilinears: usize,
37 multilinears_index: usize,
38 ) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error>;
39
40 fn layer_n_multilinears(&self, layer_no: usize) -> usize;
42
43 fn layer_n_claims(&self, layer_no: usize) -> usize;
45
46 fn finish_layer(
49 &mut self,
50 evaluation_order: EvaluationOrder,
51 layer_no: usize,
52 multilinear_evals: &[P::Scalar],
53 r: &[P::Scalar],
54 ) -> Vec<LayerClaim<P::Scalar>>;
55}
56
57struct ExpCommonProver<'a, P: PackedField> {
58 witness: BaseExpWitness<'a, P>,
59 current_layer_claim: LayerClaim<P::Scalar>,
60}
61
62impl<'a, P: PackedField> ExpCommonProver<'a, P> {
63 fn new(witness: BaseExpWitness<'a, P>, claim: ExpClaim<P::Scalar>) -> Self {
64 Self {
65 witness,
66 current_layer_claim: claim.into(),
67 }
68 }
69
70 pub fn exponent_bit_width(&self) -> usize {
71 self.witness.exponent.len()
72 }
73
74 fn current_layer_single_bit_output_layers_data(
75 &self,
76 layer_no: usize,
77 ) -> MultilinearWitness<'a, P> {
78 let index = self.witness.single_bit_output_layers_data.len() - layer_no - 2;
79
80 self.witness.single_bit_output_layers_data[index].clone()
81 }
82
83 pub fn eval_point(&self) -> &[P::Scalar] {
84 &self.current_layer_claim.eval_point
85 }
86
87 fn current_layer_exponent_bit(&self, index: usize) -> MultilinearWitness<'a, P> {
88 self.witness.exponent[index].clone()
89 }
90
91 pub fn is_last_layer(&self, layer_no: usize) -> bool {
92 self.exponent_bit_width() - 1 - layer_no == 0
93 }
94}
95
96pub struct StaticExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
97
98impl<'a, P: PackedField> StaticExpProver<'a, P> {
99 pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
100 if witness.uses_dynamic_base() {
101 bail!(Error::IncorrectWitnessType);
102 }
103
104 Ok(Self(ExpCommonProver::new(witness, claim.clone())))
105 }
106}
107
108impl<'a, P> ExpProver<'a, P> for StaticExpProver<'a, P>
109where
110 P::Scalar: BinaryField,
111 P: PackedField,
112{
113 fn exponent_bit_width(&self) -> usize {
114 self.0.exponent_bit_width()
115 }
116
117 fn layer_composite_sum_claim(
118 &self,
119 layer_no: usize,
120 composite_claims_n_multilinears: usize,
121 multilinears_index: usize,
122 ) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
123 if self.0.is_last_layer(layer_no) {
124 return Ok(None);
125 }
126
127 let internal_layer_index = self.exponent_bit_width() - 1 - layer_no;
128
129 let this_layer_input = self.0.current_layer_single_bit_output_layers_data(layer_no);
130
131 let exponent_bit = self
132 .0
133 .current_layer_exponent_bit(internal_layer_index)
134 .clone();
135
136 let this_layer_multilinears = vec![this_layer_input, exponent_bit];
137
138 let this_layer_input_index = multilinears_index;
139 let exponent_bit_index = multilinears_index + 1;
140
141 let base = match self.0.witness.base.clone() {
142 BaseWitness::Static(base) => base,
143 _ => unreachable!("witness must contain static base"),
144 };
145
146 let base_power_static = base.pow(1 << internal_layer_index);
147
148 let composition = IndexComposition::new(
149 composite_claims_n_multilinears,
150 [this_layer_input_index, exponent_bit_index],
151 ExpCompositions::StaticBase { base_power_static },
152 )?;
153
154 let composition = FixedDimIndexCompositions::Bivariate(composition);
155
156 let this_layer_composite_claim = CompositeSumClaim {
157 sum: self.0.current_layer_claim.eval,
158 composition,
159 };
160
161 Ok(Some(CompositeSumClaimWithMultilinears {
162 claim: this_layer_composite_claim,
163 multilinears: this_layer_multilinears,
164 }))
165 }
166
167 fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
168 self.0.eval_point()
169 }
170
171 fn finish_layer(
172 &mut self,
173 evaluation_order: EvaluationOrder,
174 layer_no: usize,
175 multilinear_evals: &[P::Scalar],
176 r: &[P::Scalar],
177 ) -> Vec<LayerClaim<P::Scalar>> {
178 let exponent_bit_claim = if self.is_last_layer(layer_no) {
179 let LayerClaim { eval_point, eval } = self.0.current_layer_claim.clone();
182
183 let base = match self.0.witness.base.clone() {
184 BaseWitness::Static(base) => base,
185 _ => unreachable!("witness must contain static base"),
186 };
187
188 LayerClaim::<P::Scalar> {
189 eval_point,
190 eval: first_layer_inverse(eval, base),
191 }
192 } else {
193 let n_vars = self.layer_claim_eval_point().len();
194
195 let this_layer_input_eval = multilinear_evals[0];
196
197 let exponent_bit_eval = multilinear_evals[1];
198
199 let eval_point = match evaluation_order {
200 EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
201 EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
202 };
203
204 if !self.is_last_layer(layer_no) {
205 self.0.current_layer_claim = LayerClaim {
206 eval: this_layer_input_eval,
207 eval_point: eval_point.clone(),
208 };
209 }
210
211 LayerClaim {
212 eval: exponent_bit_eval,
213 eval_point,
214 }
215 };
216
217 vec![exponent_bit_claim]
218 }
219
220 fn layer_n_multilinears(&self, layer_no: usize) -> usize {
221 if self.is_last_layer(layer_no) {
222 0
223 } else {
224 2
226 }
227 }
228
229 fn layer_n_claims(&self, layer_no: usize) -> usize {
230 if self.is_last_layer(layer_no) { 0 } else { 1 }
231 }
232}
233
234pub struct DynamicBaseExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
235
236impl<'a, P: PackedField> DynamicBaseExpProver<'a, P> {
237 pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
238 if !witness.uses_dynamic_base() {
239 bail!(Error::IncorrectWitnessType);
240 }
241
242 Ok(Self(ExpCommonProver::new(witness, claim.clone())))
243 }
244}
245
246pub struct CompositeSumClaimWithMultilinears<'a, P: PackedField> {
247 pub claim: CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>,
248 pub multilinears: Vec<MultilinearWitness<'a, P>>,
249}
250
251impl<'a, P: PackedField> ExpProver<'a, P> for DynamicBaseExpProver<'a, P> {
252 fn exponent_bit_width(&self) -> usize {
253 self.0.exponent_bit_width()
254 }
255
256 fn layer_composite_sum_claim(
257 &self,
258 layer_no: usize,
259 composite_claims_n_multilinears: usize,
260 multilinears_index: usize,
261 ) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
262 let base = match self.0.witness.base.clone() {
263 BaseWitness::Dynamic(base) => base,
264 _ => unreachable!("DynamicBase witness must contain base"),
265 };
266
267 let exponent_bit = self.0.current_layer_exponent_bit(layer_no);
268
269 let (composition, this_layer_multilinears) = if self.0.is_last_layer(layer_no) {
270 let this_layer_multilinears = vec![base, exponent_bit];
271
272 let base_index = multilinears_index;
273 let exponent_bit_index = multilinears_index + 1;
274
275 let composition = IndexComposition::new(
276 composite_claims_n_multilinears,
277 [base_index, exponent_bit_index],
278 ExpCompositions::DynamicBaseLastLayer,
279 )?;
280 let composition = FixedDimIndexCompositions::Bivariate(composition);
281 (composition, this_layer_multilinears)
282 } else {
283 let this_layer_input = self
284 .0
285 .current_layer_single_bit_output_layers_data(layer_no)
286 .clone();
287
288 let this_layer_multilinears = vec![this_layer_input, exponent_bit, base];
289
290 let this_layer_input_index = multilinears_index;
291 let exponent_bit_index = multilinears_index + 1;
292 let base_index = multilinears_index + 2;
293
294 let composition = IndexComposition::new(
295 composite_claims_n_multilinears,
296 [this_layer_input_index, exponent_bit_index, base_index],
297 ExpCompositions::DynamicBase,
298 )?;
299 let composition = FixedDimIndexCompositions::Trivariate(composition);
300 (composition, this_layer_multilinears)
301 };
302
303 let this_layer_composite_claim = CompositeSumClaim {
304 sum: self.0.current_layer_claim.eval,
305 composition,
306 };
307
308 Ok(Some(CompositeSumClaimWithMultilinears {
309 claim: this_layer_composite_claim,
310 multilinears: this_layer_multilinears,
311 }))
312 }
313
314 fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
315 self.0.eval_point()
316 }
317
318 fn finish_layer(
319 &mut self,
320 evaluation_order: EvaluationOrder,
321 layer_no: usize,
322 multilinear_evals: &[P::Scalar],
323 r: &[P::Scalar],
324 ) -> Vec<LayerClaim<P::Scalar>> {
325 let n_vars = self.0.eval_point().len();
326
327 let eval_point = match evaluation_order {
328 EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
329 EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
330 };
331
332 let mut claims = Vec::with_capacity(2);
333
334 let exponent_bit_eval = multilinear_evals[1];
335
336 let exponent_bit_claim = LayerClaim {
337 eval: exponent_bit_eval,
338 eval_point: eval_point.clone(),
339 };
340
341 claims.push(exponent_bit_claim);
342
343 if self.is_last_layer(layer_no) {
344 let base_eval = multilinear_evals[0];
345
346 let base_claim = LayerClaim {
347 eval: base_eval,
348 eval_point,
349 };
350 claims.push(base_claim)
351 } else {
352 let this_layer_input_eval = multilinear_evals[0];
353
354 self.0.current_layer_claim = LayerClaim {
355 eval: this_layer_input_eval,
356 eval_point: eval_point.clone(),
357 };
358
359 let base_eval = multilinear_evals[2];
360
361 let base_claim = LayerClaim {
362 eval: base_eval,
363 eval_point,
364 };
365
366 claims.push(base_claim)
367 }
368
369 claims
370 }
371
372 fn layer_n_multilinears(&self, layer_no: usize) -> usize {
373 if self.is_last_layer(layer_no) {
374 2
376 } else {
377 3
379 }
380 }
381
382 fn layer_n_claims(&self, _layer_no: usize) -> usize {
383 1
384 }
385}