binius_core/transparent/
tower_basis.rs1use std::marker::PhantomData;
4
5use binius_field::{BinaryField128b, Field, PackedField, TowerField};
6use binius_macros::{DeserializeBytes, SerializeBytes, erased_serialize_bytes};
7use binius_math::MultilinearExtension;
8use binius_utils::{DeserializeBytes, bail};
9
10use crate::polynomial::{Error, MultivariatePoly};
11
12#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)]
26pub struct TowerBasis<F: Field> {
27 k: usize,
28 iota: usize,
29 _marker: PhantomData<F>,
30}
31
32inventory::submit! {
33 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
34 "TowerBasis",
35 |buf, mode| Ok(Box::new(TowerBasis::<BinaryField128b>::deserialize(&mut *buf, mode)?))
36 )
37}
38
39impl<F: TowerField> TowerBasis<F> {
40 pub fn new(k: usize, iota: usize) -> Result<Self, Error> {
41 if iota + k > F::TOWER_LEVEL {
42 bail!(Error::ArgumentRangeError {
43 arg: "iota + k".into(),
44 range: 0..F::TOWER_LEVEL + 1,
45 });
46 }
47 Ok(Self {
48 k,
49 iota,
50 _marker: Default::default(),
51 })
52 }
53
54 pub fn multilinear_extension<P: PackedField<Scalar = F>>(
55 &self,
56 ) -> Result<MultilinearExtension<P>, Error> {
57 let n_values = (1 << self.k) / P::WIDTH;
58 let values = (0..n_values)
59 .map(|i| {
60 let mut packed_value = P::default();
61 for j in 0..P::WIDTH {
62 let basis_idx = i * P::WIDTH + j;
63 let value = TowerField::basis(self.iota, basis_idx)?;
64 packed_value.set(j, value);
65 }
66 Ok(packed_value)
67 })
68 .collect::<Result<Vec<_>, Error>>()?;
69
70 Ok(MultilinearExtension::from_values(values)?)
71 }
72}
73
74#[erased_serialize_bytes]
75impl<F> MultivariatePoly<F> for TowerBasis<F>
76where
77 F: TowerField,
78{
79 fn n_vars(&self) -> usize {
80 self.k
81 }
82
83 fn degree(&self) -> usize {
84 self.k
85 }
86
87 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
88 if query.len() != self.k {
89 bail!(Error::IncorrectQuerySize {
90 expected: self.k,
91 actual: query.len()
92 });
93 }
94
95 let mut result = F::ONE;
96 for (i, query_i) in query.iter().enumerate() {
97 let r_comp = F::ONE - query_i;
98 let basis_elt = <F as TowerField>::basis(self.iota + i, 1)?;
99 result *= r_comp + *query_i * basis_elt;
100 }
101 Ok(result)
102 }
103
104 fn binary_tower_level(&self) -> usize {
105 self.iota + self.k
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use std::iter::repeat_with;
112
113 use binius_field::{BinaryField32b, BinaryField128b, PackedBinaryField4x32b};
114 use binius_hal::{ComputationBackendExt, make_portable_backend};
115 use rand::{SeedableRng, rngs::StdRng};
116
117 use super::*;
118
119 fn test_consistency(iota: usize, k: usize) {
120 type F = BinaryField128b;
121 let mut rng = StdRng::seed_from_u64(0);
122 let backend = make_portable_backend();
123
124 let basis = TowerBasis::<F>::new(k, iota).unwrap();
125 let challenge = repeat_with(|| <F as Field>::random(&mut rng))
126 .take(k)
127 .collect::<Vec<_>>();
128
129 let eval1 = basis.evaluate(&challenge).unwrap();
130 let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
131 let mle = basis.multilinear_extension::<F>().unwrap();
132 let eval2 = mle.evaluate(&multilin_query).unwrap();
133
134 assert_eq!(eval1, eval2);
135 }
136
137 #[test]
138 fn test_consistency_packing() {
139 let iota = 2;
140 let kappa = 3;
141 type F = BinaryField32b;
142 type P = PackedBinaryField4x32b;
143 let mut rng = StdRng::seed_from_u64(0);
144 let backend = make_portable_backend();
145
146 let basis = TowerBasis::<F>::new(kappa, iota).unwrap();
147 let challenge = repeat_with(|| <F as Field>::random(&mut rng))
148 .take(kappa)
149 .collect::<Vec<_>>();
150 let eval1 = basis.evaluate(&challenge).unwrap();
151 let multilin_query = backend.multilinear_query::<F>(&challenge).unwrap();
152 let mle = basis.multilinear_extension::<P>().unwrap();
153 let eval2 = mle.evaluate(&multilin_query).unwrap();
154 assert_eq!(eval1, eval2);
155 }
156
157 #[test]
158 fn test_consistency_all() {
159 for iota in 0..=7 {
160 for k in 0..=(7 - iota) {
161 test_consistency(iota, k);
162 }
163 }
164 }
165}