binius_core/protocols/gkr_exp/
witness.rs1use binius_field::{BinaryField, PackedField};
4use binius_math::{MLEDirectAdapter, MultilinearExtension};
5use binius_maybe_rayon::{
6 iter::{IndexedParallelIterator, IntoParallelIterator},
7 prelude::ParallelIterator,
8};
9use binius_utils::bail;
10use tracing::instrument;
11
12use super::error::Error;
13use crate::{protocols::sumcheck::equal_n_vars_check, witness::MultilinearWitness};
14
15#[derive(Clone)]
16pub struct BaseExpWitness<'a, P: PackedField> {
17 pub exponent: Vec<MultilinearWitness<'a, P>>,
19 pub single_bit_output_layers_data: Vec<MultilinearWitness<'a, P>>,
21 pub base: BaseWitness<'a, P>,
23}
24
25#[derive(Clone)]
26pub enum BaseWitness<'a, P: PackedField> {
27 Static(P::Scalar),
28 Dynamic(MultilinearWitness<'a, P>),
29}
30
31#[instrument(skip_all, name = "gkr_exp::evaluate_single_bit_output_packed")]
32fn evaluate_single_bit_output_packed<P>(
33 exponent_bit_witness: MultilinearWitness<P>,
34 base: Base<P>,
35 previous_single_bit_output: &[P],
36) -> Vec<P>
37where
38 P: PackedField,
39 P::Scalar: BinaryField,
40{
41 let one = P::one();
42
43 previous_single_bit_output
44 .into_par_iter()
45 .enumerate()
46 .map(|(i, &prev_out)| {
47 let (base, prev_out) = match &base {
48 Base::Static(g) => (*g, prev_out),
49 Base::Dynamic(poly) => (
50 poly.packed_evals()
51 .expect("base and exponent_bit_witness have the same width")[i],
52 prev_out.square(),
53 ),
54 };
55
56 let exponent_bit =
57 P::from_fn(|j| {
58 exponent_bit_witness
59 .evaluate_on_hypercube(P::WIDTH * i + j)
60 .expect("previous_single_bit_output and exponent_bit_witness have the same width")
61 });
62 prev_out * (one + exponent_bit * (one + base))
63 })
64 .collect::<Vec<_>>()
65}
66
67enum Base<'a, P: PackedField> {
68 Static(P),
69 Dynamic(MultilinearWitness<'a, P>),
70}
71
72fn evaluate_first_layer_output_packed<P>(
73 exponent_bit_witness: MultilinearWitness<P>,
74 base: Base<P>,
75) -> Vec<P>
76where
77 P: PackedField,
78 P::Scalar: BinaryField,
79{
80 let one = P::one();
81
82 (0..1 << exponent_bit_witness.n_vars().saturating_sub(P::LOG_WIDTH))
83 .into_par_iter()
84 .map(|i| {
85 let base = match &base {
86 Base::Static(g) => *g,
87 Base::Dynamic(poly) => poly
88 .packed_evals()
89 .expect("base and exponent_bit_witness have the same width")[i],
90 };
91
92 let exponent_bit = P::from_fn(|j| {
93 exponent_bit_witness
94 .evaluate_on_hypercube(P::WIDTH * i + j)
95 .expect("eval on the hypercube exists")
96 });
97
98 one + exponent_bit * (one + base)
99 })
100 .collect::<Vec<_>>()
101}
102
103impl<'a, P> BaseExpWitness<'a, P>
104where
105 P: PackedField,
106{
107 #[instrument(skip_all, name = "gkr_exp::new_with_static_base")]
109 pub fn new_with_static_base(
110 exponent: Vec<MultilinearWitness<'a, P>>,
111 base: P::Scalar,
112 ) -> Result<Self, Error>
113 where
114 P: PackedField,
115 P::Scalar: BinaryField,
116 {
117 let exponent_bit_width = exponent.len();
118
119 if exponent_bit_width == 0 {
120 bail!(Error::EmptyExp)
121 }
122
123 if exponent.len() > P::Scalar::N_BITS {
124 bail!(Error::SmallBaseField)
125 }
126
127 equal_n_vars_check(&exponent)?;
128
129 let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
130
131 let mut packed_base_power_static = P::broadcast(base);
132
133 single_bit_output_layers_data[0] = evaluate_first_layer_output_packed(
134 exponent[0].clone(),
135 Base::Static(packed_base_power_static),
136 );
137
138 for layer_idx_from_left in 1..exponent_bit_width {
139 packed_base_power_static = packed_base_power_static.square();
140
141 single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
142 exponent[layer_idx_from_left].clone(),
143 Base::Static(packed_base_power_static),
144 &single_bit_output_layers_data[layer_idx_from_left - 1],
145 );
146 }
147
148 let single_bit_output_layers_data = single_bit_output_layers_data
149 .into_iter()
150 .map(|single_bit_output_layers_data| {
151 MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
152 .map(MLEDirectAdapter::<P>::from)
153 .map(|mle| mle.upcast_arc_dyn())
154 })
155 .collect::<Result<Vec<_>, binius_math::Error>>()?;
156
157 Ok(Self {
158 exponent,
159 single_bit_output_layers_data,
160 base: BaseWitness::Static(base),
161 })
162 }
163
164 #[instrument(skip_all, name = "gkr_exp::new_with_dynamic_base")]
166 pub fn new_with_dynamic_base(
167 exponent: Vec<MultilinearWitness<'a, P>>,
168 base: MultilinearWitness<'a, P>,
169 ) -> Result<Self, Error>
170 where
171 P: PackedField,
172 P::Scalar: BinaryField,
173 {
174 let exponent_bit_width = exponent.len();
175
176 if exponent_bit_width == 0 {
177 bail!(Error::EmptyExp)
178 }
179
180 if exponent.len() > P::Scalar::N_BITS {
181 bail!(Error::SmallBaseField)
182 }
183
184 equal_n_vars_check(&exponent)?;
185
186 if exponent[0].n_vars() != base.n_vars() {
187 bail!(Error::NumberOfVariablesMismatch)
188 }
189
190 let mut single_bit_output_layers_data = vec![Vec::new(); exponent_bit_width];
191
192 single_bit_output_layers_data[0] = evaluate_first_layer_output_packed(
193 exponent[exponent_bit_width - 1].clone(),
194 Base::Dynamic(base.clone()),
195 );
196
197 for layer_idx_from_left in 1..exponent_bit_width {
198 single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed(
199 exponent[exponent_bit_width - layer_idx_from_left - 1].clone(),
200 Base::Dynamic(base.clone()),
201 &single_bit_output_layers_data[layer_idx_from_left - 1],
202 );
203 }
204
205 let single_bit_output_layers_data = single_bit_output_layers_data
206 .into_iter()
207 .map(|single_bit_output_layers_data| {
208 MultilinearExtension::new(exponent[0].n_vars(), single_bit_output_layers_data)
209 .map(MLEDirectAdapter::<P>::from)
210 .map(|mle| mle.upcast_arc_dyn())
211 })
212 .collect::<Result<Vec<_>, binius_math::Error>>()?;
213
214 Ok(Self {
215 exponent,
216 single_bit_output_layers_data,
217 base: BaseWitness::Dynamic(base),
218 })
219 }
220
221 pub const fn uses_dynamic_base(&self) -> bool {
222 match self.base {
223 BaseWitness::Static(_) => false,
224 BaseWitness::Dynamic(_) => true,
225 }
226 }
227
228 pub fn n_vars(&self) -> usize {
229 self.exponent[0].n_vars()
230 }
231
232 pub fn exponentiation_result_witness(&self) -> MultilinearWitness<'a, P> {
234 self.single_bit_output_layers_data
235 .last()
236 .expect("single_bit_output_layers_data not empty")
237 .clone()
238 }
239}