binius_core/transparent/
tower_basis.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use binius_field::{BinaryField128b, Field, PackedField, TowerField};
6use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes};
7use binius_math::MultilinearExtension;
8use binius_utils::{bail, DeserializeBytes};
9
10use crate::polynomial::{Error, MultivariatePoly};
11
12/// Represents the $\mathcal{T}_{\iota}$-basis of $\mathcal{T}_{\iota+k}$
13///
14/// Recall that $\mathcal{T}_{\iota}$ is defined as
15/// * Let \mathbb{F} := \mathbb{F}_2[X_0, \ldots, X_{\iota-1}]
16/// * Let \mathcal{J} := (X_0^2 + X_0 + 1, \ldots, X_{\iota-1}^2 + X_{\iota-1}X_{\iota-2} + 1)
17/// * $\mathcal{T}_{\iota} := \mathbb{F} / J $
18///
19/// and $\mathcal{T}_{\iota}$ has the following $\mathbb{F}_2$-basis:
20/// * $1, X_0, X_1, X_0X_1, X_2, \ldots, X_0 X_1 \ldots X_{\iota-1}$
21///
22/// Thus, $\mathcal{T}_{\iota+k}$ has a $\mathcal{T}_{\iota}$-basis of size $2^k$:
23/// * $1, X_{\iota}, X_{\iota+1}, X_{\iota}X_{\iota+1}, X_{\iota+2}, \ldots, X_{\iota} X_{\iota+1} \ldots X_{\iota+k-1}$
24#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)]
25pub struct TowerBasis<F: Field> {
26	k: usize,
27	iota: usize,
28	_marker: PhantomData<F>,
29}
30
31inventory::submit! {
32	<dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
33		"TowerBasis",
34		|buf, mode| Ok(Box::new(TowerBasis::<BinaryField128b>::deserialize(&mut *buf, mode)?))
35	)
36}
37
38impl<F: TowerField> TowerBasis<F> {
39	pub fn new(k: usize, iota: usize) -> Result<Self, Error> {
40		if iota + k > F::TOWER_LEVEL {
41			bail!(Error::ArgumentRangeError {
42				arg: "iota + k".into(),
43				range: 0..F::TOWER_LEVEL + 1,
44			});
45		}
46		Ok(Self {
47			k,
48			iota,
49			_marker: Default::default(),
50		})
51	}
52
53	pub fn multilinear_extension<P: PackedField<Scalar = F>>(
54		&self,
55	) -> Result<MultilinearExtension<P>, Error> {
56		let n_values = (1 << self.k) / P::WIDTH;
57		let values = (0..n_values)
58			.map(|i| {
59				let mut packed_value = P::default();
60				for j in 0..P::WIDTH {
61					let basis_idx = i * P::WIDTH + j;
62					let value = TowerField::basis(self.iota, basis_idx)?;
63					packed_value.set(j, value);
64				}
65				Ok(packed_value)
66			})
67			.collect::<Result<Vec<_>, Error>>()?;
68
69		Ok(MultilinearExtension::from_values(values)?)
70	}
71}
72
73#[erased_serialize_bytes]
74impl<F> MultivariatePoly<F> for TowerBasis<F>
75where
76	F: TowerField,
77{
78	fn n_vars(&self) -> usize {
79		self.k
80	}
81
82	fn degree(&self) -> usize {
83		self.k
84	}
85
86	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
87		if query.len() != self.k {
88			bail!(Error::IncorrectQuerySize { expected: self.k });
89		}
90
91		let mut result = F::ONE;
92		for (i, query_i) in query.iter().enumerate() {
93			let r_comp = F::ONE - query_i;
94			let basis_elt = <F as TowerField>::basis(self.iota + i, 1)?;
95			result *= r_comp + *query_i * basis_elt;
96		}
97		Ok(result)
98	}
99
100	fn binary_tower_level(&self) -> usize {
101		self.iota + self.k
102	}
103}
104
105#[cfg(test)]
106mod tests {
107	use std::iter::repeat_with;
108
109	use binius_field::{BinaryField128b, BinaryField32b, PackedBinaryField4x32b};
110	use binius_hal::{make_portable_backend, ComputationBackendExt};
111	use rand::{rngs::StdRng, SeedableRng};
112
113	use super::*;
114
115	fn test_consistency(iota: usize, k: usize) {
116		type F = BinaryField128b;
117		let mut rng = StdRng::seed_from_u64(0);
118		let backend = make_portable_backend();
119
120		let basis = TowerBasis::<F>::new(k, iota).unwrap();
121		let challenge = repeat_with(|| <F as Field>::random(&mut rng))
122			.take(k)
123			.collect::<Vec<_>>();
124
125		let eval1 = basis.evaluate(&challenge).unwrap();
126		let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
127		let mle = basis.multilinear_extension::<F>().unwrap();
128		let eval2 = mle.evaluate(&multilin_query).unwrap();
129
130		assert_eq!(eval1, eval2);
131	}
132
133	#[test]
134	fn test_consistency_packing() {
135		let iota = 2;
136		let kappa = 3;
137		type F = BinaryField32b;
138		type P = PackedBinaryField4x32b;
139		let mut rng = StdRng::seed_from_u64(0);
140		let backend = make_portable_backend();
141
142		let basis = TowerBasis::<F>::new(kappa, iota).unwrap();
143		let challenge = repeat_with(|| <F as Field>::random(&mut rng))
144			.take(kappa)
145			.collect::<Vec<_>>();
146		let eval1 = basis.evaluate(&challenge).unwrap();
147		let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
148		let mle = basis.multilinear_extension::<P>().unwrap();
149		let eval2 = mle.evaluate(&multilin_query).unwrap();
150		assert_eq!(eval1, eval2);
151	}
152
153	#[test]
154	fn test_consistency_all() {
155		for iota in 0..=7 {
156			for k in 0..=(7 - iota) {
157				test_consistency(iota, k);
158			}
159		}
160	}
161}