1use binius_field::{
4 as_packed_field::{PackScalar, PackedType},
5 linear_transformation::{PackedTransformationFactory, Transformation},
6 packed::get_packed_slice,
7 underlier::WithUnderlier,
8 ExtensionField, Field, PackedExtension, PackedField, RepackedExtension, TowerField,
9};
10use binius_macros::{DeserializeBytes, SerializeBytes};
11use binius_math::MultilinearExtension;
12use binius_maybe_rayon::prelude::*;
13use binius_utils::bail;
14use itertools::chain;
15use tracing::instrument;
16
17use super::{
18 common::{FExt, FFastExt},
19 error::Error,
20};
21use crate::{
22 constraint_system::channel::OracleOrConst,
23 oracle::{MultilinearOracleSet, OracleId},
24 protocols::{
25 evalcheck::EvalcheckMultilinearClaim,
26 gkr_exp::{self, BaseExpReductionOutput, BaseExpWitness, ExpClaim},
27 },
28 tower::{ProverTowerFamily, ProverTowerUnderlier},
29 witness::{MultilinearExtensionIndex, MultilinearWitness},
30};
31
32#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
33pub struct Exp<F: Field> {
34 pub bits_ids: Vec<OracleId>,
36 pub base: OracleOrConst<F>,
37 pub exp_result_id: OracleId,
38}
39
40impl<F: TowerField> Exp<F> {
41 pub fn n_vars(&self, oracles: &MultilinearOracleSet<F>) -> usize {
42 oracles.n_vars(self.exp_result_id)
43 }
44}
45
46pub fn max_n_vars<F: TowerField>(exponents: &[Exp<F>], oracles: &MultilinearOracleSet<F>) -> usize {
47 exponents
48 .iter()
49 .map(|m| m.n_vars(oracles))
50 .max()
51 .unwrap_or(0)
52}
53
54type MultiplicationWitnesses<'a, U, Tower> =
55 Vec<BaseExpWitness<'a, PackedType<U, FFastExt<Tower>>>>;
56
57#[instrument(skip_all, name = "exp::make_exp_witnesses")]
60pub fn make_exp_witnesses<'a, U, Tower>(
61 witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
62 oracles: &MultilinearOracleSet<FExt<Tower>>,
63 exponents: &[Exp<Tower::B128>],
64) -> Result<MultiplicationWitnesses<'a, U, Tower>, Error>
65where
66 U: ProverTowerUnderlier<Tower>,
67 Tower: ProverTowerFamily,
68 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
69 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
70{
71 let static_exponents_iter = exponents
74 .iter()
75 .filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
76 let dynamic_exponents_iter = exponents
77 .iter()
78 .filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
79
80 chain!(static_exponents_iter, dynamic_exponents_iter)
81 .map(|exp| {
82 let fast_exponent_witnesses =
83 get_fast_exponent_witnesses::<U, Tower>(witness, &exp.bits_ids)?;
84
85 let (exp_witness, tower_level) = match exp.base {
86 OracleOrConst::Const { base, tower_level } => {
87 let witness = gkr_exp::BaseExpWitness::new_with_static_base(
88 fast_exponent_witnesses,
89 base.into(),
90 )?;
91 (witness, tower_level)
92 }
93 OracleOrConst::Oracle(base_id) => {
94 let fast_base_witnesses =
95 to_fast_witness::<U, Tower>(witness.get_multilin_poly(base_id)?)?;
96
97 let witness = gkr_exp::BaseExpWitness::new_with_dynamic_base(
98 fast_exponent_witnesses,
99 fast_base_witnesses,
100 )?;
101
102 let tower_level = oracles.tower_level(base_id);
103
104 (witness, tower_level)
105 }
106 };
107
108 let exp_result_witness = match tower_level {
109 0..=3 => repack_witness::<U, Tower, Tower::B8>(
110 exp_witness.exponentiation_result_witness(),
111 )?,
112 4 => repack_witness::<U, Tower, Tower::B16>(
113 exp_witness.exponentiation_result_witness(),
114 )?,
115 5 => repack_witness::<U, Tower, Tower::B32>(
116 exp_witness.exponentiation_result_witness(),
117 )?,
118 6 => repack_witness::<U, Tower, Tower::B64>(
119 exp_witness.exponentiation_result_witness(),
120 )?,
121 7 => repack_witness::<U, Tower, Tower::B128>(
122 exp_witness.exponentiation_result_witness(),
123 )?,
124 _ => bail!(Error::IncorrectTowerLevel),
125 };
126
127 witness.update_multilin_poly([(exp.exp_result_id, exp_result_witness)])?;
128
129 Ok(exp_witness)
130 })
131 .collect::<Result<Vec<_>, Error>>()
132}
133
134pub fn make_claims<F>(
135 exponents: &[Exp<F>],
136 oracles: &MultilinearOracleSet<F>,
137 eval_point: &[F],
138 evals: &[F],
139) -> Result<Vec<ExpClaim<F>>, Error>
140where
141 F: TowerField,
142{
143 let static_exponents_iter = exponents
144 .iter()
145 .filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
146 let dynamic_exponents_iter = exponents
147 .iter()
148 .filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
149 let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
150
151 let constant_bases = exponents_iter
152 .clone()
153 .map(|exp| match exp.base {
154 OracleOrConst::Const { base, .. } => Some(base),
155 OracleOrConst::Oracle(_) => None,
156 })
157 .collect::<Vec<_>>();
158
159 let exponents_ids = exponents_iter
160 .cloned()
161 .map(|exp| exp.bits_ids)
162 .collect::<Vec<_>>();
163
164 gkr_exp::construct_gkr_exp_claims(&exponents_ids, evals, constant_bases, oracles, eval_point)
165 .map_err(Error::from)
166}
167
168pub fn make_eval_claims<F: TowerField>(
169 exponents: &[Exp<F>],
170 base_exp_output: BaseExpReductionOutput<F>,
171) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
172 let static_exponents_iter = exponents
173 .iter()
174 .filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
175 let dynamic_exponents_iter = exponents
176 .iter()
177 .filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
178 let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
179
180 let dynamic_base_ids = exponents_iter
181 .clone()
182 .map(|exp| match exp.base {
183 OracleOrConst::Const { .. } => None,
184 OracleOrConst::Oracle(base_id) => Some(base_id),
185 })
186 .collect::<Vec<_>>();
187
188 let metas = exponents_iter
189 .cloned()
190 .map(|exp| exp.bits_ids)
191 .collect::<Vec<_>>();
192
193 gkr_exp::make_eval_claims(metas, base_exp_output, dynamic_base_ids).map_err(Error::from)
194}
195
196#[instrument(skip_all, name = "exp::repack_witness")]
197fn repack_witness<U, Tower, FExpBase>(
201 witness: MultilinearWitness<PackedType<U, FFastExt<Tower>>>,
202) -> Result<MultilinearWitness<PackedType<U, FExt<Tower>>>, Error>
203where
204 U: ProverTowerUnderlier<Tower> + PackScalar<FExpBase>,
205 Tower: ProverTowerFamily,
206 FExpBase: TowerField,
207 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
208 PackedType<U, Tower::B128>: RepackedExtension<PackedType<U, FExpBase>>,
209{
210 let from_fast = Tower::packed_transformation_from_fast();
211
212 let f_exp_log_width = PackedType::<U, FExpBase>::LOG_WIDTH;
213 let log_width = PackedType::<U, FFastExt<Tower>>::LOG_WIDTH;
214 let n_vars = witness.n_vars();
215
216 let ext_degree = <Tower::B128 as ExtensionField<FExpBase>>::DEGREE;
217
218 const MAX_SUBCUBE_VARS: usize = 8;
219 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
220
221 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
222
223 let mut repacked_evals = vec![
224 PackedType::<U, FExpBase>::default();
225 1 << witness.n_vars().saturating_sub(f_exp_log_width)
226 ];
227
228 repacked_evals
229 .par_chunks_mut(subcube_packed_size / ext_degree)
230 .enumerate()
231 .for_each(|(subcube_index, repacked_evals)| {
232 let mut subcube_evals =
233 vec![PackedType::<U, FExt<Tower>>::default(); subcube_packed_size];
234
235 let underliers =
236 PackedType::<U, FExt<Tower>>::to_underliers_ref_mut(&mut subcube_evals);
237
238 let fast_subcube_evals =
239 PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
240
241 witness
242 .subcube_evals(subcube_vars, subcube_index, 0, fast_subcube_evals)
243 .expect("repacked_evals chunks are ext_degree times smaller");
244
245 for underlier in underliers.iter_mut() {
246 let src = PackedType::<U, FFastExt<Tower>>::from_underlier(*underlier);
247 let dest = from_fast.transform(&src);
248 *underlier = PackedType::<U, FExt<Tower>>::to_underlier(dest);
249 }
250
251 let demoted = PackedType::<U, Tower::B128>::cast_bases(&subcube_evals);
252
253 if ext_degree == 1 {
254 repacked_evals.clone_from_slice(demoted);
255 } else {
256 demoted.chunks(ext_degree).zip(repacked_evals).for_each(
257 |(demoted, repacked_evals)| {
258 *repacked_evals = PackedType::<U, FExpBase>::from_fn(|i| {
259 get_packed_slice(demoted, i * ext_degree)
260 })
261 },
262 )
263 }
264 });
265
266 Ok(MultilinearExtension::new(witness.n_vars(), repacked_evals)?.specialize_arc_dyn())
267}
268
269fn to_fast_witness<U, Tower>(
270 witness: MultilinearWitness<PackedType<U, FExt<Tower>>>,
271) -> Result<MultilinearWitness<PackedType<U, FFastExt<Tower>>>, Error>
272where
273 U: ProverTowerUnderlier<Tower>,
274 Tower: ProverTowerFamily,
275 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
276{
277 let to_fast = Tower::packed_transformation_to_fast();
278
279 let log_width = PackedType::<U, FExt<Tower>>::LOG_WIDTH;
280 let n_vars = witness.n_vars();
281
282 let mut fast_packed_evals =
283 vec![PackedType::<U, FFastExt<Tower>>::default(); 1 << n_vars.saturating_sub(log_width)];
284
285 const MAX_SUBCUBE_VARS: usize = 8;
286 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
287
288 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
289
290 fast_packed_evals
291 .par_chunks_mut(subcube_packed_size)
292 .enumerate()
293 .for_each(|(subcube_index, fast_subcube)| {
294 let underliers = PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
295
296 let subcube_evals = PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
297 witness
298 .subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
299 .expect("fast_packed_evals has correct size");
300
301 for underlier in underliers.iter_mut() {
302 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
303 let dest = to_fast.transform(&src);
304 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
305 }
306 });
307
308 MultilinearExtension::new(witness.n_vars(), fast_packed_evals)
309 .map(|mle| mle.specialize_arc_dyn())
310 .map_err(Error::from)
311}
312
313type FastExponentWitnesses<'a, U, Tower> =
314 Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>;
315
316fn get_fast_exponent_witnesses<'a, U, Tower>(
319 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
320 ids: &[OracleId],
321) -> Result<FastExponentWitnesses<'a, U, Tower>, Error>
322where
323 U: ProverTowerUnderlier<Tower>,
324 Tower: ProverTowerFamily,
325 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
326 + RepackedExtension<PackedType<U, Tower::B1>>,
327{
328 ids.iter()
329 .map(|&id| {
330 let exp_witness = witness.get_multilin_poly(id)?;
331
332 let packed_evals = exp_witness
333 .packed_evals()
334 .expect("poly contain packed_evals");
335
336 let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
337
338 MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
339 .map(|mle| mle.specialize_arc_dyn())
340 .map_err(Error::from)
341 })
342 .collect::<Result<Vec<_>, _>>()
343}