binius_core/protocols/gkr_exp/
witness.rs1use std::{cmp::min, slice};
4
5use binius_field::{
6 ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, Field, PackedExtension,
7 PackedField,
8};
9use binius_math::{MLEEmbeddingAdapter, MultilinearExtension};
10use binius_maybe_rayon::{
11 prelude::{IndexedParallelIterator, ParallelIterator},
12 slice::ParallelSliceMut,
13};
14use binius_utils::bail;
15use bytemuck::zeroed_vec;
16
17use super::error::Error;
18use crate::{protocols::sumcheck::equal_n_vars_check, witness::MultilinearWitness};
19
20#[derive(Clone)]
21pub struct BaseExpWitness<'a, P: PackedField, FBase: Field> {
22 pub exponent: Vec<MultilinearWitness<'a, P>>,
24 pub single_bit_output_layers_data: Vec<MultilinearWitness<'a, P>>,
26 pub base: BaseWitness<'a, P, FBase>,
28}
29
30#[derive(Clone)]
31pub enum BaseWitness<'a, P: PackedField, FBase: Field> {
32 Constant(FBase),
33 Dynamic(MultilinearWitness<'a, P>),
34}
35
36fn copy_witness_into_vec<P, PE>(poly: &MultilinearWitness<PE>) -> Vec<P>
37where
38 P: PackedField,
39 PE: PackedExtension<P::Scalar, PackedSubfield = P>,
40 PE::Scalar: ExtensionField<P::Scalar>,
41{
42 let mut input_layer: Vec<P> = zeroed_vec(1 << poly.n_vars().saturating_sub(P::LOG_WIDTH));
43
44 let log_degree = PE::Scalar::LOG_DEGREE;
45 if poly.n_vars() >= P::LOG_WIDTH {
46 let log_chunk_size = min(poly.n_vars() - P::LOG_WIDTH, 12);
47 input_layer
48 .par_chunks_mut(1 << log_chunk_size)
49 .enumerate()
50 .for_each(|(i, chunk)| {
51 poly.subcube_evals(
52 log_chunk_size + P::LOG_WIDTH,
53 i,
54 log_degree,
55 PE::cast_exts_mut(chunk),
56 )
57 .expect("")
58 });
59 } else {
60 poly.subcube_evals(
61 poly.n_vars(),
62 0,
63 log_degree,
64 PE::cast_exts_mut(slice::from_mut(&mut input_layer[0])),
65 )
66 .expect("index is between 0 and 2^{n_vars - log_chunk_size}; log_embedding degree is 0");
67 }
68
69 input_layer
70}
71
72fn evaluate_single_bit_output_packed<PBits, PBase>(
73 exponent_bit: &[PBits],
74 base: BaseEvals<PBase>,
75 previous_single_bit_output: &[PBase],
76) -> Vec<PBase>
77where
78 PBits: PackedField<Scalar = BinaryField1b>,
79 PBase: PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
80 PBase::Scalar: BinaryField,
81{
82 debug_assert_eq!(
83 PBits::WIDTH * exponent_bit.len(),
84 PBase::WIDTH * previous_single_bit_output.len()
85 );
86
87 let mut result = previous_single_bit_output.to_vec();
88
89 let _ = ext_base_op_par(&mut result, exponent_bit, |i, prev_out, exp_bit_broadcasted| {
90 let (base, prev_out) = match &base {
91 BaseEvals::Constant(g) => (*g, prev_out),
92 BaseEvals::Dynamic(g) => (g[i], prev_out.square()),
93 };
94
95 prev_out
96 * (PBase::cast_ext(PBase::cast_base(base - PBase::one()) * exp_bit_broadcasted)
97 + PBase::one())
98 });
99
100 result
101}
102
103enum BaseEvals<'a, PBase: PackedField> {
104 Constant(PBase),
105 Dynamic(&'a [PBase]),
106}
107
108fn evaluate_first_layer_output_packed<PBits, PBase>(
109 exponent_bit: &[PBits],
110 base: BaseEvals<PBase>,
111) -> Vec<PBase>
112where
113 PBits: PackedField<Scalar = BinaryField1b>,
114 PBase: PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
115 PBase::Scalar: BinaryField,
116{
117 let mut result = vec![PBase::zero(); exponent_bit.len() * PBase::Scalar::DEGREE];
118
119 let _ = ext_base_op_par(&mut result, exponent_bit, |i, _, exp_bit_broadcasted| {
120 let base = match &base {
121 BaseEvals::Constant(g) => *g,
122 BaseEvals::Dynamic(g) => g[i],
123 };
124
125 PBase::cast_ext(PBase::cast_base(base - PBase::one()) * exp_bit_broadcasted) + PBase::one()
126 });
127
128 result
129}
130
131impl<'a, P, FBase> BaseExpWitness<'a, P, FBase>
132where
133 P: PackedField,
134 FBase: BinaryField,
135{
136 pub fn new_with_constant_base<PBits, PBase>(
138 exponent: Vec<MultilinearWitness<'a, P>>,
139 base: PBase::Scalar,
140 ) -> Result<Self, Error>
141 where
142 PBits: PackedField<Scalar = BinaryField1b>,
143 PBase: PackedField<Scalar = FBase> + PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
144 P: PackedExtension<PBits::Scalar, PackedSubfield = PBits>
145 + PackedExtension<PBase::Scalar, PackedSubfield = PBase>,
146 P::Scalar: ExtensionField<PBase::Scalar>,
147 {
148 let exponent_bit_width = exponent.len();
149
150 if exponent_bit_width == 0 {
151 bail!(Error::EmptyExp)
152 }
153
154 if exponent.len() > PBase::Scalar::N_BITS {
155 bail!(Error::SmallBaseField)
156 }
157
158 equal_n_vars_check(&exponent)?;
159
160 let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
161
162 let mut packed_base_power_constant = PBase::broadcast(base);
163
164 single_bit_output_layers_data[0] = evaluate_first_layer_output_packed::<PBits, PBase>(
165 ©_witness_into_vec(&exponent[0]),
166 BaseEvals::Constant(packed_base_power_constant),
167 );
168
169 for layer_idx_from_left in 1..exponent_bit_width {
170 packed_base_power_constant = packed_base_power_constant.square();
171
172 single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
173 ©_witness_into_vec(&exponent[layer_idx_from_left]),
174 BaseEvals::Constant(packed_base_power_constant),
175 &single_bit_output_layers_data[layer_idx_from_left - 1],
176 );
177 }
178
179 let single_bit_output_layers_data = single_bit_output_layers_data
180 .into_iter()
181 .map(|single_bit_output_layers_data| {
182 MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
183 .map(MLEEmbeddingAdapter::<PBase, P>::from)
184 .map(|mle| mle.upcast_arc_dyn())
185 })
186 .collect::<Result<Vec<_>, binius_math::Error>>()?;
187
188 Ok(Self {
189 exponent,
190 single_bit_output_layers_data,
191 base: BaseWitness::Constant(base),
192 })
193 }
194
195 pub fn new_with_dynamic_base<PBits, PBase>(
197 exponent: Vec<MultilinearWitness<'a, P>>,
198 base: MultilinearWitness<'a, P>,
199 ) -> Result<Self, Error>
200 where
201 PBits: PackedField<Scalar = BinaryField1b>,
202 PBase: PackedField<Scalar = FBase> + PackedExtension<PBits::Scalar, PackedSubfield = PBits>,
203 P: PackedExtension<PBits::Scalar, PackedSubfield = PBits>
204 + PackedExtension<PBase::Scalar, PackedSubfield = PBase>,
205 P::Scalar: ExtensionField<PBase::Scalar>,
206 {
207 let exponent_bit_width = exponent.len();
208
209 if exponent_bit_width == 0 {
210 bail!(Error::EmptyExp)
211 }
212
213 if exponent.len() > PBase::Scalar::N_BITS {
214 bail!(Error::SmallBaseField)
215 }
216
217 equal_n_vars_check(&exponent)?;
218
219 let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
220
221 let base_evals = copy_witness_into_vec::<PBase, P>(&base);
222
223 single_bit_output_layers_data[0] = evaluate_first_layer_output_packed::<PBits, PBase>(
224 ©_witness_into_vec(&exponent[exponent_bit_width - 1]),
225 BaseEvals::Dynamic(&base_evals),
226 );
227
228 for layer_idx_from_left in 1..exponent_bit_width {
229 single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
230 ©_witness_into_vec(&exponent[exponent_bit_width - layer_idx_from_left - 1]),
231 BaseEvals::Dynamic(&base_evals),
232 &single_bit_output_layers_data[layer_idx_from_left - 1],
233 );
234 }
235
236 let single_bit_output_layers_data = single_bit_output_layers_data
237 .into_iter()
238 .map(|single_bit_output_layers_data| {
239 MultilinearExtension::from_values(single_bit_output_layers_data)
240 .map(MLEEmbeddingAdapter::<PBase, P>::from)
241 .map(|mle| mle.upcast_arc_dyn())
242 })
243 .collect::<Result<Vec<_>, binius_math::Error>>()?;
244
245 Ok(Self {
246 exponent,
247 single_bit_output_layers_data,
248 base: BaseWitness::Dynamic(base),
249 })
250 }
251
252 pub const fn uses_dynamic_base(&self) -> bool {
253 match self.base {
254 BaseWitness::Constant(_) => false,
255 BaseWitness::Dynamic(_) => true,
256 }
257 }
258
259 pub fn n_vars(&self) -> usize {
260 self.exponent[0].n_vars()
261 }
262
263 pub fn exponentiation_result_witness(&self) -> MultilinearWitness<'a, P> {
265 self.single_bit_output_layers_data
266 .last()
267 .expect("single_bit_output_layers_data not empty")
268 .clone()
269 }
270}