1use binius_field::{
4 as_packed_field::{PackScalar, PackedType},
5 linear_transformation::{PackedTransformationFactory, Transformation},
6 packed::get_packed_slice_checked,
7 tower::{ProverTowerFamily, ProverTowerUnderlier},
8 underlier::WithUnderlier,
9 ExtensionField, Field, PackedExtension, PackedField, RepackedExtension, TowerField,
10};
11use binius_macros::{DeserializeBytes, SerializeBytes};
12use binius_math::MultilinearExtension;
13use binius_maybe_rayon::prelude::*;
14use binius_utils::bail;
15use itertools::chain;
16use tracing::instrument;
17
18use super::{
19 common::{FExt, FFastExt},
20 error::Error,
21};
22use crate::{
23 constraint_system::channel::OracleOrConst,
24 oracle::{MultilinearOracleSet, OracleId},
25 protocols::{
26 evalcheck::EvalcheckMultilinearClaim,
27 gkr_exp::{self, BaseExpReductionOutput, BaseExpWitness, ExpClaim},
28 },
29 witness::{MultilinearExtensionIndex, MultilinearWitness},
30};
31
32#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
33pub struct Exp<F: Field> {
34 pub bits_ids: Vec<OracleId>,
36 pub base: OracleOrConst<F>,
37 pub exp_result_id: OracleId,
38}
39
40impl<F: TowerField> Exp<F> {
41 pub fn n_vars(&self, oracles: &MultilinearOracleSet<F>) -> usize {
42 oracles.n_vars(self.exp_result_id)
43 }
44}
45
46pub fn max_n_vars<F: TowerField>(exponents: &[Exp<F>], oracles: &MultilinearOracleSet<F>) -> usize {
47 exponents
48 .iter()
49 .map(|m| m.n_vars(oracles))
50 .max()
51 .unwrap_or(0)
52}
53
54type MultiplicationWitnesses<'a, U, Tower> =
55 Vec<BaseExpWitness<'a, PackedType<U, FFastExt<Tower>>>>;
56
57#[instrument(skip_all, name = "exp::make_exp_witnesses")]
60pub fn make_exp_witnesses<'a, U, Tower>(
61 witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
62 oracles: &MultilinearOracleSet<FExt<Tower>>,
63 exponents: &[Exp<Tower::B128>],
64) -> Result<MultiplicationWitnesses<'a, U, Tower>, Error>
65where
66 U: ProverTowerUnderlier<Tower>,
67 Tower: ProverTowerFamily,
68 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
69 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
70{
71 let static_exponents_iter = exponents
74 .iter()
75 .filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
76 let dynamic_exponents_iter = exponents
77 .iter()
78 .filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
79
80 chain!(static_exponents_iter, dynamic_exponents_iter)
81 .map(|exp| {
82 let fast_exponent_witnesses =
83 get_fast_exponent_witnesses::<U, Tower>(witness, &exp.bits_ids)?;
84
85 let (exp_witness, tower_level) = match exp.base {
86 OracleOrConst::Const { base, tower_level } => {
87 let witness = gkr_exp::BaseExpWitness::new_with_static_base::<
88 PackedType<U, FFastExt<Tower>>,
89 >(fast_exponent_witnesses, base.into())?;
90 (witness, tower_level)
91 }
92 OracleOrConst::Oracle(base_id) => {
93 let fast_base_witnesses =
94 to_fast_witness::<U, Tower>(witness.get_multilin_poly(base_id)?)?;
95
96 let witness = gkr_exp::BaseExpWitness::new_with_dynamic_base::<
97 PackedType<U, FFastExt<Tower>>,
98 >(fast_exponent_witnesses, fast_base_witnesses)?;
99
100 let tower_level = oracles.tower_level(base_id);
101
102 (witness, tower_level)
103 }
104 };
105
106 let exp_result_witness = match tower_level {
107 0..=3 => repack_witness::<U, Tower, Tower::B8>(
108 exp_witness.exponentiation_result_witness(),
109 )?,
110 4 => repack_witness::<U, Tower, Tower::B16>(
111 exp_witness.exponentiation_result_witness(),
112 )?,
113 5 => repack_witness::<U, Tower, Tower::B32>(
114 exp_witness.exponentiation_result_witness(),
115 )?,
116 6 => repack_witness::<U, Tower, Tower::B64>(
117 exp_witness.exponentiation_result_witness(),
118 )?,
119 7 => repack_witness::<U, Tower, Tower::B128>(
120 exp_witness.exponentiation_result_witness(),
121 )?,
122 _ => bail!(Error::IncorrectTowerLevel),
123 };
124
125 witness.update_multilin_poly([(exp.exp_result_id, exp_result_witness)])?;
126
127 Ok(exp_witness)
128 })
129 .collect::<Result<Vec<_>, Error>>()
130}
131
132pub fn make_claims<F>(
133 exponents: &[Exp<F>],
134 oracles: &MultilinearOracleSet<F>,
135 eval_point: &[F],
136 evals: &[F],
137) -> Result<Vec<ExpClaim<F>>, Error>
138where
139 F: TowerField,
140{
141 let static_exponents_iter = exponents
142 .iter()
143 .filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
144 let dynamic_exponents_iter = exponents
145 .iter()
146 .filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
147 let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
148
149 let constant_bases = exponents_iter
150 .clone()
151 .map(|exp| match exp.base {
152 OracleOrConst::Const { base, .. } => Some(base),
153 OracleOrConst::Oracle(_) => None,
154 })
155 .collect::<Vec<_>>();
156
157 let exponents_ids = exponents_iter
158 .cloned()
159 .map(|exp| exp.bits_ids)
160 .collect::<Vec<_>>();
161
162 gkr_exp::construct_gkr_exp_claims(&exponents_ids, evals, constant_bases, oracles, eval_point)
163 .map_err(Error::from)
164}
165
166pub fn make_eval_claims<F: TowerField>(
167 exponents: &[Exp<F>],
168 base_exp_output: BaseExpReductionOutput<F>,
169) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
170 let static_exponents_iter = exponents
171 .iter()
172 .filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
173 let dynamic_exponents_iter = exponents
174 .iter()
175 .filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
176 let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
177
178 let dynamic_base_ids = exponents_iter
179 .clone()
180 .map(|exp| match exp.base {
181 OracleOrConst::Const { .. } => None,
182 OracleOrConst::Oracle(base_id) => Some(base_id),
183 })
184 .collect::<Vec<_>>();
185
186 let metas = exponents_iter
187 .cloned()
188 .map(|exp| exp.bits_ids)
189 .collect::<Vec<_>>();
190
191 gkr_exp::make_eval_claims(metas, base_exp_output, dynamic_base_ids).map_err(Error::from)
192}
193
194#[instrument(skip_all, name = "exp::repack_witness")]
195fn repack_witness<U, Tower, FExpBase>(
199 witness: MultilinearWitness<PackedType<U, FFastExt<Tower>>>,
200) -> Result<MultilinearWitness<PackedType<U, FExt<Tower>>>, Error>
201where
202 U: ProverTowerUnderlier<Tower> + PackScalar<FExpBase>,
203 Tower: ProverTowerFamily,
204 FExpBase: TowerField,
205 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
206 PackedType<U, Tower::B128>: RepackedExtension<PackedType<U, FExpBase>>,
207{
208 let from_fast = Tower::packed_transformation_from_fast();
209
210 let f_exp_log_width = PackedType::<U, FExpBase>::LOG_WIDTH;
211 let log_width = PackedType::<U, FFastExt<Tower>>::LOG_WIDTH;
212 let n_vars = witness.n_vars();
213
214 let ext_degree = <Tower::B128 as ExtensionField<FExpBase>>::DEGREE;
215
216 const MAX_SUBCUBE_VARS: usize = 8;
217 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
218
219 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
220
221 let mut repacked_evals =
222 vec![PackedType::<U, FExpBase>::default(); 1 << n_vars.saturating_sub(f_exp_log_width)];
223
224 repacked_evals
225 .par_chunks_mut((subcube_packed_size / ext_degree).max(1))
226 .enumerate()
227 .for_each(|(subcube_index, repacked_evals)| {
228 let mut subcube_evals =
229 vec![PackedType::<U, FExt<Tower>>::default(); subcube_packed_size];
230
231 let underliers =
232 PackedType::<U, FExt<Tower>>::to_underliers_ref_mut(&mut subcube_evals);
233
234 let fast_subcube_evals =
235 PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
236
237 witness
238 .subcube_evals(subcube_vars, subcube_index, 0, fast_subcube_evals)
239 .expect("repacked_evals chunks are ext_degree times smaller");
240
241 for underlier in underliers.iter_mut() {
242 let src = PackedType::<U, FFastExt<Tower>>::from_underlier(*underlier);
243 let dest = from_fast.transform(&src);
244 *underlier = PackedType::<U, FExt<Tower>>::to_underlier(dest);
245 }
246
247 let demoted = PackedType::<U, Tower::B128>::cast_bases(&subcube_evals);
248
249 if ext_degree == 1 {
250 repacked_evals.clone_from_slice(demoted);
251 } else {
252 demoted.chunks(ext_degree).zip(repacked_evals).for_each(
253 |(demoted, repacked_evals)| {
254 *repacked_evals = PackedType::<U, FExpBase>::from_fn(|i| {
255 get_packed_slice_checked(demoted, i * ext_degree)
256 .unwrap_or(FExpBase::ZERO)
257 });
258 },
259 )
260 }
261 });
262
263 Ok(MultilinearExtension::new(witness.n_vars(), repacked_evals)?.specialize_arc_dyn())
264}
265
266fn to_fast_witness<U, Tower>(
267 witness: MultilinearWitness<PackedType<U, FExt<Tower>>>,
268) -> Result<MultilinearWitness<PackedType<U, FFastExt<Tower>>>, Error>
269where
270 U: ProverTowerUnderlier<Tower>,
271 Tower: ProverTowerFamily,
272 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
273{
274 let to_fast = Tower::packed_transformation_to_fast();
275
276 let log_width = PackedType::<U, FExt<Tower>>::LOG_WIDTH;
277 let n_vars = witness.n_vars();
278
279 let mut fast_packed_evals =
280 vec![PackedType::<U, FFastExt<Tower>>::default(); 1 << n_vars.saturating_sub(log_width)];
281
282 const MAX_SUBCUBE_VARS: usize = 8;
283 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
284
285 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
286
287 fast_packed_evals
288 .par_chunks_mut(subcube_packed_size)
289 .enumerate()
290 .for_each(|(subcube_index, fast_subcube)| {
291 let underliers = PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
292
293 let subcube_evals = PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
294 witness
295 .subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
296 .expect("fast_packed_evals has correct size");
297
298 for underlier in underliers.iter_mut() {
299 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
300 let dest = to_fast.transform(&src);
301 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
302 }
303 });
304
305 MultilinearExtension::new(witness.n_vars(), fast_packed_evals)
306 .map(|mle| mle.specialize_arc_dyn())
307 .map_err(Error::from)
308}
309
310type FastExponentWitnesses<'a, U, Tower> =
311 Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>;
312
313fn get_fast_exponent_witnesses<'a, U, Tower>(
316 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
317 ids: &[OracleId],
318) -> Result<FastExponentWitnesses<'a, U, Tower>, Error>
319where
320 U: ProverTowerUnderlier<Tower>,
321 Tower: ProverTowerFamily,
322 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
323 + RepackedExtension<PackedType<U, Tower::B1>>,
324{
325 ids.iter()
326 .map(|&id| {
327 let exp_witness = witness.get_multilin_poly(id)?;
328
329 let packed_evals = exp_witness
330 .packed_evals()
331 .expect("poly contain packed_evals");
332
333 let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
334
335 MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
336 .map(|mle| mle.specialize_arc_dyn())
337 .map_err(Error::from)
338 })
339 .collect::<Result<Vec<_>, _>>()
340}