binius_core/protocols/gkr_exp/
provers.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::EvaluationOrder;
5use binius_utils::bail;
6
7use super::{
8	common::{ExpClaim, LayerClaim},
9	compositions::{ExpCompositions, IndexedExpComposition},
10	error::Error,
11	utils::first_layer_inverse,
12	witness::{BaseExpWitness, BaseWitness},
13};
14use crate::{
15	composition::{FixedDimIndexCompositions, IndexComposition},
16	protocols::sumcheck::CompositeSumClaim,
17	witness::MultilinearWitness,
18};
19
20pub trait ExpProver<'a, P: PackedField> {
21	fn exponent_bit_width(&self) -> usize;
22
23	fn is_last_layer(&self, layer_no: usize) -> bool {
24		self.exponent_bit_width() - 1 - layer_no == 0
25	}
26
27	/// return the eval_point of the internal [LayerClaim].
28	fn layer_claim_eval_point(&self) -> &[P::Scalar];
29
30	/// return [CompositeSumClaim] and multilinears that it contains,
31	/// If the prover does not participate in the sumcheck for this layer,
32	/// the function returns `None`.
33	fn layer_composite_sum_claim(
34		&self,
35		layer_no: usize,
36		composite_claims_n_multilinears: usize,
37		multilinears_index: usize,
38	) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error>;
39
40	/// return a tuple of the number of multilinears used by this prover for this layer.
41	fn layer_n_multilinears(&self, layer_no: usize) -> usize;
42
43	/// return a tuple of the number of sumcheck claims used by this prover for this layer.
44	fn layer_n_claims(&self, layer_no: usize) -> usize;
45
46	/// update the prover internal [LayerClaim] and return the [LayerClaim]s of multilinears,
47	/// excluding `this_layer_input`.
48	fn finish_layer(
49		&mut self,
50		evaluation_order: EvaluationOrder,
51		layer_no: usize,
52		multilinear_evals: &[P::Scalar],
53		r: &[P::Scalar],
54	) -> Vec<LayerClaim<P::Scalar>>;
55}
56
57struct ExpCommonProver<'a, P: PackedField> {
58	witness: BaseExpWitness<'a, P>,
59	current_layer_claim: LayerClaim<P::Scalar>,
60}
61
62impl<'a, P: PackedField> ExpCommonProver<'a, P> {
63	fn new(witness: BaseExpWitness<'a, P>, claim: ExpClaim<P::Scalar>) -> Self {
64		Self {
65			witness,
66			current_layer_claim: claim.into(),
67		}
68	}
69
70	pub fn exponent_bit_width(&self) -> usize {
71		self.witness.exponent.len()
72	}
73
74	fn current_layer_single_bit_output_layers_data(
75		&self,
76		layer_no: usize,
77	) -> MultilinearWitness<'a, P> {
78		let index = self.witness.single_bit_output_layers_data.len() - layer_no - 2;
79
80		self.witness.single_bit_output_layers_data[index].clone()
81	}
82
83	pub fn eval_point(&self) -> &[P::Scalar] {
84		&self.current_layer_claim.eval_point
85	}
86
87	fn current_layer_exponent_bit(&self, index: usize) -> MultilinearWitness<'a, P> {
88		self.witness.exponent[index].clone()
89	}
90
91	pub fn is_last_layer(&self, layer_no: usize) -> bool {
92		self.exponent_bit_width() - 1 - layer_no == 0
93	}
94}
95
96pub struct StaticExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
97
98impl<'a, P: PackedField> StaticExpProver<'a, P> {
99	pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
100		if witness.uses_dynamic_base() {
101			bail!(Error::IncorrectWitnessType);
102		}
103
104		Ok(Self(ExpCommonProver::new(witness, claim.clone())))
105	}
106}
107
108impl<'a, P> ExpProver<'a, P> for StaticExpProver<'a, P>
109where
110	P::Scalar: BinaryField,
111	P: PackedField,
112{
113	fn exponent_bit_width(&self) -> usize {
114		self.0.exponent_bit_width()
115	}
116
117	fn layer_composite_sum_claim(
118		&self,
119		layer_no: usize,
120		composite_claims_n_multilinears: usize,
121		multilinears_index: usize,
122	) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
123		if self.0.is_last_layer(layer_no) {
124			return Ok(None);
125		}
126
127		let internal_layer_index = self.exponent_bit_width() - 1 - layer_no;
128
129		let this_layer_input = self.0.current_layer_single_bit_output_layers_data(layer_no);
130
131		let exponent_bit = self
132			.0
133			.current_layer_exponent_bit(internal_layer_index)
134			.clone();
135
136		let this_layer_multilinears = vec![this_layer_input, exponent_bit];
137
138		let this_layer_input_index = multilinears_index;
139		let exponent_bit_index = multilinears_index + 1;
140
141		let base = match self.0.witness.base.clone() {
142			BaseWitness::Static(base) => base,
143			_ => unreachable!("witness must contain static base"),
144		};
145
146		let base_power_static = base.pow(1 << internal_layer_index);
147
148		let composition = IndexComposition::new(
149			composite_claims_n_multilinears,
150			[this_layer_input_index, exponent_bit_index],
151			ExpCompositions::StaticBase { base_power_static },
152		)?;
153
154		let composition = FixedDimIndexCompositions::Bivariate(composition);
155
156		let this_layer_composite_claim = CompositeSumClaim {
157			sum: self.0.current_layer_claim.eval,
158			composition,
159		};
160
161		Ok(Some(CompositeSumClaimWithMultilinears {
162			claim: this_layer_composite_claim,
163			multilinears: this_layer_multilinears,
164		}))
165	}
166
167	fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
168		self.0.eval_point()
169	}
170
171	fn finish_layer(
172		&mut self,
173		evaluation_order: EvaluationOrder,
174		layer_no: usize,
175		multilinear_evals: &[P::Scalar],
176		r: &[P::Scalar],
177	) -> Vec<LayerClaim<P::Scalar>> {
178		let exponent_bit_claim = if self.is_last_layer(layer_no) {
179			// the evaluation of the last exponent bit can be uniquely calculated from the previous exponentiation layer claim.
180			// $a_0(x) = (V_0(x) - 1)/(g - 1)$
181			let LayerClaim { eval_point, eval } = self.0.current_layer_claim.clone();
182
183			let base = match self.0.witness.base.clone() {
184				BaseWitness::Static(base) => base,
185				_ => unreachable!("witness must contain static base"),
186			};
187
188			LayerClaim::<P::Scalar> {
189				eval_point,
190				eval: first_layer_inverse(eval, base),
191			}
192		} else {
193			let n_vars = self.layer_claim_eval_point().len();
194
195			let this_layer_input_eval = multilinear_evals[0];
196
197			let exponent_bit_eval = multilinear_evals[1];
198
199			let eval_point = match evaluation_order {
200				EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
201				EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
202			};
203
204			if !self.is_last_layer(layer_no) {
205				self.0.current_layer_claim = LayerClaim {
206					eval: this_layer_input_eval,
207					eval_point: eval_point.clone(),
208				};
209			}
210
211			LayerClaim {
212				eval: exponent_bit_eval,
213				eval_point,
214			}
215		};
216
217		vec![exponent_bit_claim]
218	}
219
220	fn layer_n_multilinears(&self, layer_no: usize) -> usize {
221		if self.is_last_layer(layer_no) {
222			0
223		} else {
224			// this_layer_input, exponent_bit
225			2
226		}
227	}
228
229	fn layer_n_claims(&self, layer_no: usize) -> usize {
230		if self.is_last_layer(layer_no) {
231			0
232		} else {
233			1
234		}
235	}
236}
237
238pub struct DynamicBaseExpProver<'a, P: PackedField>(ExpCommonProver<'a, P>);
239
240impl<'a, P: PackedField> DynamicBaseExpProver<'a, P> {
241	pub fn new(witness: BaseExpWitness<'a, P>, claim: &ExpClaim<P::Scalar>) -> Result<Self, Error> {
242		if !witness.uses_dynamic_base() {
243			bail!(Error::IncorrectWitnessType);
244		}
245
246		Ok(Self(ExpCommonProver::new(witness, claim.clone())))
247	}
248}
249
250pub struct CompositeSumClaimWithMultilinears<'a, P: PackedField> {
251	pub claim: CompositeSumClaim<P::Scalar, IndexedExpComposition<P::Scalar>>,
252	pub multilinears: Vec<MultilinearWitness<'a, P>>,
253}
254
255impl<'a, P: PackedField> ExpProver<'a, P> for DynamicBaseExpProver<'a, P> {
256	fn exponent_bit_width(&self) -> usize {
257		self.0.exponent_bit_width()
258	}
259
260	fn layer_composite_sum_claim(
261		&self,
262		layer_no: usize,
263		composite_claims_n_multilinears: usize,
264		multilinears_index: usize,
265	) -> Result<Option<CompositeSumClaimWithMultilinears<'a, P>>, Error> {
266		let base = match self.0.witness.base.clone() {
267			BaseWitness::Dynamic(base) => base,
268			_ => unreachable!("DynamicBase witness must contain base"),
269		};
270
271		let exponent_bit = self.0.current_layer_exponent_bit(layer_no);
272
273		let (composition, this_layer_multilinears) = if self.0.is_last_layer(layer_no) {
274			let this_layer_multilinears = vec![base, exponent_bit];
275
276			let base_index = multilinears_index;
277			let exponent_bit_index = multilinears_index + 1;
278
279			let composition = IndexComposition::new(
280				composite_claims_n_multilinears,
281				[base_index, exponent_bit_index],
282				ExpCompositions::DynamicBaseLastLayer,
283			)?;
284			let composition = FixedDimIndexCompositions::Bivariate(composition);
285			(composition, this_layer_multilinears)
286		} else {
287			let this_layer_input = self
288				.0
289				.current_layer_single_bit_output_layers_data(layer_no)
290				.clone();
291
292			let this_layer_multilinears = vec![this_layer_input, exponent_bit, base];
293
294			let this_layer_input_index = multilinears_index;
295			let exponent_bit_index = multilinears_index + 1;
296			let base_index = multilinears_index + 2;
297
298			let composition = IndexComposition::new(
299				composite_claims_n_multilinears,
300				[this_layer_input_index, exponent_bit_index, base_index],
301				ExpCompositions::DynamicBase,
302			)?;
303			let composition = FixedDimIndexCompositions::Trivariate(composition);
304			(composition, this_layer_multilinears)
305		};
306
307		let this_layer_composite_claim = CompositeSumClaim {
308			sum: self.0.current_layer_claim.eval,
309			composition,
310		};
311
312		Ok(Some(CompositeSumClaimWithMultilinears {
313			claim: this_layer_composite_claim,
314			multilinears: this_layer_multilinears,
315		}))
316	}
317
318	fn layer_claim_eval_point(&self) -> &[<P as PackedField>::Scalar] {
319		self.0.eval_point()
320	}
321
322	fn finish_layer(
323		&mut self,
324		evaluation_order: EvaluationOrder,
325		layer_no: usize,
326		multilinear_evals: &[P::Scalar],
327		r: &[P::Scalar],
328	) -> Vec<LayerClaim<P::Scalar>> {
329		let n_vars = self.0.eval_point().len();
330
331		let eval_point = match evaluation_order {
332			EvaluationOrder::LowToHigh => r[r.len() - n_vars..].to_vec(),
333			EvaluationOrder::HighToLow => r[..n_vars].to_vec(),
334		};
335
336		let mut claims = Vec::with_capacity(2);
337
338		let exponent_bit_eval = multilinear_evals[1];
339
340		let exponent_bit_claim = LayerClaim {
341			eval: exponent_bit_eval,
342			eval_point: eval_point.clone(),
343		};
344
345		claims.push(exponent_bit_claim);
346
347		if self.is_last_layer(layer_no) {
348			let base_eval = multilinear_evals[0];
349
350			let base_claim = LayerClaim {
351				eval: base_eval,
352				eval_point,
353			};
354			claims.push(base_claim)
355		} else {
356			let this_layer_input_eval = multilinear_evals[0];
357
358			self.0.current_layer_claim = LayerClaim {
359				eval: this_layer_input_eval,
360				eval_point: eval_point.clone(),
361			};
362
363			let base_eval = multilinear_evals[2];
364
365			let base_claim = LayerClaim {
366				eval: base_eval,
367				eval_point,
368			};
369
370			claims.push(base_claim)
371		}
372
373		claims
374	}
375
376	fn layer_n_multilinears(&self, layer_no: usize) -> usize {
377		if self.is_last_layer(layer_no) {
378			// base, exponent_bit
379			2
380		} else {
381			// this_layer_input, exponent_bit, base
382			3
383		}
384	}
385
386	fn layer_n_claims(&self, _layer_no: usize) -> usize {
387		1
388	}
389}