binius_core/protocols/gkr_exp/
provers.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::EvaluationOrder;
5use binius_utils::bail;
6
7use super::{
8	common::{ExpClaim, LayerClaim},
9	compositions::{ExpCompositions, IndexedExpComposition},
10	error::Error,
11	utils::first_layer_inverse,
12	witness::{BaseExpWitness, BaseWitness},
13};
14use crate::{
15	composition::{FixedDimIndexCompositions, IndexComposition},
16	protocols::sumcheck::CompositeSumClaim,
17	witness::MultilinearWitness,
18};
19
20pub trait ExpProver<'a, P: PackedField> {
21	fn exponent_bit_width(&self) -> usize;
22
23	fn is_last_layer(&self, layer_no: usize) -> bool {
24		self.exponent_bit_width() - 1 - layer_no == 0
25	}
26
27	/// return the eval_point of the internal [LayerClaim].
28	fn layer_claim_eval_point(&self) -> &[P::Scalar];
29
30	/// return [CompositeSumClaim] and multilinears that it contains,
31	/// If the prover does not participate in the sumcheck for this layer,
32	/// the function returns `None`.
33	fn layer_composite_sum_claim(
34		&self,
35		layer_no: usize,
36		composite_claims_n_multilinears: usize,
37		multilinears_index: usize,
38	) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error>;
39
40	/// return a tuple of the number of multilinears used by this prover for this layer.
41	fn layer_n_multilinears(&self, layer_no: usize) -> usize;
42
43	/// return a tuple of the number of sumcheck claims used by this prover for this layer.
44	fn layer_n_claims(&self, layer_no: usize) -> usize;
45
46	/// update the prover internal [LayerClaim] and return the [LayerClaim]s of multilinears,
47	/// excluding `this_layer_input`.
48	fn finish_layer(
49		&mut self,
50		evaluation_order: EvaluationOrder,
51		layer_no: usize,
52		multilinear_evals: &[P::Scalar],
53		r: &[P::Scalar],
54	) -> Vec<LayerClaim<P::Scalar>>;
55}
56
57struct ExpCommonProver<'a, P: PackedField> {
58	witness: BaseExpWitness<'a, P>,
59	current_layer_claim: LayerClaim<P::Scalar>,
60}
61
62impl<'a, P: PackedField> ExpCommonProver<'a, P> {
63	fn new(witness: BaseExpWitness<'a, P>, claim: ExpClaim<P::Scalar>) -> Self {
64		Self {
65			witness,
66			current_layer_claim: claim.into(),
67		}
68	}
69
70	pub fn exponent_bit_width(&self) -> usize {
71		self.witness.exponent.len()
72	}
73
74	fn current_layer_single_bit_output_layers_data(
75		&self,
76		layer_no: usize,
77	) -> MultilinearWitness<'a, P> {
78		let index = self.witness.single_bit_output_layers_data.len() - layer_no - 2;
79
80		self.witness.single_bit_output_layers_data[index].clone()
81	}
82
83	pub fn eval_point(&self) -> &[P::Scalar] {
84		&self.current_layer_claim.eval_point
85	}
86
87	fn current_layer_exponent_bit(&self, index: usize) -> MultilinearWitness<'a, P> {
88		self.witness.exponent[index].clone()
89	}
90
91	pub fn is_last_layer(&self, layer_no: usize) -> bool {
92		self.exponent_bit_width() - 1 - layer_no == 0
93	}
94}
95
96pub struct StaticExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
97
98impl<'a, P: PackedField> StaticExpProver<'a, P> {
99	pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
100		if witness.uses_dynamic_base() {
101			bail!(Error::IncorrectWitnessType);
102		}
103
104		Ok(Self(ExpCommonProver::new(witness, claim.clone())))
105	}
106}
107
108impl<'a, P> ExpProver<'a, P> for StaticExpProver<'a, P>
109where
110	P::Scalar: BinaryField,
111	P: PackedField,
112{
113	fn exponent_bit_width(&self) -> usize {
114		self.0.exponent_bit_width()
115	}
116
117	fn layer_composite_sum_claim(
118		&self,
119		layer_no: usize,
120		composite_claims_n_multilinears: usize,
121		multilinears_index: usize,
122	) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
123		if self.0.is_last_layer(layer_no) {
124			return Ok(None);
125		}
126
127		let internal_layer_index = self.exponent_bit_width() - 1 - layer_no;
128
129		let this_layer_input = self.0.current_layer_single_bit_output_layers_data(layer_no);
130
131		let exponent_bit = self
132			.0
133			.current_layer_exponent_bit(internal_layer_index)
134			.clone();
135
136		let this_layer_multilinears = vec![this_layer_input, exponent_bit];
137
138		let this_layer_input_index = multilinears_index;
139		let exponent_bit_index = multilinears_index + 1;
140
141		let base = match self.0.witness.base.clone() {
142			BaseWitness::Static(base) => base,
143			_ => unreachable!("witness must contain static base"),
144		};
145
146		let base_power_static = base.pow(1 << internal_layer_index);
147
148		let composition = IndexComposition::new(
149			composite_claims_n_multilinears,
150			[this_layer_input_index, exponent_bit_index],
151			ExpCompositions::StaticBase { base_power_static },
152		)?;
153
154		let composition = FixedDimIndexCompositions::Bivariate(composition);
155
156		let this_layer_composite_claim = CompositeSumClaim {
157			sum: self.0.current_layer_claim.eval,
158			composition,
159		};
160
161		Ok(Some(CompositeSumClaimWithMultilinears {
162			claim: this_layer_composite_claim,
163			multilinears: this_layer_multilinears,
164		}))
165	}
166
167	fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
168		self.0.eval_point()
169	}
170
171	fn finish_layer(
172		&mut self,
173		evaluation_order: EvaluationOrder,
174		layer_no: usize,
175		multilinear_evals: &[P::Scalar],
176		r: &[P::Scalar],
177	) -> Vec<LayerClaim<P::Scalar>> {
178		let exponent_bit_claim = if self.is_last_layer(layer_no) {
179			// the evaluation of the last exponent bit can be uniquely calculated from the previous
180			// exponentiation layer claim. $a_0(x) = (V_0(x) - 1)/(g - 1)$
181			let LayerClaim { eval_point, eval } = self.0.current_layer_claim.clone();
182
183			let base = match self.0.witness.base.clone() {
184				BaseWitness::Static(base) => base,
185				_ => unreachable!("witness must contain static base"),
186			};
187
188			LayerClaim::<P::Scalar> {
189				eval_point,
190				eval: first_layer_inverse(eval, base),
191			}
192		} else {
193			let n_vars = self.layer_claim_eval_point().len();
194
195			let this_layer_input_eval = multilinear_evals[0];
196
197			let exponent_bit_eval = multilinear_evals[1];
198
199			let eval_point = match evaluation_order {
200				EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
201				EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
202			};
203
204			if !self.is_last_layer(layer_no) {
205				self.0.current_layer_claim = LayerClaim {
206					eval: this_layer_input_eval,
207					eval_point: eval_point.clone(),
208				};
209			}
210
211			LayerClaim {
212				eval: exponent_bit_eval,
213				eval_point,
214			}
215		};
216
217		vec![exponent_bit_claim]
218	}
219
220	fn layer_n_multilinears(&self, layer_no: usize) -> usize {
221		if self.is_last_layer(layer_no) {
222			0
223		} else {
224			// this_layer_input, exponent_bit
225			2
226		}
227	}
228
229	fn layer_n_claims(&self, layer_no: usize) -> usize {
230		if self.is_last_layer(layer_no) { 0 } else { 1 }
231	}
232}
233
234pub struct DynamicBaseExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
235
236impl<'a, P: PackedField> DynamicBaseExpProver<'a, P> {
237	pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
238		if !witness.uses_dynamic_base() {
239			bail!(Error::IncorrectWitnessType);
240		}
241
242		Ok(Self(ExpCommonProver::new(witness, claim.clone())))
243	}
244}
245
246pub struct CompositeSumClaimWithMultilinears<'a, P: PackedField> {
247	pub claim: CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>,
248	pub multilinears: Vec<MultilinearWitness<'a, P>>,
249}
250
251impl<'a, P: PackedField> ExpProver<'a, P> for DynamicBaseExpProver<'a, P> {
252	fn exponent_bit_width(&self) -> usize {
253		self.0.exponent_bit_width()
254	}
255
256	fn layer_composite_sum_claim(
257		&self,
258		layer_no: usize,
259		composite_claims_n_multilinears: usize,
260		multilinears_index: usize,
261	) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
262		let base = match self.0.witness.base.clone() {
263			BaseWitness::Dynamic(base) => base,
264			_ => unreachable!("DynamicBase witness must contain base"),
265		};
266
267		let exponent_bit = self.0.current_layer_exponent_bit(layer_no);
268
269		let (composition, this_layer_multilinears) = if self.0.is_last_layer(layer_no) {
270			let this_layer_multilinears = vec![base, exponent_bit];
271
272			let base_index = multilinears_index;
273			let exponent_bit_index = multilinears_index + 1;
274
275			let composition = IndexComposition::new(
276				composite_claims_n_multilinears,
277				[base_index, exponent_bit_index],
278				ExpCompositions::DynamicBaseLastLayer,
279			)?;
280			let composition = FixedDimIndexCompositions::Bivariate(composition);
281			(composition, this_layer_multilinears)
282		} else {
283			let this_layer_input = self
284				.0
285				.current_layer_single_bit_output_layers_data(layer_no)
286				.clone();
287
288			let this_layer_multilinears = vec![this_layer_input, exponent_bit, base];
289
290			let this_layer_input_index = multilinears_index;
291			let exponent_bit_index = multilinears_index + 1;
292			let base_index = multilinears_index + 2;
293
294			let composition = IndexComposition::new(
295				composite_claims_n_multilinears,
296				[this_layer_input_index, exponent_bit_index, base_index],
297				ExpCompositions::DynamicBase,
298			)?;
299			let composition = FixedDimIndexCompositions::Trivariate(composition);
300			(composition, this_layer_multilinears)
301		};
302
303		let this_layer_composite_claim = CompositeSumClaim {
304			sum: self.0.current_layer_claim.eval,
305			composition,
306		};
307
308		Ok(Some(CompositeSumClaimWithMultilinears {
309			claim: this_layer_composite_claim,
310			multilinears: this_layer_multilinears,
311		}))
312	}
313
314	fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
315		self.0.eval_point()
316	}
317
318	fn finish_layer(
319		&mut self,
320		evaluation_order: EvaluationOrder,
321		layer_no: usize,
322		multilinear_evals: &[P::Scalar],
323		r: &[P::Scalar],
324	) -> Vec<LayerClaim<P::Scalar>> {
325		let n_vars = self.0.eval_point().len();
326
327		let eval_point = match evaluation_order {
328			EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
329			EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
330		};
331
332		let mut claims = Vec::with_capacity(2);
333
334		let exponent_bit_eval = multilinear_evals[1];
335
336		let exponent_bit_claim = LayerClaim {
337			eval: exponent_bit_eval,
338			eval_point: eval_point.clone(),
339		};
340
341		claims.push(exponent_bit_claim);
342
343		if self.is_last_layer(layer_no) {
344			let base_eval = multilinear_evals[0];
345
346			let base_claim = LayerClaim {
347				eval: base_eval,
348				eval_point,
349			};
350			claims.push(base_claim)
351		} else {
352			let this_layer_input_eval = multilinear_evals[0];
353
354			self.0.current_layer_claim = LayerClaim {
355				eval: this_layer_input_eval,
356				eval_point: eval_point.clone(),
357			};
358
359			let base_eval = multilinear_evals[2];
360
361			let base_claim = LayerClaim {
362				eval: base_eval,
363				eval_point,
364			};
365
366			claims.push(base_claim)
367		}
368
369		claims
370	}
371
372	fn layer_n_multilinears(&self, layer_no: usize) -> usize {
373		if self.is_last_layer(layer_no) {
374			// base, exponent_bit
375			2
376		} else {
377			// this_layer_input, exponent_bit, base
378			3
379		}
380	}
381
382	fn layer_n_claims(&self, _layer_no: usize) -> usize {
383		1
384	}
385}