binius_core/protocols/gkr_exp/
witness.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::{MLEDirectAdapter, MultilinearExtension};
5use binius_maybe_rayon::{
6	iter::{IndexedParallelIterator, IntoParallelIterator},
7	prelude::ParallelIterator,
8};
9use binius_utils::bail;
10use tracing::instrument;
11
12use super::error::Error;
13use crate::{protocols::sumcheck::equal_n_vars_check, witness::MultilinearWitness};
14
15#[derive(Clone)]
16pub struct BaseExpWitness<'a, P: PackedField> {
17	/// Multilinears that represent an integers by its bits.
18	pub exponent: Vec<MultilinearWitness<'a, P>>,
19	/// Circuit layer-multilinears
20	pub single_bit_output_layers_data: Vec<MultilinearWitness<'a, P>>,
21	/// The base to be used for exponentiation.
22	pub base: BaseWitness<'a, P>,
23}
24
25#[derive(Clone)]
26pub enum BaseWitness<'a, P: PackedField> {
27	Static(P::Scalar),
28	Dynamic(MultilinearWitness<'a, P>),
29}
30
31#[instrument(skip_all, name = "gkr_exp::evaluate_single_bit_output_packed")]
32fn evaluate_single_bit_output_packed<P>(
33	exponent_bit_witness: MultilinearWitness<P>,
34	base: Base<P>,
35	previous_single_bit_output: &[P],
36) -> Vec<P>
37where
38	P: PackedField,
39	P::Scalar: BinaryField,
40{
41	let one = P::one();
42
43	previous_single_bit_output
44		.into_par_iter()
45		.enumerate()
46		.map(|(i, &prev_out)| {
47			let (base, prev_out) = match &base {
48				Base::Static(g) => (*g, prev_out),
49				Base::Dynamic(poly) => (
50					poly.packed_evals()
51						.expect("base and exponent_bit_witness have the same width")[i],
52					prev_out.square(),
53				),
54			};
55
56			let exponent_bit =
57				P::from_fn(|j| {
58					exponent_bit_witness
59					.evaluate_on_hypercube(P::WIDTH * i + j)
60					.expect("previous_single_bit_output and exponent_bit_witness have the same width")
61				});
62			prev_out * (one + exponent_bit * (one + base))
63		})
64		.collect::<Vec<_>>()
65}
66
67enum Base<'a, P: PackedField> {
68	Static(P),
69	Dynamic(MultilinearWitness<'a, P>),
70}
71
72fn evaluate_first_layer_output_packed<P>(
73	exponent_bit_witness: MultilinearWitness<P>,
74	base: Base<P>,
75) -> Vec<P>
76where
77	P: PackedField,
78	P::Scalar: BinaryField,
79{
80	let one = P::one();
81
82	(0..1 << exponent_bit_witness.n_vars().saturating_sub(P::LOG_WIDTH))
83		.into_par_iter()
84		.map(|i| {
85			let base = match &base {
86				Base::Static(g) => *g,
87				Base::Dynamic(poly) => poly
88					.packed_evals()
89					.expect("base and exponent_bit_witness have the same width")[i],
90			};
91
92			let exponent_bit = P::from_fn(|j| {
93				exponent_bit_witness
94					.evaluate_on_hypercube(P::WIDTH * i + j)
95					.expect("eval on the hypercube exists")
96			});
97
98			one + exponent_bit * (one + base)
99		})
100		.collect::<Vec<_>>()
101}
102
103impl<'a, P> BaseExpWitness<'a, P>
104where
105	P: PackedField,
106{
107	/// Constructs a witness where the base is the static [BinaryField].
108	#[instrument(skip_all, name = "gkr_exp::new_with_static_base")]
109	pub fn new_with_static_base(
110		exponent: Vec<MultilinearWitness<'a, P>>,
111		base: P::Scalar,
112	) -> Result<Self, Error>
113	where
114		P: PackedField,
115		P::Scalar: BinaryField,
116	{
117		let exponent_bit_width = exponent.len();
118
119		if exponent_bit_width == 0 {
120			bail!(Error::EmptyExp)
121		}
122
123		if exponent.len() > P::Scalar::N_BITS {
124			bail!(Error::SmallBaseField)
125		}
126
127		equal_n_vars_check(&exponent)?;
128
129		let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
130
131		let mut packed_base_power_static = P::broadcast(base);
132
133		single_bit_output_layers_data[0] = evaluate_first_layer_output_packed(
134			exponent[0].clone(),
135			Base::Static(packed_base_power_static),
136		);
137
138		for layer_idx_from_left in 1..exponent_bit_width {
139			packed_base_power_static = packed_base_power_static.square();
140
141			single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
142				exponent[layer_idx_from_left].clone(),
143				Base::Static(packed_base_power_static),
144				&single_bit_output_layers_data[layer_idx_from_left - 1],
145			);
146		}
147
148		let single_bit_output_layers_data = single_bit_output_layers_data
149			.into_iter()
150			.map(|single_bit_output_layers_data| {
151				MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
152					.map(MLEDirectAdapter::<P>::from)
153					.map(|mle| mle.upcast_arc_dyn())
154			})
155			.collect::<Result<Vec<_>, binius_math::Error>>()?;
156
157		Ok(Self {
158			exponent,
159			single_bit_output_layers_data,
160			base: BaseWitness::Static(base),
161		})
162	}
163
164	/// Constructs a witness with a specified multilinear base.
165	#[instrument(skip_all, name = "gkr_exp::new_with_dynamic_base")]
166	pub fn new_with_dynamic_base(
167		exponent: Vec<MultilinearWitness<'a, P>>,
168		base: MultilinearWitness<'a, P>,
169	) -> Result<Self, Error>
170	where
171		P: PackedField,
172		P::Scalar: BinaryField,
173	{
174		let exponent_bit_width = exponent.len();
175
176		if exponent_bit_width == 0 {
177			bail!(Error::EmptyExp)
178		}
179
180		if exponent.len() > P::Scalar::N_BITS {
181			bail!(Error::SmallBaseField)
182		}
183
184		equal_n_vars_check(&exponent)?;
185
186		if exponent[0].n_vars() != base.n_vars() {
187			bail!(Error::NumberOfVariablesMismatch)
188		}
189
190		let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
191
192		single_bit_output_layers_data[0] = evaluate_first_layer_output_packed(
193			exponent[exponent_bit_width - 1].clone(),
194			Base::Dynamic(base.clone()),
195		);
196
197		for layer_idx_from_left in 1..exponent_bit_width {
198			single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
199				exponent[exponent_bit_width - layer_idx_from_left - 1].clone(),
200				Base::Dynamic(base.clone()),
201				&single_bit_output_layers_data[layer_idx_from_left - 1],
202			);
203		}
204
205		let single_bit_output_layers_data = single_bit_output_layers_data
206			.into_iter()
207			.map(|single_bit_output_layers_data| {
208				MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
209					.map(MLEDirectAdapter::<P>::from)
210					.map(|mle| mle.upcast_arc_dyn())
211			})
212			.collect::<Result<Vec<_>, binius_math::Error>>()?;
213
214		Ok(Self {
215			exponent,
216			single_bit_output_layers_data,
217			base: BaseWitness::Dynamic(base),
218		})
219	}
220
221	pub const fn uses_dynamic_base(&self) -> bool {
222		match self.base {
223			BaseWitness::Static(_) => false,
224			BaseWitness::Dynamic(_) => true,
225		}
226	}
227
228	pub fn n_vars(&self) -> usize {
229		self.exponent[0].n_vars()
230	}
231
232	/// Returns the multilinear that corresponds to the exponentiation of the base to an integers.
233	pub fn exponentiation_result_witness(&self) -> MultilinearWitness<'a, P> {
234		self.single_bit_output_layers_data
235			.last()
236			.expect("single_bit_output_layers_data not empty")
237			.clone()
238	}
239}