binius_math/
tensor_algebra.rs1use std::{
4 iter::Sum,
5 marker::PhantomData,
6 mem,
7 ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{ExtensionField, Field, util::inner_product_unchecked};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct TensorAlgebra<F, FE>
23where
24 F: Field,
25 FE: ExtensionField<F>,
26{
27 pub elems: Vec<FE>,
28 _marker: PhantomData<F>,
29}
30
31impl<F, FE> Default for TensorAlgebra<F, FE>
32where
33 F: Field,
34 FE: ExtensionField<F>,
35{
36 fn default() -> Self {
37 Self {
38 elems: vec![FE::default(); FE::DEGREE],
39 _marker: PhantomData,
40 }
41 }
42}
43
44impl<F, FE> TensorAlgebra<F, FE>
45where
46 F: Field,
47 FE: ExtensionField<F>,
48{
49 pub fn new(mut elems: Vec<FE>) -> Self {
55 elems.resize(FE::DEGREE, FE::ZERO);
56 Self {
57 elems,
58 _marker: PhantomData,
59 }
60 }
61
62 pub const fn kappa() -> usize {
64 FE::LOG_DEGREE
65 }
66
67 pub const fn byte_size() -> usize {
69 mem::size_of::<FE>() << Self::kappa()
70 }
71
72 pub fn one() -> Self {
74 let mut one = Self::default();
75 one.elems[0] = FE::ONE;
76 one
77 }
78
79 pub fn vertical_elems(&self) -> &[FE] {
81 &self.elems
82 }
83
84 pub fn tensor(vertical: FE, horizontal: FE) -> Self {
86 let elems = horizontal
87 .iter_bases()
88 .map(|base| vertical * base)
89 .collect();
90 Self {
91 elems,
92 _marker: PhantomData,
93 }
94 }
95
96 pub fn from_vertical(x: FE) -> Self {
98 let mut elems = vec![FE::ZERO; FE::DEGREE];
99 elems[0] = x;
100 Self {
101 elems,
102 _marker: PhantomData,
103 }
104 }
105
106 pub fn try_extract_vertical(&self) -> Option<FE> {
108 self.elems
109 .iter()
110 .skip(1)
111 .all(|&elem| elem == FE::ZERO)
112 .then_some(self.elems[0])
113 }
114
115 pub fn scale_vertical(mut self, scalar: FE) -> Self {
117 for elem_i in &mut self.elems {
118 *elem_i *= scalar;
119 }
120 self
121 }
122}
123
124impl<F: Field, FE: ExtensionField<F>> TensorAlgebra<F, FE> {
125 pub fn scale_horizontal(self, scalar: FE) -> Self {
131 self.transpose().scale_vertical(scalar).transpose()
132 }
133
134 pub fn transpose(mut self) -> Self {
138 FE::square_transpose(&mut self.elems)
139 .expect("transpose dimensions are square by struct invariant");
140 self
141 }
142
143 pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
149 inner_product_unchecked::<FE, _>(self.transpose().elems, coeffs.iter().copied())
150 }
151}
152
153impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
154where
155 F: Field,
156 FE: ExtensionField<F>,
157{
158 type Output = Self;
159
160 fn add(mut self, rhs: &Self) -> Self {
161 self.add_assign(rhs);
162 self
163 }
164}
165
166impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
167where
168 F: Field,
169 FE: ExtensionField<F>,
170{
171 type Output = Self;
172
173 fn sub(mut self, rhs: &Self) -> Self {
174 self.sub_assign(rhs);
175 self
176 }
177}
178
179impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
180where
181 F: Field,
182 FE: ExtensionField<F>,
183{
184 fn add_assign(&mut self, rhs: &Self) {
185 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
186 *self_i += *rhs_i;
187 }
188 }
189}
190
191impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
192where
193 F: Field,
194 FE: ExtensionField<F>,
195{
196 fn sub_assign(&mut self, rhs: &Self) {
197 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
198 *self_i -= *rhs_i;
199 }
200 }
201}
202
203impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
204where
205 F: Field,
206 FE: ExtensionField<F>,
207{
208 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
209 iter.fold(Self::default(), |sum, item| sum + item)
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use binius_field::{BinaryField1b as B1, Random};
216 use rand::{SeedableRng, rngs::StdRng};
217
218 use super::*;
219 use crate::test_utils::B128;
220
221 #[test]
222 fn test_tensor_product() {
223 type F = B1;
224 type FE = B128;
225
226 let mut rng = StdRng::seed_from_u64(0);
227
228 let vert = FE::random(&mut rng);
229 let hztl = FE::random(&mut rng);
230
231 let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
232 assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
233 }
234
235 #[test]
236 fn test_try_extract_vertical() {
237 type F = B1;
238 type FE = B128;
239
240 let mut rng = StdRng::seed_from_u64(0);
241
242 let vert = FE::random(&mut rng);
243 let elem = TensorAlgebra::<F, _>::from_vertical(vert);
244 assert_eq!(elem.try_extract_vertical(), Some(vert));
245
246 let hztl = FE::new(1111);
248 let elem = elem.scale_horizontal(hztl);
249 assert_eq!(elem.try_extract_vertical(), None);
250
251 let hztl_inv = hztl.invert().unwrap();
253 let elem = elem.scale_horizontal(hztl_inv);
254 assert_eq!(elem.try_extract_vertical(), Some(vert));
255
256 let hztl_subfield = FE::from(F::ONE);
258 let elem = elem.scale_horizontal(hztl_subfield);
259 assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
260 }
261}