1use std::cmp::Reverse;
4
5use binius_field::{
6 ExtensionField, Field, PackedExtension, PackedField, RepackedExtension, TowerField,
7 as_packed_field::{PackScalar, PackedType},
8 linear_transformation::{PackedTransformationFactory, Transformation},
9 packed::get_packed_slice_checked,
10 tower::{ProverTowerFamily, ProverTowerUnderlier},
11 underlier::WithUnderlier,
12};
13use binius_macros::{DeserializeBytes, SerializeBytes};
14use binius_math::{MultilinearExtension, MultilinearPoly};
15use binius_maybe_rayon::prelude::*;
16use binius_utils::bail;
17use tracing::instrument;
18
19use super::{
20 common::{FExt, FFastExt},
21 error::Error,
22};
23use crate::{
24 constraint_system::channel::OracleOrConst,
25 oracle::{MultilinearOracleSet, OracleId},
26 protocols::{
27 evalcheck::EvalcheckMultilinearClaim,
28 gkr_exp::{self, BaseExpReductionOutput, BaseExpWitness, ExpClaim},
29 },
30 witness::{MultilinearExtensionIndex, MultilinearWitness},
31};
32
33#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
34pub struct Exp<F: Field> {
35 pub bits_ids: Vec<OracleId>,
37 pub base: OracleOrConst<F>,
38 pub exp_result_id: OracleId,
39}
40
41impl<F: TowerField> Exp<F> {
42 pub fn n_vars(&self, oracles: &MultilinearOracleSet<F>) -> usize {
43 oracles.n_vars(self.exp_result_id)
44 }
45}
46
47pub fn max_n_vars<F: TowerField>(exponents: &[Exp<F>], oracles: &MultilinearOracleSet<F>) -> usize {
48 exponents
49 .iter()
50 .map(|m| m.n_vars(oracles))
51 .max()
52 .unwrap_or(0)
53}
54
55type MultiplicationWitnesses<'a, U, Tower> =
56 Vec<BaseExpWitness<'a, PackedType<U, FFastExt<Tower>>>>;
57
58#[instrument(skip_all, name = "exp::make_exp_witnesses")]
61pub fn make_exp_witnesses<'a, U, Tower>(
62 witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
63 oracles: &MultilinearOracleSet<FExt<Tower>>,
64 exponents: &[Exp<Tower::B128>],
65) -> Result<MultiplicationWitnesses<'a, U, Tower>, Error>
66where
67 U: ProverTowerUnderlier<Tower>,
68 Tower: ProverTowerFamily,
69 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
70 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
71{
72 exponents
73 .iter()
74 .map(|exp| {
75 let fast_exponent_witnesses =
76 get_fast_exponent_witnesses::<U, Tower>(witness, &exp.bits_ids)?;
77
78 let (exp_witness, tower_level) = match exp.base {
79 OracleOrConst::Const { base, tower_level } => {
80 let witness = gkr_exp::BaseExpWitness::new_with_static_base::<
81 PackedType<U, FFastExt<Tower>>,
82 >(fast_exponent_witnesses, base.into())?;
83 (witness, tower_level)
84 }
85 OracleOrConst::Oracle(base_id) => {
86 let fast_base_witnesses =
87 to_fast_witness::<U, Tower>(witness.get_multilin_poly(base_id)?)?;
88
89 let witness = gkr_exp::BaseExpWitness::new_with_dynamic_base::<
90 PackedType<U, FFastExt<Tower>>,
91 >(fast_exponent_witnesses, fast_base_witnesses)?;
92
93 let tower_level = oracles.tower_level(base_id);
94
95 (witness, tower_level)
96 }
97 };
98
99 let exp_result_witness = match tower_level {
100 0..=3 => repack_witness::<U, Tower, Tower::B8>(
101 exp_witness.exponentiation_result_witness(),
102 )?,
103 4 => repack_witness::<U, Tower, Tower::B16>(
104 exp_witness.exponentiation_result_witness(),
105 )?,
106 5 => repack_witness::<U, Tower, Tower::B32>(
107 exp_witness.exponentiation_result_witness(),
108 )?,
109 6 => repack_witness::<U, Tower, Tower::B64>(
110 exp_witness.exponentiation_result_witness(),
111 )?,
112 7 => repack_witness::<U, Tower, Tower::B128>(
113 exp_witness.exponentiation_result_witness(),
114 )?,
115 _ => bail!(Error::IncorrectTowerLevel),
116 };
117
118 witness.update_multilin_poly([(exp.exp_result_id, exp_result_witness)])?;
119
120 Ok(exp_witness)
121 })
122 .collect::<Result<Vec<_>, Error>>()
123}
124
125pub fn make_claims<F>(
126 exponents: &[Exp<F>],
127 oracles: &MultilinearOracleSet<F>,
128 eval_point: &[F],
129 evals: &[F],
130) -> Result<Vec<ExpClaim<F>>, Error>
131where
132 F: TowerField,
133{
134 let constant_bases = exponents
135 .iter()
136 .map(|exp| match exp.base {
137 OracleOrConst::Const { base, .. } => Some(base),
138 OracleOrConst::Oracle(_) => None,
139 })
140 .collect::<Vec<_>>();
141
142 let exponents_ids = exponents
143 .iter()
144 .map(|exp| exp.bits_ids.clone())
145 .collect::<Vec<_>>();
146
147 gkr_exp::construct_gkr_exp_claims(&exponents_ids, evals, constant_bases, oracles, eval_point)
148 .map_err(Error::from)
149}
150
151pub fn make_eval_claims<F: TowerField>(
152 exponents: &[Exp<F>],
153 base_exp_output: BaseExpReductionOutput<F>,
154) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
155 let dynamic_base_ids = exponents
156 .iter()
157 .map(|exp| match exp.base {
158 OracleOrConst::Const { .. } => None,
159 OracleOrConst::Oracle(base_id) => Some(base_id),
160 })
161 .collect::<Vec<_>>();
162
163 let metas = exponents
164 .iter()
165 .map(|exp| exp.bits_ids.clone())
166 .collect::<Vec<_>>();
167
168 gkr_exp::make_eval_claims(metas, base_exp_output, dynamic_base_ids).map_err(Error::from)
169}
170
171#[instrument(skip_all, name = "exp::repack_witness")]
172fn repack_witness<U, Tower, FExpBase>(
176 witness: MultilinearWitness<PackedType<U, FFastExt<Tower>>>,
177) -> Result<MultilinearWitness<PackedType<U, FExt<Tower>>>, Error>
178where
179 U: ProverTowerUnderlier<Tower> + PackScalar<FExpBase>,
180 Tower: ProverTowerFamily,
181 FExpBase: TowerField,
182 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
183 PackedType<U, Tower::B128>: RepackedExtension<PackedType<U, FExpBase>>,
184{
185 let from_fast = Tower::packed_transformation_from_fast();
186
187 let f_exp_log_width = PackedType::<U, FExpBase>::LOG_WIDTH;
188 let log_width = PackedType::<U, FFastExt<Tower>>::LOG_WIDTH;
189 let n_vars = witness.n_vars();
190
191 let ext_degree = <Tower::B128 as ExtensionField<FExpBase>>::DEGREE;
192
193 const MAX_SUBCUBE_VARS: usize = 8;
194 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
195
196 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
197
198 let mut repacked_evals =
199 vec![PackedType::<U, FExpBase>::default(); 1 << n_vars.saturating_sub(f_exp_log_width)];
200
201 repacked_evals
202 .par_chunks_mut((subcube_packed_size / ext_degree).max(1))
203 .enumerate()
204 .for_each(|(subcube_index, repacked_evals)| {
205 let mut subcube_evals =
206 vec![PackedType::<U, FExt<Tower>>::default(); subcube_packed_size];
207
208 let underliers =
209 PackedType::<U, FExt<Tower>>::to_underliers_ref_mut(&mut subcube_evals);
210
211 let fast_subcube_evals =
212 PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
213
214 witness
215 .subcube_evals(subcube_vars, subcube_index, 0, fast_subcube_evals)
216 .expect("repacked_evals chunks are ext_degree times smaller");
217
218 for underlier in underliers.iter_mut() {
219 let src = PackedType::<U, FFastExt<Tower>>::from_underlier(*underlier);
220 let dest = from_fast.transform(&src);
221 *underlier = PackedType::<U, FExt<Tower>>::to_underlier(dest);
222 }
223
224 let demoted = PackedType::<U, Tower::B128>::cast_bases(&subcube_evals);
225
226 if ext_degree == 1 {
227 repacked_evals.clone_from_slice(demoted);
228 } else {
229 demoted.chunks(ext_degree).zip(repacked_evals).for_each(
230 |(demoted, repacked_evals)| {
231 *repacked_evals = PackedType::<U, FExpBase>::from_fn(|i| {
232 get_packed_slice_checked(demoted, i * ext_degree)
233 .unwrap_or(FExpBase::ZERO)
234 });
235 },
236 )
237 }
238 });
239
240 Ok(MultilinearExtension::new(witness.n_vars(), repacked_evals)?.specialize_arc_dyn())
241}
242
243fn to_fast_witness<U, Tower>(
244 witness: MultilinearWitness<PackedType<U, FExt<Tower>>>,
245) -> Result<MultilinearWitness<PackedType<U, FFastExt<Tower>>>, Error>
246where
247 U: ProverTowerUnderlier<Tower>,
248 Tower: ProverTowerFamily,
249 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
250{
251 let to_fast = Tower::packed_transformation_to_fast();
252
253 let log_width = PackedType::<U, FExt<Tower>>::LOG_WIDTH;
254 let n_vars = witness.n_vars();
255
256 let mut fast_packed_evals =
257 vec![PackedType::<U, FFastExt<Tower>>::default(); 1 << n_vars.saturating_sub(log_width)];
258
259 const MAX_SUBCUBE_VARS: usize = 8;
260 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
261
262 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
263
264 fast_packed_evals
265 .par_chunks_mut(subcube_packed_size)
266 .enumerate()
267 .for_each(|(subcube_index, fast_subcube)| {
268 let underliers = PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
269
270 let subcube_evals = PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
271 witness
272 .subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
273 .expect("fast_packed_evals has correct size");
274
275 for underlier in underliers.iter_mut() {
276 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
277 let dest = to_fast.transform(&src);
278 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
279 }
280 });
281
282 MultilinearExtension::new(witness.n_vars(), fast_packed_evals)
283 .map(|mle| mle.specialize_arc_dyn())
284 .map_err(Error::from)
285}
286
287type FastExponentWitnesses<'a, U, Tower> =
288 Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>;
289
290fn get_fast_exponent_witnesses<'a, U, Tower>(
293 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
294 ids: &[OracleId],
295) -> Result<FastExponentWitnesses<'a, U, Tower>, Error>
296where
297 U: ProverTowerUnderlier<Tower>,
298 Tower: ProverTowerFamily,
299 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
300 + RepackedExtension<PackedType<U, Tower::B1>>,
301{
302 ids.iter()
303 .map(|&id| {
304 let exp_witness = witness.get_multilin_poly(id)?;
305
306 let packed_evals = exp_witness
307 .packed_evals()
308 .expect("poly contain packed_evals");
309
310 let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
311
312 MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
313 .map(|mle| mle.specialize_arc_dyn())
314 .map_err(Error::from)
315 })
316 .collect::<Result<Vec<_>, _>>()
317}
318
319pub fn reorder_exponents<F: TowerField>(
320 exponents: &mut [Exp<F>],
321 oracles: &MultilinearOracleSet<F>,
322) {
323 exponents.sort_by_key(|exp| match exp.base {
326 OracleOrConst::Const { .. } => (Reverse(exp.n_vars(oracles)), 0),
327 OracleOrConst::Oracle(_) => (Reverse(exp.n_vars(oracles)), 1),
328 });
329}