binius_core/constraint_system/
exp.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{
4	as_packed_field::{PackScalar, PackedType},
5	linear_transformation::{PackedTransformationFactory, Transformation},
6	packed::get_packed_slice_checked,
7	tower::{ProverTowerFamily, ProverTowerUnderlier},
8	underlier::WithUnderlier,
9	ExtensionField, Field, PackedExtension, PackedField, RepackedExtension, TowerField,
10};
11use binius_macros::{DeserializeBytes, SerializeBytes};
12use binius_math::MultilinearExtension;
13use binius_maybe_rayon::prelude::*;
14use binius_utils::bail;
15use itertools::chain;
16use tracing::instrument;
17
18use super::{
19	common::{FExt, FFastExt},
20	error::Error,
21};
22use crate::{
23	constraint_system::channel::OracleOrConst,
24	oracle::{MultilinearOracleSet, OracleId},
25	protocols::{
26		evalcheck::EvalcheckMultilinearClaim,
27		gkr_exp::{self, BaseExpReductionOutput, BaseExpWitness, ExpClaim},
28	},
29	witness::{MultilinearExtensionIndex, MultilinearWitness},
30};
31
32#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
33pub struct Exp<F: Field> {
34	/// A vector of `OracleId`s representing the exponent in little-endian bit order
35	pub bits_ids: Vec<OracleId>,
36	pub base: OracleOrConst<F>,
37	pub exp_result_id: OracleId,
38}
39
40impl<F: TowerField> Exp<F> {
41	pub fn n_vars(&self, oracles: &MultilinearOracleSet<F>) -> usize {
42		oracles.n_vars(self.exp_result_id)
43	}
44}
45
46pub fn max_n_vars<F: TowerField>(exponents: &[Exp<F>], oracles: &MultilinearOracleSet<F>) -> usize {
47	exponents
48		.iter()
49		.map(|m| m.n_vars(oracles))
50		.max()
51		.unwrap_or(0)
52}
53
54type MultiplicationWitnesses<'a, U, Tower> =
55	Vec<BaseExpWitness<'a, PackedType<U, FFastExt<Tower>>>>;
56
57/// Constructs [`BaseExpWitness`] instances and adds the exponentiation-result witnesses
58/// to the MultiplicationWitnesses.
59#[instrument(skip_all, name = "exp::make_exp_witnesses")]
60pub fn make_exp_witnesses<'a, U, Tower>(
61	witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
62	oracles: &MultilinearOracleSet<FExt<Tower>>,
63	exponents: &[Exp<Tower::B128>],
64) -> Result<MultiplicationWitnesses<'a, U, Tower>, Error>
65where
66	U: ProverTowerUnderlier<Tower>,
67	Tower: ProverTowerFamily,
68	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
69	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
70{
71	// Since dynamic witnesses may need the `exp_result` of static witnesses,
72	// we start processing with static ones first.
73	let static_exponents_iter = exponents
74		.iter()
75		.filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
76	let dynamic_exponents_iter = exponents
77		.iter()
78		.filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
79
80	chain!(static_exponents_iter, dynamic_exponents_iter)
81		.map(|exp| {
82			let fast_exponent_witnesses =
83				get_fast_exponent_witnesses::<U, Tower>(witness, &exp.bits_ids)?;
84
85			let (exp_witness, tower_level) = match exp.base {
86				OracleOrConst::Const { base, tower_level } => {
87					let witness = gkr_exp::BaseExpWitness::new_with_static_base::<
88						PackedType<U, FFastExt<Tower>>,
89					>(fast_exponent_witnesses, base.into())?;
90					(witness, tower_level)
91				}
92				OracleOrConst::Oracle(base_id) => {
93					let fast_base_witnesses =
94						to_fast_witness::<U, Tower>(witness.get_multilin_poly(base_id)?)?;
95
96					let witness = gkr_exp::BaseExpWitness::new_with_dynamic_base::<
97						PackedType<U, FFastExt<Tower>>,
98					>(fast_exponent_witnesses, fast_base_witnesses)?;
99
100					let tower_level = oracles.tower_level(base_id);
101
102					(witness, tower_level)
103				}
104			};
105
106			let exp_result_witness = match tower_level {
107				0..=3 => repack_witness::<U, Tower, Tower::B8>(
108					exp_witness.exponentiation_result_witness(),
109				)?,
110				4 => repack_witness::<U, Tower, Tower::B16>(
111					exp_witness.exponentiation_result_witness(),
112				)?,
113				5 => repack_witness::<U, Tower, Tower::B32>(
114					exp_witness.exponentiation_result_witness(),
115				)?,
116				6 => repack_witness::<U, Tower, Tower::B64>(
117					exp_witness.exponentiation_result_witness(),
118				)?,
119				7 => repack_witness::<U, Tower, Tower::B128>(
120					exp_witness.exponentiation_result_witness(),
121				)?,
122				_ => bail!(Error::IncorrectTowerLevel),
123			};
124
125			witness.update_multilin_poly([(exp.exp_result_id, exp_result_witness)])?;
126
127			Ok(exp_witness)
128		})
129		.collect::<Result<Vec<_>, Error>>()
130}
131
132pub fn make_claims<F>(
133	exponents: &[Exp<F>],
134	oracles: &MultilinearOracleSet<F>,
135	eval_point: &[F],
136	evals: &[F],
137) -> Result<Vec<ExpClaim<F>>, Error>
138where
139	F: TowerField,
140{
141	let static_exponents_iter = exponents
142		.iter()
143		.filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
144	let dynamic_exponents_iter = exponents
145		.iter()
146		.filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
147	let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
148
149	let constant_bases = exponents_iter
150		.clone()
151		.map(|exp| match exp.base {
152			OracleOrConst::Const { base, .. } => Some(base),
153			OracleOrConst::Oracle(_) => None,
154		})
155		.collect::<Vec<_>>();
156
157	let exponents_ids = exponents_iter
158		.cloned()
159		.map(|exp| exp.bits_ids)
160		.collect::<Vec<_>>();
161
162	gkr_exp::construct_gkr_exp_claims(&exponents_ids, evals, constant_bases, oracles, eval_point)
163		.map_err(Error::from)
164}
165
166pub fn make_eval_claims<F: TowerField>(
167	exponents: &[Exp<F>],
168	base_exp_output: BaseExpReductionOutput<F>,
169) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
170	let static_exponents_iter = exponents
171		.iter()
172		.filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
173	let dynamic_exponents_iter = exponents
174		.iter()
175		.filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
176	let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
177
178	let dynamic_base_ids = exponents_iter
179		.clone()
180		.map(|exp| match exp.base {
181			OracleOrConst::Const { .. } => None,
182			OracleOrConst::Oracle(base_id) => Some(base_id),
183		})
184		.collect::<Vec<_>>();
185
186	let metas = exponents_iter
187		.cloned()
188		.map(|exp| exp.bits_ids)
189		.collect::<Vec<_>>();
190
191	gkr_exp::make_eval_claims(metas, base_exp_output, dynamic_base_ids).map_err(Error::from)
192}
193
194#[instrument(skip_all, name = "exp::repack_witness")]
195// Because the exponentiation witness operates in a large field, the number of leading zeroes
196// depends on the FExpBase. To optimize storage and avoid committing unnecessary zeroes, we
197// can repack B128 into FExpBase.
198fn repack_witness<U, Tower, FExpBase>(
199	witness: MultilinearWitness<PackedType<U, FFastExt<Tower>>>,
200) -> Result<MultilinearWitness<PackedType<U, FExt<Tower>>>, Error>
201where
202	U: ProverTowerUnderlier<Tower> + PackScalar<FExpBase>,
203	Tower: ProverTowerFamily,
204	FExpBase: TowerField,
205	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
206	PackedType<U, Tower::B128>: RepackedExtension<PackedType<U, FExpBase>>,
207{
208	let from_fast = Tower::packed_transformation_from_fast();
209
210	let f_exp_log_width = PackedType::<U, FExpBase>::LOG_WIDTH;
211	let log_width = PackedType::<U, FFastExt<Tower>>::LOG_WIDTH;
212	let n_vars = witness.n_vars();
213
214	let ext_degree = <Tower::B128 as ExtensionField<FExpBase>>::DEGREE;
215
216	const MAX_SUBCUBE_VARS: usize = 8;
217	let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
218
219	let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
220
221	let mut repacked_evals =
222		vec![PackedType::<U, FExpBase>::default(); 1 << n_vars.saturating_sub(f_exp_log_width)];
223
224	repacked_evals
225		.par_chunks_mut((subcube_packed_size / ext_degree).max(1))
226		.enumerate()
227		.for_each(|(subcube_index, repacked_evals)| {
228			let mut subcube_evals =
229				vec![PackedType::<U, FExt<Tower>>::default(); subcube_packed_size];
230
231			let underliers =
232				PackedType::<U, FExt<Tower>>::to_underliers_ref_mut(&mut subcube_evals);
233
234			let fast_subcube_evals =
235				PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
236
237			witness
238				.subcube_evals(subcube_vars, subcube_index, 0, fast_subcube_evals)
239				.expect("repacked_evals chunks are ext_degree times smaller");
240
241			for underlier in underliers.iter_mut() {
242				let src = PackedType::<U, FFastExt<Tower>>::from_underlier(*underlier);
243				let dest = from_fast.transform(&src);
244				*underlier = PackedType::<U, FExt<Tower>>::to_underlier(dest);
245			}
246
247			let demoted = PackedType::<U, Tower::B128>::cast_bases(&subcube_evals);
248
249			if ext_degree == 1 {
250				repacked_evals.clone_from_slice(demoted);
251			} else {
252				demoted.chunks(ext_degree).zip(repacked_evals).for_each(
253					|(demoted, repacked_evals)| {
254						*repacked_evals = PackedType::<U, FExpBase>::from_fn(|i| {
255							get_packed_slice_checked(demoted, i * ext_degree)
256								.unwrap_or(FExpBase::ZERO)
257						});
258					},
259				)
260			}
261		});
262
263	Ok(MultilinearExtension::new(witness.n_vars(), repacked_evals)?.specialize_arc_dyn())
264}
265
266fn to_fast_witness<U, Tower>(
267	witness: MultilinearWitness<PackedType<U, FExt<Tower>>>,
268) -> Result<MultilinearWitness<PackedType<U, FFastExt<Tower>>>, Error>
269where
270	U: ProverTowerUnderlier<Tower>,
271	Tower: ProverTowerFamily,
272	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
273{
274	let to_fast = Tower::packed_transformation_to_fast();
275
276	let log_width = PackedType::<U, FExt<Tower>>::LOG_WIDTH;
277	let n_vars = witness.n_vars();
278
279	let mut fast_packed_evals =
280		vec![PackedType::<U, FFastExt<Tower>>::default(); 1 << n_vars.saturating_sub(log_width)];
281
282	const MAX_SUBCUBE_VARS: usize = 8;
283	let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
284
285	let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
286
287	fast_packed_evals
288		.par_chunks_mut(subcube_packed_size)
289		.enumerate()
290		.for_each(|(subcube_index, fast_subcube)| {
291			let underliers = PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
292
293			let subcube_evals = PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
294			witness
295				.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
296				.expect("fast_packed_evals has correct size");
297
298			for underlier in underliers.iter_mut() {
299				let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
300				let dest = to_fast.transform(&src);
301				*underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
302			}
303		});
304
305	MultilinearExtension::new(witness.n_vars(), fast_packed_evals)
306		.map(|mle| mle.specialize_arc_dyn())
307		.map_err(Error::from)
308}
309
310type FastExponentWitnesses<'a, U, Tower> =
311	Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>;
312
313/// Casts witness from 1B to FastB128.
314/// TODO: Update when we start using byteslicing.
315fn get_fast_exponent_witnesses<'a, U, Tower>(
316	witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
317	ids: &[OracleId],
318) -> Result<FastExponentWitnesses<'a, U, Tower>, Error>
319where
320	U: ProverTowerUnderlier<Tower>,
321	Tower: ProverTowerFamily,
322	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
323		+ RepackedExtension<PackedType<U, Tower::B1>>,
324{
325	ids.iter()
326		.map(|&id| {
327			let exp_witness = witness.get_multilin_poly(id)?;
328
329			let packed_evals = exp_witness
330				.packed_evals()
331				.expect("poly contain packed_evals");
332
333			let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
334
335			MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
336				.map(|mle| mle.specialize_arc_dyn())
337				.map_err(Error::from)
338		})
339		.collect::<Result<Vec<_>, _>>()
340}