binius_core/constraint_system/
exp.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{
4	as_packed_field::{PackScalar, PackedType},
5	linear_transformation::{PackedTransformationFactory, Transformation},
6	packed::get_packed_slice,
7	underlier::WithUnderlier,
8	ExtensionField, Field, PackedExtension, PackedField, RepackedExtension, TowerField,
9};
10use binius_macros::{DeserializeBytes, SerializeBytes};
11use binius_math::MultilinearExtension;
12use binius_maybe_rayon::prelude::*;
13use binius_utils::bail;
14use itertools::chain;
15use tracing::instrument;
16
17use super::{
18	common::{FExt, FFastExt},
19	error::Error,
20};
21use crate::{
22	constraint_system::channel::OracleOrConst,
23	oracle::{MultilinearOracleSet, OracleId},
24	protocols::{
25		evalcheck::EvalcheckMultilinearClaim,
26		gkr_exp::{self, BaseExpReductionOutput, BaseExpWitness, ExpClaim},
27	},
28	tower::{ProverTowerFamily, ProverTowerUnderlier},
29	witness::{MultilinearExtensionIndex, MultilinearWitness},
30};
31
32#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
33pub struct Exp<F: Field> {
34	/// A vector of `OracleId`s representing the exponent in little-endian bit order
35	pub bits_ids: Vec<OracleId>,
36	pub base: OracleOrConst<F>,
37	pub exp_result_id: OracleId,
38}
39
40impl<F: TowerField> Exp<F> {
41	pub fn n_vars(&self, oracles: &MultilinearOracleSet<F>) -> usize {
42		oracles.n_vars(self.exp_result_id)
43	}
44}
45
46pub fn max_n_vars<F: TowerField>(exponents: &[Exp<F>], oracles: &MultilinearOracleSet<F>) -> usize {
47	exponents
48		.iter()
49		.map(|m| m.n_vars(oracles))
50		.max()
51		.unwrap_or(0)
52}
53
54type MultiplicationWitnesses<'a, U, Tower> =
55	Vec<BaseExpWitness<'a, PackedType<U, FFastExt<Tower>>>>;
56
57/// Constructs [`BaseExpWitness`] instances and adds the exponentiation-result witnesses
58/// to the MultiplicationWitnesses.
59#[instrument(skip_all, name = "exp::make_exp_witnesses")]
60pub fn make_exp_witnesses<'a, U, Tower>(
61	witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
62	oracles: &MultilinearOracleSet<FExt<Tower>>,
63	exponents: &[Exp<Tower::B128>],
64) -> Result<MultiplicationWitnesses<'a, U, Tower>, Error>
65where
66	U: ProverTowerUnderlier<Tower>,
67	Tower: ProverTowerFamily,
68	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
69	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
70{
71	// Since dynamic witnesses may need the `exp_result` of static witnesses,
72	// we start processing with static ones first.
73	let static_exponents_iter = exponents
74		.iter()
75		.filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
76	let dynamic_exponents_iter = exponents
77		.iter()
78		.filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
79
80	chain!(static_exponents_iter, dynamic_exponents_iter)
81		.map(|exp| {
82			let fast_exponent_witnesses =
83				get_fast_exponent_witnesses::<U, Tower>(witness, &exp.bits_ids)?;
84
85			let (exp_witness, tower_level) = match exp.base {
86				OracleOrConst::Const { base, tower_level } => {
87					let witness = gkr_exp::BaseExpWitness::new_with_static_base(
88						fast_exponent_witnesses,
89						base.into(),
90					)?;
91					(witness, tower_level)
92				}
93				OracleOrConst::Oracle(base_id) => {
94					let fast_base_witnesses =
95						to_fast_witness::<U, Tower>(witness.get_multilin_poly(base_id)?)?;
96
97					let witness = gkr_exp::BaseExpWitness::new_with_dynamic_base(
98						fast_exponent_witnesses,
99						fast_base_witnesses,
100					)?;
101
102					let tower_level = oracles.tower_level(base_id);
103
104					(witness, tower_level)
105				}
106			};
107
108			let exp_result_witness = match tower_level {
109				0..=3 => repack_witness::<U, Tower, Tower::B8>(
110					exp_witness.exponentiation_result_witness(),
111				)?,
112				4 => repack_witness::<U, Tower, Tower::B16>(
113					exp_witness.exponentiation_result_witness(),
114				)?,
115				5 => repack_witness::<U, Tower, Tower::B32>(
116					exp_witness.exponentiation_result_witness(),
117				)?,
118				6 => repack_witness::<U, Tower, Tower::B64>(
119					exp_witness.exponentiation_result_witness(),
120				)?,
121				7 => repack_witness::<U, Tower, Tower::B128>(
122					exp_witness.exponentiation_result_witness(),
123				)?,
124				_ => bail!(Error::IncorrectTowerLevel),
125			};
126
127			witness.update_multilin_poly([(exp.exp_result_id, exp_result_witness)])?;
128
129			Ok(exp_witness)
130		})
131		.collect::<Result<Vec<_>, Error>>()
132}
133
134pub fn make_claims<F>(
135	exponents: &[Exp<F>],
136	oracles: &MultilinearOracleSet<F>,
137	eval_point: &[F],
138	evals: &[F],
139) -> Result<Vec<ExpClaim<F>>, Error>
140where
141	F: TowerField,
142{
143	let static_exponents_iter = exponents
144		.iter()
145		.filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
146	let dynamic_exponents_iter = exponents
147		.iter()
148		.filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
149	let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
150
151	let constant_bases = exponents_iter
152		.clone()
153		.map(|exp| match exp.base {
154			OracleOrConst::Const { base, .. } => Some(base),
155			OracleOrConst::Oracle(_) => None,
156		})
157		.collect::<Vec<_>>();
158
159	let exponents_ids = exponents_iter
160		.cloned()
161		.map(|exp| exp.bits_ids)
162		.collect::<Vec<_>>();
163
164	gkr_exp::construct_gkr_exp_claims(&exponents_ids, evals, constant_bases, oracles, eval_point)
165		.map_err(Error::from)
166}
167
168pub fn make_eval_claims<F: TowerField>(
169	exponents: &[Exp<F>],
170	base_exp_output: BaseExpReductionOutput<F>,
171) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
172	let static_exponents_iter = exponents
173		.iter()
174		.filter(|exp| matches!(exp.base, OracleOrConst::Const { .. }));
175	let dynamic_exponents_iter = exponents
176		.iter()
177		.filter(|exp| matches!(exp.base, OracleOrConst::Oracle(_)));
178	let exponents_iter = chain!(static_exponents_iter, dynamic_exponents_iter);
179
180	let dynamic_base_ids = exponents_iter
181		.clone()
182		.map(|exp| match exp.base {
183			OracleOrConst::Const { .. } => None,
184			OracleOrConst::Oracle(base_id) => Some(base_id),
185		})
186		.collect::<Vec<_>>();
187
188	let metas = exponents_iter
189		.cloned()
190		.map(|exp| exp.bits_ids)
191		.collect::<Vec<_>>();
192
193	gkr_exp::make_eval_claims(metas, base_exp_output, dynamic_base_ids).map_err(Error::from)
194}
195
196#[instrument(skip_all, name = "exp::repack_witness")]
197// Because the exponentiation witness operates in a large field, the number of leading zeroes
198// depends on the FExpBase. To optimize storage and avoid committing unnecessary zeroes, we
199// can repack B128 into FExpBase.
200fn repack_witness<U, Tower, FExpBase>(
201	witness: MultilinearWitness<PackedType<U, FFastExt<Tower>>>,
202) -> Result<MultilinearWitness<PackedType<U, FExt<Tower>>>, Error>
203where
204	U: ProverTowerUnderlier<Tower> + PackScalar<FExpBase>,
205	Tower: ProverTowerFamily,
206	FExpBase: TowerField,
207	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
208	PackedType<U, Tower::B128>: RepackedExtension<PackedType<U, FExpBase>>,
209{
210	let from_fast = Tower::packed_transformation_from_fast();
211
212	let f_exp_log_width = PackedType::<U, FExpBase>::LOG_WIDTH;
213	let log_width = PackedType::<U, FFastExt<Tower>>::LOG_WIDTH;
214	let n_vars = witness.n_vars();
215
216	let ext_degree = <Tower::B128 as ExtensionField<FExpBase>>::DEGREE;
217
218	const MAX_SUBCUBE_VARS: usize = 8;
219	let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
220
221	let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
222
223	let mut repacked_evals = vec![
224		PackedType::<U, FExpBase>::default();
225		1 << witness.n_vars().saturating_sub(f_exp_log_width)
226	];
227
228	repacked_evals
229		.par_chunks_mut(subcube_packed_size / ext_degree)
230		.enumerate()
231		.for_each(|(subcube_index, repacked_evals)| {
232			let mut subcube_evals =
233				vec![PackedType::<U, FExt<Tower>>::default(); subcube_packed_size];
234
235			let underliers =
236				PackedType::<U, FExt<Tower>>::to_underliers_ref_mut(&mut subcube_evals);
237
238			let fast_subcube_evals =
239				PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
240
241			witness
242				.subcube_evals(subcube_vars, subcube_index, 0, fast_subcube_evals)
243				.expect("repacked_evals chunks are ext_degree times smaller");
244
245			for underlier in underliers.iter_mut() {
246				let src = PackedType::<U, FFastExt<Tower>>::from_underlier(*underlier);
247				let dest = from_fast.transform(&src);
248				*underlier = PackedType::<U, FExt<Tower>>::to_underlier(dest);
249			}
250
251			let demoted = PackedType::<U, Tower::B128>::cast_bases(&subcube_evals);
252
253			if ext_degree == 1 {
254				repacked_evals.clone_from_slice(demoted);
255			} else {
256				demoted.chunks(ext_degree).zip(repacked_evals).for_each(
257					|(demoted, repacked_evals)| {
258						*repacked_evals = PackedType::<U, FExpBase>::from_fn(|i| {
259							get_packed_slice(demoted, i * ext_degree)
260						})
261					},
262				)
263			}
264		});
265
266	Ok(MultilinearExtension::new(witness.n_vars(), repacked_evals)?.specialize_arc_dyn())
267}
268
269fn to_fast_witness<U, Tower>(
270	witness: MultilinearWitness<PackedType<U, FExt<Tower>>>,
271) -> Result<MultilinearWitness<PackedType<U, FFastExt<Tower>>>, Error>
272where
273	U: ProverTowerUnderlier<Tower>,
274	Tower: ProverTowerFamily,
275	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
276{
277	let to_fast = Tower::packed_transformation_to_fast();
278
279	let log_width = PackedType::<U, FExt<Tower>>::LOG_WIDTH;
280	let n_vars = witness.n_vars();
281
282	let mut fast_packed_evals =
283		vec![PackedType::<U, FFastExt<Tower>>::default(); 1 << n_vars.saturating_sub(log_width)];
284
285	const MAX_SUBCUBE_VARS: usize = 8;
286	let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
287
288	let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
289
290	fast_packed_evals
291		.par_chunks_mut(subcube_packed_size)
292		.enumerate()
293		.for_each(|(subcube_index, fast_subcube)| {
294			let underliers = PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
295
296			let subcube_evals = PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
297			witness
298				.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
299				.expect("fast_packed_evals has correct size");
300
301			for underlier in underliers.iter_mut() {
302				let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
303				let dest = to_fast.transform(&src);
304				*underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
305			}
306		});
307
308	MultilinearExtension::new(witness.n_vars(), fast_packed_evals)
309		.map(|mle| mle.specialize_arc_dyn())
310		.map_err(Error::from)
311}
312
313type FastExponentWitnesses<'a, U, Tower> =
314	Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>;
315
316/// Casts witness from 1B to FastB128.
317/// TODO: Update when we start using byteslicing.
318fn get_fast_exponent_witnesses<'a, U, Tower>(
319	witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
320	ids: &[OracleId],
321) -> Result<FastExponentWitnesses<'a, U, Tower>, Error>
322where
323	U: ProverTowerUnderlier<Tower>,
324	Tower: ProverTowerFamily,
325	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
326		+ RepackedExtension<PackedType<U, Tower::B1>>,
327{
328	ids.iter()
329		.map(|&id| {
330			let exp_witness = witness.get_multilin_poly(id)?;
331
332			let packed_evals = exp_witness
333				.packed_evals()
334				.expect("poly contain packed_evals");
335
336			let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
337
338			MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
339				.map(|mle| mle.specialize_arc_dyn())
340				.map_err(Error::from)
341		})
342		.collect::<Result<Vec<_>, _>>()
343}