binius_core/protocols/gkr_exp/
witness.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{cmp::min, slice};
4
5use binius_field::{
6	ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, Field, PackedExtension,
7	PackedField,
8};
9use binius_math::{MLEEmbeddingAdapter, MultilinearExtension};
10use binius_maybe_rayon::{
11	prelude::{IndexedParallelIterator, ParallelIterator},
12	slice::ParallelSliceMut,
13};
14use binius_utils::bail;
15use bytemuck::zeroed_vec;
16
17use super::error::Error;
18use crate::{protocols::sumcheck::equal_n_vars_check, witness::MultilinearWitness};
19
20#[derive(Clone)]
21pub struct BaseExpWitness<'a, P: PackedField, FBase: Field> {
22	/// Multilinears that represent an integers by its bits.
23	pub exponent: Vec<MultilinearWitness<'a, P>>,
24	/// Circuit layer-multilinears
25	pub single_bit_output_layers_data: Vec<MultilinearWitness<'a, P>>,
26	/// The base to be used for exponentiation.
27	pub base: BaseWitness<'a, P, FBase>,
28}
29
30#[derive(Clone)]
31pub enum BaseWitness<'a, P: PackedField, FBase: Field> {
32	Constant(FBase),
33	Dynamic(MultilinearWitness<'a, P>),
34}
35
36fn copy_witness_into_vec<P, PE>(poly: &MultilinearWitness<PE>) -> Vec<P>
37where
38	P: PackedField,
39	PE: PackedExtension<P::Scalar, PackedSubfield = P>,
40	PE::Scalar: ExtensionField<P::Scalar>,
41{
42	let mut input_layer: Vec<P> = zeroed_vec(1 << poly.n_vars().saturating_sub(P::LOG_WIDTH));
43
44	let log_degree = PE::Scalar::LOG_DEGREE;
45	if poly.n_vars() >= P::LOG_WIDTH {
46		let log_chunk_size = min(poly.n_vars() - P::LOG_WIDTH, 12);
47		input_layer
48			.par_chunks_mut(1 << log_chunk_size)
49			.enumerate()
50			.for_each(|(i, chunk)| {
51				poly.subcube_evals(
52					log_chunk_size + P::LOG_WIDTH,
53					i,
54					log_degree,
55					PE::cast_exts_mut(chunk),
56				)
57				.expect("")
58			});
59	} else {
60		poly.subcube_evals(
61			poly.n_vars(),
62			0,
63			log_degree,
64			PE::cast_exts_mut(slice::from_mut(&mut input_layer[0])),
65		)
66		.expect("index is between 0 and 2^{n_vars - log_chunk_size}; log_embedding degree is 0");
67	}
68
69	input_layer
70}
71
72fn evaluate_single_bit_output_packed<PBits, PBase>(
73	exponent_bit: &[PBits],
74	base: BaseEvals<PBase>,
75	previous_single_bit_output: &[PBase],
76) -> Vec<PBase>
77where
78	PBits: PackedField<Scalar = BinaryField1b>,
79	PBase: PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
80	PBase::Scalar: BinaryField,
81{
82	debug_assert_eq!(
83		PBits::WIDTH * exponent_bit.len(),
84		PBase::WIDTH * previous_single_bit_output.len()
85	);
86
87	let mut result = previous_single_bit_output.to_vec();
88
89	let _ = ext_base_op_par(&mut result, exponent_bit, |i, prev_out, exp_bit_broadcasted| {
90		let (base, prev_out) = match &base {
91			BaseEvals::Constant(g) => (*g, prev_out),
92			BaseEvals::Dynamic(g) => (g[i], prev_out.square()),
93		};
94
95		prev_out
96			* (PBase::cast_ext(PBase::cast_base(base - PBase::one()) * exp_bit_broadcasted)
97				+ PBase::one())
98	});
99
100	result
101}
102
103enum BaseEvals<'a, PBase: PackedField> {
104	Constant(PBase),
105	Dynamic(&'a [PBase]),
106}
107
108fn evaluate_first_layer_output_packed<PBits, PBase>(
109	exponent_bit: &[PBits],
110	base: BaseEvals<PBase>,
111) -> Vec<PBase>
112where
113	PBits: PackedField<Scalar = BinaryField1b>,
114	PBase: PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
115	PBase::Scalar: BinaryField,
116{
117	let mut result = vec![PBase::zero(); exponent_bit.len() * PBase::Scalar::DEGREE];
118
119	let _ = ext_base_op_par(&mut result, exponent_bit, |i, _, exp_bit_broadcasted| {
120		let base = match &base {
121			BaseEvals::Constant(g) => *g,
122			BaseEvals::Dynamic(g) => g[i],
123		};
124
125		PBase::cast_ext(PBase::cast_base(base - PBase::one()) * exp_bit_broadcasted) + PBase::one()
126	});
127
128	result
129}
130
131impl<'a, P, FBase> BaseExpWitness<'a, P, FBase>
132where
133	P: PackedField,
134	FBase: BinaryField,
135{
136	/// Constructs a witness where the base is the constant [BinaryField].
137	pub fn new_with_constant_base<PBits, PBase>(
138		exponent: Vec<MultilinearWitness<'a, P>>,
139		base: PBase::Scalar,
140	) -> Result<Self, Error>
141	where
142		PBits: PackedField<Scalar = BinaryField1b>,
143		PBase: PackedField<Scalar = FBase> + PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
144		P: PackedExtension<PBits::Scalar, PackedSubfield = PBits>
145			+ PackedExtension<PBase::Scalar, PackedSubfield = PBase>,
146		P::Scalar: ExtensionField<PBase::Scalar>,
147	{
148		let exponent_bit_width = exponent.len();
149
150		if exponent_bit_width == 0 {
151			bail!(Error::EmptyExp)
152		}
153
154		if exponent.len() > PBase::Scalar::N_BITS {
155			bail!(Error::SmallBaseField)
156		}
157
158		equal_n_vars_check(&exponent)?;
159
160		let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
161
162		let mut packed_base_power_constant = PBase::broadcast(base);
163
164		single_bit_output_layers_data[0] = evaluate_first_layer_output_packed::<PBits, PBase>(
165			&copy_witness_into_vec(&exponent[0]),
166			BaseEvals::Constant(packed_base_power_constant),
167		);
168
169		for layer_idx_from_left in 1..exponent_bit_width {
170			packed_base_power_constant = packed_base_power_constant.square();
171
172			single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
173				&copy_witness_into_vec(&exponent[layer_idx_from_left]),
174				BaseEvals::Constant(packed_base_power_constant),
175				&single_bit_output_layers_data[layer_idx_from_left - 1],
176			);
177		}
178
179		let single_bit_output_layers_data = single_bit_output_layers_data
180			.into_iter()
181			.map(|single_bit_output_layers_data| {
182				MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
183					.map(MLEEmbeddingAdapter::<PBase, P>::from)
184					.map(|mle| mle.upcast_arc_dyn())
185			})
186			.collect::<Result<Vec<_>, binius_math::Error>>()?;
187
188		Ok(Self {
189			exponent,
190			single_bit_output_layers_data,
191			base: BaseWitness::Constant(base),
192		})
193	}
194
195	/// Constructs a witness with a specified multilinear base.
196	pub fn new_with_dynamic_base<PBits, PBase>(
197		exponent: Vec<MultilinearWitness<'a, P>>,
198		base: MultilinearWitness<'a, P>,
199	) -> Result<Self, Error>
200	where
201		PBits: PackedField<Scalar = BinaryField1b>,
202		PBase: PackedField<Scalar = FBase> + PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
203		P: PackedExtension<PBits::Scalar, PackedSubfield = PBits>
204			+ PackedExtension<PBase::Scalar, PackedSubfield = PBase>,
205		P::Scalar: ExtensionField<PBase::Scalar>,
206	{
207		let exponent_bit_width = exponent.len();
208
209		if exponent_bit_width == 0 {
210			bail!(Error::EmptyExp)
211		}
212
213		if exponent.len() > PBase::Scalar::N_BITS {
214			bail!(Error::SmallBaseField)
215		}
216
217		equal_n_vars_check(&exponent)?;
218
219		let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
220
221		let base_evals = copy_witness_into_vec::<PBase, P>(&base);
222
223		single_bit_output_layers_data[0] = evaluate_first_layer_output_packed::<PBits, PBase>(
224			&copy_witness_into_vec(&exponent[exponent_bit_width - 1]),
225			BaseEvals::Dynamic(&base_evals),
226		);
227
228		for layer_idx_from_left in 1..exponent_bit_width {
229			single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
230				&copy_witness_into_vec(&exponent[exponent_bit_width - layer_idx_from_left - 1]),
231				BaseEvals::Dynamic(&base_evals),
232				&single_bit_output_layers_data[layer_idx_from_left - 1],
233			);
234		}
235
236		let single_bit_output_layers_data = single_bit_output_layers_data
237			.into_iter()
238			.map(|single_bit_output_layers_data| {
239				MultilinearExtension::from_values(single_bit_output_layers_data)
240					.map(MLEEmbeddingAdapter::<PBase, P>::from)
241					.map(|mle| mle.upcast_arc_dyn())
242			})
243			.collect::<Result<Vec<_>, binius_math::Error>>()?;
244
245		Ok(Self {
246			exponent,
247			single_bit_output_layers_data,
248			base: BaseWitness::Dynamic(base),
249		})
250	}
251
252	pub const fn uses_dynamic_base(&self) -> bool {
253		match self.base {
254			BaseWitness::Constant(_) => false,
255			BaseWitness::Dynamic(_) => true,
256		}
257	}
258
259	pub fn n_vars(&self) -> usize {
260		self.exponent[0].n_vars()
261	}
262
263	/// Returns the multilinear that corresponds to the exponentiation of the base to an integers.
264	pub fn exponentiation_result_witness(&self) -> MultilinearWitness<'a, P> {
265		self.single_bit_output_layers_data
266			.last()
267			.expect("single_bit_output_layers_data not empty")
268			.clone()
269	}
270}