binius_core/transparent/
tower_basis.rs1use std::marker::PhantomData;
4
5use binius_field::{BinaryField128b, Field, PackedField, TowerField};
6use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes};
7use binius_math::MultilinearExtension;
8use binius_utils::{bail, DeserializeBytes};
9
10use crate::polynomial::{Error, MultivariatePoly};
11
12#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)]
25pub struct TowerBasis<F: Field> {
26 k: usize,
27 iota: usize,
28 _marker: PhantomData<F>,
29}
30
31inventory::submit! {
32 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
33 "TowerBasis",
34 |buf, mode| Ok(Box::new(TowerBasis::<BinaryField128b>::deserialize(&mut *buf, mode)?))
35 )
36}
37
38impl<F: TowerField> TowerBasis<F> {
39 pub fn new(k: usize, iota: usize) -> Result<Self, Error> {
40 if iota + k > F::TOWER_LEVEL {
41 bail!(Error::ArgumentRangeError {
42 arg: "iota + k".into(),
43 range: 0..F::TOWER_LEVEL + 1,
44 });
45 }
46 Ok(Self {
47 k,
48 iota,
49 _marker: Default::default(),
50 })
51 }
52
53 pub fn multilinear_extension<P: PackedField<Scalar = F>>(
54 &self,
55 ) -> Result<MultilinearExtension<P>, Error> {
56 let n_values = (1 << self.k) / P::WIDTH;
57 let values = (0..n_values)
58 .map(|i| {
59 let mut packed_value = P::default();
60 for j in 0..P::WIDTH {
61 let basis_idx = i * P::WIDTH + j;
62 let value = TowerField::basis(self.iota, basis_idx)?;
63 packed_value.set(j, value);
64 }
65 Ok(packed_value)
66 })
67 .collect::<Result<Vec<_>, Error>>()?;
68
69 Ok(MultilinearExtension::from_values(values)?)
70 }
71}
72
73#[erased_serialize_bytes]
74impl<F> MultivariatePoly<F> for TowerBasis<F>
75where
76 F: TowerField,
77{
78 fn n_vars(&self) -> usize {
79 self.k
80 }
81
82 fn degree(&self) -> usize {
83 self.k
84 }
85
86 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
87 if query.len() != self.k {
88 bail!(Error::IncorrectQuerySize { expected: self.k });
89 }
90
91 let mut result = F::ONE;
92 for (i, query_i) in query.iter().enumerate() {
93 let r_comp = F::ONE - query_i;
94 let basis_elt = <F as TowerField>::basis(self.iota + i, 1)?;
95 result *= r_comp + *query_i * basis_elt;
96 }
97 Ok(result)
98 }
99
100 fn binary_tower_level(&self) -> usize {
101 self.iota + self.k
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use std::iter::repeat_with;
108
109 use binius_field::{BinaryField128b, BinaryField32b, PackedBinaryField4x32b};
110 use binius_hal::{make_portable_backend, ComputationBackendExt};
111 use rand::{rngs::StdRng, SeedableRng};
112
113 use super::*;
114
115 fn test_consistency(iota: usize, k: usize) {
116 type F = BinaryField128b;
117 let mut rng = StdRng::seed_from_u64(0);
118 let backend = make_portable_backend();
119
120 let basis = TowerBasis::<F>::new(k, iota).unwrap();
121 let challenge = repeat_with(|| <F as Field>::random(&mut rng))
122 .take(k)
123 .collect::<Vec<_>>();
124
125 let eval1 = basis.evaluate(&challenge).unwrap();
126 let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
127 let mle = basis.multilinear_extension::<F>().unwrap();
128 let eval2 = mle.evaluate(&multilin_query).unwrap();
129
130 assert_eq!(eval1, eval2);
131 }
132
133 #[test]
134 fn test_consistency_packing() {
135 let iota = 2;
136 let kappa = 3;
137 type F = BinaryField32b;
138 type P = PackedBinaryField4x32b;
139 let mut rng = StdRng::seed_from_u64(0);
140 let backend = make_portable_backend();
141
142 let basis = TowerBasis::<F>::new(kappa, iota).unwrap();
143 let challenge = repeat_with(|| <F as Field>::random(&mut rng))
144 .take(kappa)
145 .collect::<Vec<_>>();
146 let eval1 = basis.evaluate(&challenge).unwrap();
147 let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
148 let mle = basis.multilinear_extension::<P>().unwrap();
149 let eval2 = mle.evaluate(&multilin_query).unwrap();
150 assert_eq!(eval1, eval2);
151 }
152
153 #[test]
154 fn test_consistency_all() {
155 for iota in 0..=7 {
156 for k in 0..=(7 - iota) {
157 test_consistency(iota, k);
158 }
159 }
160 }
161}