binius_core/transparent/
tower_basis.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use binius_field::{BinaryField128b, Field, PackedField, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes, erased_serialize_bytes};
7use binius_math::MultilinearExtension;
8use binius_utils::{DeserializeBytes, bail};
9
10use crate::polynomial::{Error, MultivariatePoly};
11
12/// Represents the $\mathcal{T}_{\iota}$-basis of $\mathcal{T}_{\iota+k}$
13///
14/// Recall that $\mathcal{T}_{\iota}$ is defined as
15/// * Let \mathbb{F} := \mathbb{F}_2[X_0, \ldots, X_{\iota-1}]
16/// * Let \mathcal{J} := (X_0^2 + X_0 + 1, \ldots, X_{\iota-1}^2 + X_{\iota-1}X_{\iota-2} + 1)
17/// * $\mathcal{T}_{\iota} := \mathbb{F} / J $
18///
19/// and $\mathcal{T}_{\iota}$ has the following $\mathbb{F}_2$-basis:
20/// * $1, X_0, X_1, X_0X_1, X_2, \ldots, X_0 X_1 \ldots X_{\iota-1}$
21///
22/// Thus, $\mathcal{T}_{\iota+k}$ has a $\mathcal{T}_{\iota}$-basis of size $2^k$:
23/// * $1, X_{\iota}, X_{\iota+1}, X_{\iota}X_{\iota+1}, X_{\iota+2}, \ldots, X_{\iota} X_{\iota+1}
24///   \ldots X_{\iota+k-1}$
25#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)]
26pub struct TowerBasis<F: Field> {
27	k: usize,
28	iota: usize,
29	_marker: PhantomData<F>,
30}
31
32inventory::submit! {
33	<dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
34		"TowerBasis",
35		|buf, mode| Ok(Box::new(TowerBasis::<BinaryField128b>::deserialize(&mut *buf, mode)?))
36	)
37}
38
39impl<F: TowerField> TowerBasis<F> {
40	pub fn new(k: usize, iota: usize) -> Result<Self, Error> {
41		if iota + k > F::TOWER_LEVEL {
42			bail!(Error::ArgumentRangeError {
43				arg: "iota + k".into(),
44				range: 0..F::TOWER_LEVEL + 1,
45			});
46		}
47		Ok(Self {
48			k,
49			iota,
50			_marker: Default::default(),
51		})
52	}
53
54	pub fn multilinear_extension<P: PackedField<Scalar = F>>(
55		&self,
56	) -> Result<MultilinearExtension<P>, Error> {
57		let n_values = (1 << self.k) / P::WIDTH;
58		let values = (0..n_values)
59			.map(|i| {
60				let mut packed_value = P::default();
61				for j in 0..P::WIDTH {
62					let basis_idx = i * P::WIDTH + j;
63					let value = TowerField::basis(self.iota, basis_idx)?;
64					packed_value.set(j, value);
65				}
66				Ok(packed_value)
67			})
68			.collect::<Result<Vec<_>, Error>>()?;
69
70		Ok(MultilinearExtension::from_values(values)?)
71	}
72}
73
74#[erased_serialize_bytes]
75impl<F> MultivariatePoly<F> for TowerBasis<F>
76where
77	F: TowerField,
78{
79	fn n_vars(&self) -> usize {
80		self.k
81	}
82
83	fn degree(&self) -> usize {
84		self.k
85	}
86
87	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
88		if query.len() != self.k {
89			bail!(Error::IncorrectQuerySize {
90				expected: self.k,
91				actual: query.len()
92			});
93		}
94
95		let mut result = F::ONE;
96		for (i, query_i) in query.iter().enumerate() {
97			let r_comp = F::ONE - query_i;
98			let basis_elt = <F as TowerField>::basis(self.iota + i, 1)?;
99			result *= r_comp + *query_i * basis_elt;
100		}
101		Ok(result)
102	}
103
104	fn binary_tower_level(&self) -> usize {
105		self.iota + self.k
106	}
107}
108
109#[cfg(test)]
110mod tests {
111	use std::iter::repeat_with;
112
113	use binius_field::{BinaryField32b, BinaryField128b, PackedBinaryField4x32b};
114	use binius_hal::{ComputationBackendExt, make_portable_backend};
115	use rand::{SeedableRng, rngs::StdRng};
116
117	use super::*;
118
119	fn test_consistency(iota: usize, k: usize) {
120		type F = BinaryField128b;
121		let mut rng = StdRng::seed_from_u64(0);
122		let backend = make_portable_backend();
123
124		let basis = TowerBasis::<F>::new(k, iota).unwrap();
125		let challenge = repeat_with(|| <F as Field>::random(&mut rng))
126			.take(k)
127			.collect::<Vec<_>>();
128
129		let eval1 = basis.evaluate(&challenge).unwrap();
130		let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
131		let mle = basis.multilinear_extension::<F>().unwrap();
132		let eval2 = mle.evaluate(&multilin_query).unwrap();
133
134		assert_eq!(eval1, eval2);
135	}
136
137	#[test]
138	fn test_consistency_packing() {
139		let iota = 2;
140		let kappa = 3;
141		type F = BinaryField32b;
142		type P = PackedBinaryField4x32b;
143		let mut rng = StdRng::seed_from_u64(0);
144		let backend = make_portable_backend();
145
146		let basis = TowerBasis::<F>::new(kappa, iota).unwrap();
147		let challenge = repeat_with(|| <F as Field>::random(&mut rng))
148			.take(kappa)
149			.collect::<Vec<_>>();
150		let eval1 = basis.evaluate(&challenge).unwrap();
151		let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
152		let mle = basis.multilinear_extension::<P>().unwrap();
153		let eval2 = mle.evaluate(&multilin_query).unwrap();
154		assert_eq!(eval1, eval2);
155	}
156
157	#[test]
158	fn test_consistency_all() {
159		for iota in 0..=7 {
160			for k in 0..=(7 - iota) {
161				test_consistency(iota, k);
162			}
163		}
164	}
165}