binius_core/protocols/gkr_exp/
witness.rs1use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, RepackedExtension};
4use binius_math::{MLEDirectAdapter, MultilinearExtension};
5use binius_maybe_rayon::{
6 iter::{IndexedParallelIterator, IntoParallelIterator},
7 prelude::ParallelIterator,
8};
9use binius_utils::bail;
10use tracing::instrument;
11
12use super::error::Error;
13use crate::{protocols::sumcheck::equal_n_vars_check, witness::MultilinearWitness};
14
15#[derive(Clone)]
16pub struct BaseExpWitness<'a, P: PackedField> {
17 pub exponent: Vec<MultilinearWitness<'a, P>>,
19 pub single_bit_output_layers_data: Vec<MultilinearWitness<'a, P>>,
21 pub base: BaseWitness<'a, P>,
23}
24
25#[derive(Clone)]
26pub enum BaseWitness<'a, P: PackedField> {
27 Static(P::Scalar),
28 Dynamic(MultilinearWitness<'a, P>),
29}
30
31#[instrument(skip_all, name = "gkr_exp::evaluate_single_bit_output_packed")]
32fn evaluate_single_bit_output_packed<P, PExpBase>(
33 exponent_bit_witness: MultilinearWitness<P>,
34 base: &Base<PExpBase>,
35 previous_single_bit_output: &[PExpBase],
36) -> Vec<PExpBase>
37where
38 P: PackedField,
39 PExpBase: PackedField,
40{
41 let one = PExpBase::one();
42
43 previous_single_bit_output
44 .into_par_iter()
45 .enumerate()
46 .map(|(i, &prev_out)| {
47 let (base, prev_out) = match base {
48 Base::Static(g) => (*g, prev_out),
49 Base::Dynamic(evals) => (evals[i], prev_out.square()),
50 };
51
52 let exponent_bit = PExpBase::from_fn(|j| {
53 let ext_bit = exponent_bit_witness
54 .evaluate_on_hypercube(PExpBase::WIDTH * i + j)
55 .unwrap_or_else(|_| P::Scalar::zero());
56
57 if ext_bit == P::Scalar::one() {
58 PExpBase::Scalar::one()
59 } else {
60 PExpBase::Scalar::zero()
61 }
62 });
63
64 prev_out * (one + exponent_bit * (one + base))
65 })
66 .collect::<Vec<_>>()
67}
68
69enum Base<'a, PExpBase: PackedField> {
70 Static(PExpBase),
71 Dynamic(&'a [PExpBase]),
72}
73
74fn evaluate_first_layer_output_packed<P, PExpBase>(
75 exponent_bit_witness: MultilinearWitness<P>,
76 base: &Base<PExpBase>,
77) -> Vec<PExpBase>
78where
79 P: PackedField,
80 PExpBase: PackedField,
81{
82 let one = PExpBase::one();
83
84 (0..1
85 << exponent_bit_witness
86 .n_vars()
87 .saturating_sub(PExpBase::LOG_WIDTH))
88 .into_par_iter()
89 .map(|i| {
90 let base = match base {
91 Base::Static(g) => *g,
92 Base::Dynamic(evals) => evals[i],
93 };
94
95 let exponent_bit = PExpBase::from_fn(|j| {
96 let ext_bit = exponent_bit_witness
97 .evaluate_on_hypercube(PExpBase::WIDTH * i + j)
98 .unwrap_or_else(|_| P::Scalar::zero());
99
100 if ext_bit == P::Scalar::one() {
101 PExpBase::Scalar::one()
102 } else {
103 PExpBase::Scalar::zero()
104 }
105 });
106
107 one + exponent_bit * (one + base)
108 })
109 .collect::<Vec<_>>()
110}
111
112impl<'a, P> BaseExpWitness<'a, P>
113where
114 P: PackedField,
115{
116 #[instrument(skip_all, name = "gkr_exp::new_with_static_base")]
118 pub fn new_with_static_base<PExpBase>(
119 exponent: Vec<MultilinearWitness<'a, P>>,
120 base: PExpBase::Scalar,
121 ) -> Result<Self, Error>
122 where
123 P: RepackedExtension<PExpBase>,
124 PExpBase::Scalar: BinaryField,
125 PExpBase: PackedField,
126 {
127 let exponent_bit_width = exponent.len();
128
129 if exponent_bit_width == 0 {
130 bail!(Error::EmptyExp)
131 }
132
133 if exponent.len() > PExpBase::Scalar::N_BITS {
134 bail!(Error::SmallBaseField)
135 }
136
137 equal_n_vars_check(&exponent)?;
138
139 let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
140
141 let mut packed_base_power_static = PExpBase::broadcast(base);
142
143 single_bit_output_layers_data[0] = evaluate_first_layer_output_packed(
144 exponent[0].clone(),
145 &Base::Static(packed_base_power_static),
146 );
147
148 for layer_idx_from_left in 1..exponent_bit_width {
149 packed_base_power_static = packed_base_power_static.square();
150
151 single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
152 exponent[layer_idx_from_left].clone(),
153 &Base::Static(packed_base_power_static),
154 &single_bit_output_layers_data[layer_idx_from_left - 1],
155 );
156 }
157
158 let single_bit_output_layers_data = single_bit_output_layers_data
159 .into_iter()
160 .map(|single_bit_output_layers_data| {
161 MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
162 .map(|me| me.specialize_arc_dyn())
163 })
164 .collect::<Result<Vec<_>, binius_math::Error>>()?;
165
166 Ok(Self {
167 exponent,
168 single_bit_output_layers_data,
169 base: BaseWitness::Static(base.into()),
170 })
171 }
172
173 #[instrument(skip_all, name = "gkr_exp::new_with_dynamic_base")]
178 pub fn new_with_dynamic_base<PExpBase>(
179 exponent: Vec<MultilinearWitness<'a, P>>,
180 base: MultilinearWitness<'a, P>,
181 ) -> Result<Self, Error>
182 where
183 P: RepackedExtension<PExpBase>,
184 P::Scalar: BinaryField + ExtensionField<PExpBase::Scalar>,
185 PExpBase::Scalar: BinaryField,
186 PExpBase: PackedField,
187 {
188 let exponent_bit_width = exponent.len();
189
190 if exponent_bit_width == 0 {
191 bail!(Error::EmptyExp)
192 }
193
194 if exponent.len() > PExpBase::Scalar::N_BITS {
195 bail!(Error::SmallBaseField)
196 }
197
198 equal_n_vars_check(&exponent)?;
199
200 if exponent[0].n_vars() != base.n_vars() {
201 bail!(Error::NumberOfVariablesMismatch)
202 }
203
204 let single_bit_output_layers_data =
205 if <P::Scalar as ExtensionField<PExpBase::Scalar>>::DEGREE == 0 {
207 let single_bit_output_layers_data =
208 Self::build_dynamic_single_bit_output_layers_data::<P>(
209 &exponent,
210 base.packed_evals().expect("packed_evals required"),
211 );
212
213 single_bit_output_layers_data
214 .into_iter()
215 .map(|single_bit_output_layers_data| {
216 MultilinearExtension::new(
217 exponent[0].n_vars(),
218 single_bit_output_layers_data,
219 )
220 .map(MLEDirectAdapter::from)
221 .map(|mle| mle.upcast_arc_dyn())
222 })
223 .collect::<Result<Vec<_>, binius_math::Error>>()?
224 } else {
225 let repacked_evals = <P as PackedExtension<PExpBase::Scalar>>::cast_bases(
226 base.packed_evals().expect("packed_evals required"),
227 );
228
229 if repacked_evals.len() != 1 << base.n_vars().saturating_sub(PExpBase::LOG_WIDTH) {
230 bail!(Error::NumberOfVariablesMismatch)
231 }
232
233 let single_bit_output_layers_data =
234 Self::build_dynamic_single_bit_output_layers_data::<PExpBase>(
235 &exponent,
236 repacked_evals,
237 );
238
239 single_bit_output_layers_data
240 .into_iter()
241 .map(|single_bit_output_layers_data| {
242 MultilinearExtension::new(
243 exponent[0].n_vars(),
244 single_bit_output_layers_data,
245 )
246 .map(|me| me.specialize_arc_dyn())
247 })
248 .collect::<Result<Vec<_>, binius_math::Error>>()?
249 };
250
251 Ok(Self {
252 exponent,
253 single_bit_output_layers_data,
254 base: BaseWitness::Dynamic(base),
255 })
256 }
257
258 fn build_dynamic_single_bit_output_layers_data<PExpBase>(
259 exponent: &[MultilinearWitness<'a, P>],
260 evals: &[PExpBase],
261 ) -> Vec<Vec<PExpBase>>
262 where
263 P: PackedField,
264 PExpBase::Scalar: BinaryField,
265 PExpBase: PackedField,
266 {
267 let exponent_bit_width = exponent.len();
268
269 let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
270
271 let base = Base::Dynamic(evals);
272
273 single_bit_output_layers_data[0] =
274 evaluate_first_layer_output_packed(exponent[exponent_bit_width - 1].clone(), &base);
275
276 for layer_idx_from_left in 1..exponent_bit_width {
277 single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
278 exponent[exponent_bit_width - layer_idx_from_left - 1].clone(),
279 &base,
280 &single_bit_output_layers_data[layer_idx_from_left - 1],
281 );
282 }
283 single_bit_output_layers_data
284 }
285
286 pub const fn uses_dynamic_base(&self) -> bool {
287 match self.base {
288 BaseWitness::Static(_) => false,
289 BaseWitness::Dynamic(_) => true,
290 }
291 }
292
293 pub fn n_vars(&self) -> usize {
294 self.exponent[0].n_vars()
295 }
296
297 pub fn exponentiation_result_witness(&self) -> MultilinearWitness<'a, P> {
299 self.single_bit_output_layers_data
300 .last()
301 .expect("single_bit_output_layers_data not empty")
302 .clone()
303 }
304}