binius_core/protocols/gkr_exp/
witness.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, RepackedExtension};
4use binius_math::{MLEDirectAdapter, MultilinearExtension};
5use binius_maybe_rayon::{
6	iter::{IndexedParallelIterator, IntoParallelIterator},
7	prelude::ParallelIterator,
8};
9use binius_utils::bail;
10use tracing::instrument;
11
12use super::error::Error;
13use crate::{protocols::sumcheck::equal_n_vars_check, witness::MultilinearWitness};
14
15#[derive(Clone)]
16pub struct BaseExpWitness<'a, P: PackedField> {
17	/// Multilinears that represent an integers by its bits.
18	pub exponent: Vec<MultilinearWitness<'a, P>>,
19	/// Circuit layer-multilinears
20	pub single_bit_output_layers_data: Vec<MultilinearWitness<'a, P>>,
21	/// The base to be used for exponentiation.
22	pub base: BaseWitness<'a, P>,
23}
24
25#[derive(Clone)]
26pub enum BaseWitness<'a, P: PackedField> {
27	Static(P::Scalar),
28	Dynamic(MultilinearWitness<'a, P>),
29}
30
31#[instrument(skip_all, name = "gkr_exp::evaluate_single_bit_output_packed")]
32fn evaluate_single_bit_output_packed<P, PExpBase>(
33	exponent_bit_witness: MultilinearWitness<P>,
34	base: &Base<PExpBase>,
35	previous_single_bit_output: &[PExpBase],
36) -> Vec<PExpBase>
37where
38	P: PackedField,
39	PExpBase: PackedField,
40{
41	let one = PExpBase::one();
42
43	previous_single_bit_output
44		.into_par_iter()
45		.enumerate()
46		.map(|(i, &prev_out)| {
47			let (base, prev_out) = match base {
48				Base::Static(g) => (*g, prev_out),
49				Base::Dynamic(evals) => (evals[i], prev_out.square()),
50			};
51
52			let exponent_bit = PExpBase::from_fn(|j| {
53				let ext_bit = exponent_bit_witness
54					.evaluate_on_hypercube(PExpBase::WIDTH * i + j)
55					.unwrap_or_else(|_| P::Scalar::zero());
56
57				if ext_bit == P::Scalar::one() {
58					PExpBase::Scalar::one()
59				} else {
60					PExpBase::Scalar::zero()
61				}
62			});
63
64			prev_out * (one + exponent_bit * (one + base))
65		})
66		.collect::<Vec<_>>()
67}
68
69enum Base<'a, PExpBase: PackedField> {
70	Static(PExpBase),
71	Dynamic(&'a [PExpBase]),
72}
73
74fn evaluate_first_layer_output_packed<P, PExpBase>(
75	exponent_bit_witness: MultilinearWitness<P>,
76	base: &Base<PExpBase>,
77) -> Vec<PExpBase>
78where
79	P: PackedField,
80	PExpBase: PackedField,
81{
82	let one = PExpBase::one();
83
84	(0..1
85		<< exponent_bit_witness
86			.n_vars()
87			.saturating_sub(PExpBase::LOG_WIDTH))
88		.into_par_iter()
89		.map(|i| {
90			let base = match base {
91				Base::Static(g) => *g,
92				Base::Dynamic(evals) => evals[i],
93			};
94
95			let exponent_bit = PExpBase::from_fn(|j| {
96				let ext_bit = exponent_bit_witness
97					.evaluate_on_hypercube(PExpBase::WIDTH * i + j)
98					.unwrap_or_else(|_| P::Scalar::zero());
99
100				if ext_bit == P::Scalar::one() {
101					PExpBase::Scalar::one()
102				} else {
103					PExpBase::Scalar::zero()
104				}
105			});
106
107			one + exponent_bit * (one + base)
108		})
109		.collect::<Vec<_>>()
110}
111
112impl<'a, P> BaseExpWitness<'a, P>
113where
114	P: PackedField,
115{
116	/// Constructs a witness where the base is the static [BinaryField].
117	#[instrument(skip_all, name = "gkr_exp::new_with_static_base")]
118	pub fn new_with_static_base<PExpBase>(
119		exponent: Vec<MultilinearWitness<'a, P>>,
120		base: PExpBase::Scalar,
121	) -> Result<Self, Error>
122	where
123		P: RepackedExtension<PExpBase>,
124		PExpBase::Scalar: BinaryField,
125		PExpBase: PackedField,
126	{
127		let exponent_bit_width = exponent.len();
128
129		if exponent_bit_width == 0 {
130			bail!(Error::EmptyExp)
131		}
132
133		if exponent.len() > PExpBase::Scalar::N_BITS {
134			bail!(Error::SmallBaseField)
135		}
136
137		equal_n_vars_check(&exponent)?;
138
139		let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
140
141		let mut packed_base_power_static = PExpBase::broadcast(base);
142
143		single_bit_output_layers_data[0] = evaluate_first_layer_output_packed(
144			exponent[0].clone(),
145			&Base::Static(packed_base_power_static),
146		);
147
148		for layer_idx_from_left in 1..exponent_bit_width {
149			packed_base_power_static = packed_base_power_static.square();
150
151			single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
152				exponent[layer_idx_from_left].clone(),
153				&Base::Static(packed_base_power_static),
154				&single_bit_output_layers_data[layer_idx_from_left - 1],
155			);
156		}
157
158		let single_bit_output_layers_data = single_bit_output_layers_data
159			.into_iter()
160			.map(|single_bit_output_layers_data| {
161				MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
162					.map(|me| me.specialize_arc_dyn())
163			})
164			.collect::<Result<Vec<_>, binius_math::Error>>()?;
165
166		Ok(Self {
167			exponent,
168			single_bit_output_layers_data,
169			base: BaseWitness::Static(base.into()),
170		})
171	}
172
173	/// Constructs a witness with a specified multilinear base.
174	///
175	/// # Requirements
176	/// * base have `packed_evals()` is Some
177	#[instrument(skip_all, name = "gkr_exp::new_with_dynamic_base")]
178	pub fn new_with_dynamic_base<PExpBase>(
179		exponent: Vec<MultilinearWitness<'a, P>>,
180		base: MultilinearWitness<'a, P>,
181	) -> Result<Self, Error>
182	where
183		P: RepackedExtension<PExpBase>,
184		P::Scalar: BinaryField + ExtensionField<PExpBase::Scalar>,
185		PExpBase::Scalar: BinaryField,
186		PExpBase: PackedField,
187	{
188		let exponent_bit_width = exponent.len();
189
190		if exponent_bit_width == 0 {
191			bail!(Error::EmptyExp)
192		}
193
194		if exponent.len() > PExpBase::Scalar::N_BITS {
195			bail!(Error::SmallBaseField)
196		}
197
198		equal_n_vars_check(&exponent)?;
199
200		if exponent[0].n_vars() != base.n_vars() {
201			bail!(Error::NumberOfVariablesMismatch)
202		}
203
204		let single_bit_output_layers_data =
205			// If P and PExpBase are the same field, MLEDirectAdapter can be used
206			if <P::Scalar as ExtensionField<PExpBase::Scalar>>::DEGREE == 0 {
207				let single_bit_output_layers_data =
208					Self::build_dynamic_single_bit_output_layers_data::<P>(
209						&exponent,
210						base.packed_evals().expect("packed_evals required"),
211					);
212
213				single_bit_output_layers_data
214					.into_iter()
215					.map(|single_bit_output_layers_data| {
216						MultilinearExtension::new(
217							exponent[0].n_vars(),
218							single_bit_output_layers_data,
219						)
220						.map(MLEDirectAdapter::from)
221						.map(|mle| mle.upcast_arc_dyn())
222					})
223					.collect::<Result<Vec<_>, binius_math::Error>>()?
224			} else {
225				let repacked_evals = <P as PackedExtension<PExpBase::Scalar>>::cast_bases(
226					base.packed_evals().expect("packed_evals required"),
227				);
228
229				if repacked_evals.len() != 1 << base.n_vars().saturating_sub(PExpBase::LOG_WIDTH) {
230					bail!(Error::NumberOfVariablesMismatch)
231				}
232
233				let single_bit_output_layers_data =
234					Self::build_dynamic_single_bit_output_layers_data::<PExpBase>(
235						&exponent,
236						repacked_evals,
237					);
238
239				single_bit_output_layers_data
240					.into_iter()
241					.map(|single_bit_output_layers_data| {
242						MultilinearExtension::new(
243							exponent[0].n_vars(),
244							single_bit_output_layers_data,
245						)
246						.map(|me| me.specialize_arc_dyn())
247					})
248					.collect::<Result<Vec<_>, binius_math::Error>>()?
249			};
250
251		Ok(Self {
252			exponent,
253			single_bit_output_layers_data,
254			base: BaseWitness::Dynamic(base),
255		})
256	}
257
258	fn build_dynamic_single_bit_output_layers_data<PExpBase>(
259		exponent: &[MultilinearWitness<'a, P>],
260		evals: &[PExpBase],
261	) -> Vec<Vec<PExpBase>>
262	where
263		P: PackedField,
264		PExpBase::Scalar: BinaryField,
265		PExpBase: PackedField,
266	{
267		let exponent_bit_width = exponent.len();
268
269		let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
270
271		let base = Base::Dynamic(evals);
272
273		single_bit_output_layers_data[0] =
274			evaluate_first_layer_output_packed(exponent[exponent_bit_width - 1].clone(), &base);
275
276		for layer_idx_from_left in 1..exponent_bit_width {
277			single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
278				exponent[exponent_bit_width - layer_idx_from_left - 1].clone(),
279				&base,
280				&single_bit_output_layers_data[layer_idx_from_left - 1],
281			);
282		}
283		single_bit_output_layers_data
284	}
285
286	pub const fn uses_dynamic_base(&self) -> bool {
287		match self.base {
288			BaseWitness::Static(_) => false,
289			BaseWitness::Dynamic(_) => true,
290		}
291	}
292
293	pub fn n_vars(&self) -> usize {
294		self.exponent[0].n_vars()
295	}
296
297	/// Returns the multilinear that corresponds to the exponentiation of the base to an integers.
298	pub fn exponentiation_result_witness(&self) -> MultilinearWitness<'a, P> {
299		self.single_bit_output_layers_data
300			.last()
301			.expect("single_bit_output_layers_data not empty")
302			.clone()
303	}
304}