binius_core/constraint_system/
exp.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::cmp::Reverse;
4
5use binius_field::{
6	ExtensionField, Field, PackedExtension, PackedField, RepackedExtension, TowerField,
7	as_packed_field::{PackScalar, PackedType},
8	linear_transformation::{PackedTransformationFactory, Transformation},
9	packed::get_packed_slice_checked,
10	tower::{ProverTowerFamily, ProverTowerUnderlier},
11	underlier::WithUnderlier,
12};
13use binius_macros::{DeserializeBytes, SerializeBytes};
14use binius_math::{MultilinearExtension, MultilinearPoly};
15use binius_maybe_rayon::prelude::*;
16use binius_utils::bail;
17use tracing::instrument;
18
19use super::{
20	common::{FExt, FFastExt},
21	error::Error,
22};
23use crate::{
24	constraint_system::channel::OracleOrConst,
25	oracle::{MultilinearOracleSet, OracleId},
26	protocols::{
27		evalcheck::EvalcheckMultilinearClaim,
28		gkr_exp::{self, BaseExpReductionOutput, BaseExpWitness, ExpClaim},
29	},
30	witness::{MultilinearExtensionIndex, MultilinearWitness},
31};
32
33#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
34pub struct Exp<F: Field> {
35	/// A vector of `OracleId`s representing the exponent in little-endian bit order
36	pub bits_ids: Vec<OracleId>,
37	pub base: OracleOrConst<F>,
38	pub exp_result_id: OracleId,
39}
40
41impl<F: TowerField> Exp<F> {
42	pub fn n_vars(&self, oracles: &MultilinearOracleSet<F>) -> usize {
43		oracles.n_vars(self.exp_result_id)
44	}
45}
46
47pub fn max_n_vars<F: TowerField>(exponents: &[Exp<F>], oracles: &MultilinearOracleSet<F>) -> usize {
48	exponents
49		.iter()
50		.map(|m| m.n_vars(oracles))
51		.max()
52		.unwrap_or(0)
53}
54
55type MultiplicationWitnesses<'a, U, Tower> =
56	Vec<BaseExpWitness<'a, PackedType<U, FFastExt<Tower>>>>;
57
58/// Constructs [`BaseExpWitness`] instances and adds the exponentiation-result witnesses
59/// to the MultiplicationWitnesses.
60#[instrument(skip_all, name = "exp::make_exp_witnesses")]
61pub fn make_exp_witnesses<'a, U, Tower>(
62	witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
63	oracles: &MultilinearOracleSet<FExt<Tower>>,
64	exponents: &[Exp<Tower::B128>],
65) -> Result<MultiplicationWitnesses<'a, U, Tower>, Error>
66where
67	U: ProverTowerUnderlier<Tower>,
68	Tower: ProverTowerFamily,
69	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
70	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
71{
72	exponents
73		.iter()
74		.map(|exp| {
75			let fast_exponent_witnesses =
76				get_fast_exponent_witnesses::<U, Tower>(witness, &exp.bits_ids)?;
77
78			let (exp_witness, tower_level) = match exp.base {
79				OracleOrConst::Const { base, tower_level } => {
80					let witness = gkr_exp::BaseExpWitness::new_with_static_base::<
81						PackedType<U, FFastExt<Tower>>,
82					>(fast_exponent_witnesses, base.into())?;
83					(witness, tower_level)
84				}
85				OracleOrConst::Oracle(base_id) => {
86					let fast_base_witnesses =
87						to_fast_witness::<U, Tower>(witness.get_multilin_poly(base_id)?)?;
88
89					let witness = gkr_exp::BaseExpWitness::new_with_dynamic_base::<
90						PackedType<U, FFastExt<Tower>>,
91					>(fast_exponent_witnesses, fast_base_witnesses)?;
92
93					let tower_level = oracles.tower_level(base_id);
94
95					(witness, tower_level)
96				}
97			};
98
99			let exp_result_witness = match tower_level {
100				0..=3 => repack_witness::<U, Tower, Tower::B8>(
101					exp_witness.exponentiation_result_witness(),
102				)?,
103				4 => repack_witness::<U, Tower, Tower::B16>(
104					exp_witness.exponentiation_result_witness(),
105				)?,
106				5 => repack_witness::<U, Tower, Tower::B32>(
107					exp_witness.exponentiation_result_witness(),
108				)?,
109				6 => repack_witness::<U, Tower, Tower::B64>(
110					exp_witness.exponentiation_result_witness(),
111				)?,
112				7 => repack_witness::<U, Tower, Tower::B128>(
113					exp_witness.exponentiation_result_witness(),
114				)?,
115				_ => bail!(Error::IncorrectTowerLevel),
116			};
117
118			witness.update_multilin_poly([(exp.exp_result_id, exp_result_witness)])?;
119
120			Ok(exp_witness)
121		})
122		.collect::<Result<Vec<_>, Error>>()
123}
124
125pub fn make_claims<F>(
126	exponents: &[Exp<F>],
127	oracles: &MultilinearOracleSet<F>,
128	eval_point: &[F],
129	evals: &[F],
130) -> Result<Vec<ExpClaim<F>>, Error>
131where
132	F: TowerField,
133{
134	let constant_bases = exponents
135		.iter()
136		.map(|exp| match exp.base {
137			OracleOrConst::Const { base, .. } => Some(base),
138			OracleOrConst::Oracle(_) => None,
139		})
140		.collect::<Vec<_>>();
141
142	let exponents_ids = exponents
143		.iter()
144		.map(|exp| exp.bits_ids.clone())
145		.collect::<Vec<_>>();
146
147	gkr_exp::construct_gkr_exp_claims(&exponents_ids, evals, constant_bases, oracles, eval_point)
148		.map_err(Error::from)
149}
150
151pub fn make_eval_claims<F: TowerField>(
152	exponents: &[Exp<F>],
153	base_exp_output: BaseExpReductionOutput<F>,
154) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
155	let dynamic_base_ids = exponents
156		.iter()
157		.map(|exp| match exp.base {
158			OracleOrConst::Const { .. } => None,
159			OracleOrConst::Oracle(base_id) => Some(base_id),
160		})
161		.collect::<Vec<_>>();
162
163	let metas = exponents
164		.iter()
165		.map(|exp| exp.bits_ids.clone())
166		.collect::<Vec<_>>();
167
168	gkr_exp::make_eval_claims(metas, base_exp_output, dynamic_base_ids).map_err(Error::from)
169}
170
171#[instrument(skip_all, name = "exp::repack_witness")]
172// Because the exponentiation witness operates in a large field, the number of leading zeroes
173// depends on the FExpBase. To optimize storage and avoid committing unnecessary zeroes, we
174// can repack B128 into FExpBase.
175fn repack_witness<U, Tower, FExpBase>(
176	witness: MultilinearWitness<PackedType<U, FFastExt<Tower>>>,
177) -> Result<MultilinearWitness<PackedType<U, FExt<Tower>>>, Error>
178where
179	U: ProverTowerUnderlier<Tower> + PackScalar<FExpBase>,
180	Tower: ProverTowerFamily,
181	FExpBase: TowerField,
182	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
183	PackedType<U, Tower::B128>: RepackedExtension<PackedType<U, FExpBase>>,
184{
185	let from_fast = Tower::packed_transformation_from_fast();
186
187	let f_exp_log_width = PackedType::<U, FExpBase>::LOG_WIDTH;
188	let log_width = PackedType::<U, FFastExt<Tower>>::LOG_WIDTH;
189	let n_vars = witness.n_vars();
190
191	let ext_degree = <Tower::B128 as ExtensionField<FExpBase>>::DEGREE;
192
193	const MAX_SUBCUBE_VARS: usize = 8;
194	let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
195
196	let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
197
198	let mut repacked_evals =
199		vec![PackedType::<U, FExpBase>::default(); 1 << n_vars.saturating_sub(f_exp_log_width)];
200
201	repacked_evals
202		.par_chunks_mut((subcube_packed_size / ext_degree).max(1))
203		.enumerate()
204		.for_each(|(subcube_index, repacked_evals)| {
205			let mut subcube_evals =
206				vec![PackedType::<U, FExt<Tower>>::default(); subcube_packed_size];
207
208			let underliers =
209				PackedType::<U, FExt<Tower>>::to_underliers_ref_mut(&mut subcube_evals);
210
211			let fast_subcube_evals =
212				PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
213
214			witness
215				.subcube_evals(subcube_vars, subcube_index, 0, fast_subcube_evals)
216				.expect("repacked_evals chunks are ext_degree times smaller");
217
218			for underlier in underliers.iter_mut() {
219				let src = PackedType::<U, FFastExt<Tower>>::from_underlier(*underlier);
220				let dest = from_fast.transform(&src);
221				*underlier = PackedType::<U, FExt<Tower>>::to_underlier(dest);
222			}
223
224			let demoted = PackedType::<U, Tower::B128>::cast_bases(&subcube_evals);
225
226			if ext_degree == 1 {
227				repacked_evals.clone_from_slice(demoted);
228			} else {
229				demoted.chunks(ext_degree).zip(repacked_evals).for_each(
230					|(demoted, repacked_evals)| {
231						*repacked_evals = PackedType::<U, FExpBase>::from_fn(|i| {
232							get_packed_slice_checked(demoted, i * ext_degree)
233								.unwrap_or(FExpBase::ZERO)
234						});
235					},
236				)
237			}
238		});
239
240	Ok(MultilinearExtension::new(witness.n_vars(), repacked_evals)?.specialize_arc_dyn())
241}
242
243fn to_fast_witness<U, Tower>(
244	witness: MultilinearWitness<PackedType<U, FExt<Tower>>>,
245) -> Result<MultilinearWitness<PackedType<U, FFastExt<Tower>>>, Error>
246where
247	U: ProverTowerUnderlier<Tower>,
248	Tower: ProverTowerFamily,
249	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
250{
251	let to_fast = Tower::packed_transformation_to_fast();
252
253	let log_width = PackedType::<U, FExt<Tower>>::LOG_WIDTH;
254	let n_vars = witness.n_vars();
255
256	let mut fast_packed_evals =
257		vec![PackedType::<U, FFastExt<Tower>>::default(); 1 << n_vars.saturating_sub(log_width)];
258
259	const MAX_SUBCUBE_VARS: usize = 8;
260	let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
261
262	let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
263
264	fast_packed_evals
265		.par_chunks_mut(subcube_packed_size)
266		.enumerate()
267		.for_each(|(subcube_index, fast_subcube)| {
268			let underliers = PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
269
270			let subcube_evals = PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
271			witness
272				.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
273				.expect("fast_packed_evals has correct size");
274
275			for underlier in underliers.iter_mut() {
276				let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
277				let dest = to_fast.transform(&src);
278				*underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
279			}
280		});
281
282	MultilinearExtension::new(witness.n_vars(), fast_packed_evals)
283		.map(|mle| mle.specialize_arc_dyn())
284		.map_err(Error::from)
285}
286
287type FastExponentWitnesses<'a, U, Tower> =
288	Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>;
289
290/// Casts witness from 1B to FastB128.
291/// TODO: Update when we start using byteslicing.
292fn get_fast_exponent_witnesses<'a, U, Tower>(
293	witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
294	ids: &[OracleId],
295) -> Result<FastExponentWitnesses<'a, U, Tower>, Error>
296where
297	U: ProverTowerUnderlier<Tower>,
298	Tower: ProverTowerFamily,
299	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
300		+ RepackedExtension<PackedType<U, Tower::B1>>,
301{
302	ids.iter()
303		.map(|&id| {
304			let exp_witness = witness.get_multilin_poly(id)?;
305
306			let packed_evals = exp_witness
307				.packed_evals()
308				.expect("poly contain packed_evals");
309
310			let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
311
312			MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
313				.map(|mle| mle.specialize_arc_dyn())
314				.map_err(Error::from)
315		})
316		.collect::<Result<Vec<_>, _>>()
317}
318
319pub fn reorder_exponents<F: TowerField>(
320	exponents: &mut [Exp<F>],
321	oracles: &MultilinearOracleSet<F>,
322) {
323	// Since dynamic witnesses may need the `exp_result` of static witnesses,
324	// we start processing with static ones first.
325	exponents.sort_by_key(|exp| match exp.base {
326		OracleOrConst::Const { .. } => (Reverse(exp.n_vars(oracles)), 0),
327		OracleOrConst::Oracle(_) => (Reverse(exp.n_vars(oracles)), 1),
328	});
329}