binius_math/
tensor_algebra.rs1use std::{
4 iter::Sum,
5 marker::PhantomData,
6 mem,
7 ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{ExtensionField, Field};
11
12use crate::inner_product::inner_product;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct TensorAlgebra<F, FE>
25where
26 F: Field,
27 FE: ExtensionField<F>,
28{
29 pub elems: Vec<FE>,
30 _marker: PhantomData<F>,
31}
32
33impl<F, FE> Default for TensorAlgebra<F, FE>
34where
35 F: Field,
36 FE: ExtensionField<F>,
37{
38 fn default() -> Self {
39 Self {
40 elems: vec![FE::default(); FE::DEGREE],
41 _marker: PhantomData,
42 }
43 }
44}
45
46impl<F, FE> TensorAlgebra<F, FE>
47where
48 F: Field,
49 FE: ExtensionField<F>,
50{
51 pub fn new(mut elems: Vec<FE>) -> Self {
57 elems.resize(FE::DEGREE, FE::ZERO);
58 Self {
59 elems,
60 _marker: PhantomData,
61 }
62 }
63
64 pub const fn kappa() -> usize {
66 FE::LOG_DEGREE
67 }
68
69 pub const fn byte_size() -> usize {
71 mem::size_of::<FE>() << Self::kappa()
72 }
73
74 pub fn one() -> Self {
76 let mut one = Self::default();
77 one.elems[0] = FE::ONE;
78 one
79 }
80
81 pub fn vertical_elems(&self) -> &[FE] {
83 &self.elems
84 }
85
86 pub fn tensor(vertical: FE, horizontal: FE) -> Self {
88 let elems = horizontal
89 .iter_bases()
90 .map(|base| vertical * base)
91 .collect();
92 Self {
93 elems,
94 _marker: PhantomData,
95 }
96 }
97
98 pub fn from_vertical(x: FE) -> Self {
100 let mut elems = vec![FE::ZERO; FE::DEGREE];
101 elems[0] = x;
102 Self {
103 elems,
104 _marker: PhantomData,
105 }
106 }
107
108 pub fn try_extract_vertical(&self) -> Option<FE> {
110 self.elems
111 .iter()
112 .skip(1)
113 .all(|&elem| elem == FE::ZERO)
114 .then_some(self.elems[0])
115 }
116
117 pub fn scale_vertical(mut self, scalar: FE) -> Self {
119 for elem_i in &mut self.elems {
120 *elem_i *= scalar;
121 }
122 self
123 }
124}
125
126impl<F: Field, FE: ExtensionField<F>> TensorAlgebra<F, FE> {
127 pub fn scale_horizontal(self, scalar: FE) -> Self {
133 self.transpose().scale_vertical(scalar).transpose()
134 }
135
136 pub fn transpose(mut self) -> Self {
140 FE::square_transpose(&mut self.elems);
141 self
142 }
143
144 pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
150 inner_product(self.transpose().elems, coeffs.iter().copied())
151 }
152}
153
154impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
155where
156 F: Field,
157 FE: ExtensionField<F>,
158{
159 type Output = Self;
160
161 fn add(mut self, rhs: &Self) -> Self {
162 self.add_assign(rhs);
163 self
164 }
165}
166
167impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
168where
169 F: Field,
170 FE: ExtensionField<F>,
171{
172 type Output = Self;
173
174 fn sub(mut self, rhs: &Self) -> Self {
175 self.sub_assign(rhs);
176 self
177 }
178}
179
180impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
181where
182 F: Field,
183 FE: ExtensionField<F>,
184{
185 fn add_assign(&mut self, rhs: &Self) {
186 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
187 *self_i += *rhs_i;
188 }
189 }
190}
191
192impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
193where
194 F: Field,
195 FE: ExtensionField<F>,
196{
197 fn sub_assign(&mut self, rhs: &Self) {
198 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
199 *self_i -= *rhs_i;
200 }
201 }
202}
203
204impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
205where
206 F: Field,
207 FE: ExtensionField<F>,
208{
209 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
210 iter.fold(Self::default(), |sum, item| sum + item)
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use binius_field::{BinaryField1b as B1, Random};
217 use rand::{SeedableRng, rngs::StdRng};
218
219 use super::*;
220 use crate::test_utils::B128;
221
222 #[test]
223 fn test_tensor_product() {
224 type F = B1;
225 type FE = B128;
226
227 let mut rng = StdRng::seed_from_u64(0);
228
229 let vert = FE::random(&mut rng);
230 let hztl = FE::random(&mut rng);
231
232 let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
233 assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
234 }
235
236 #[test]
237 fn test_try_extract_vertical() {
238 type F = B1;
239 type FE = B128;
240
241 let mut rng = StdRng::seed_from_u64(0);
242
243 let vert = FE::random(&mut rng);
244 let elem = TensorAlgebra::<F, _>::from_vertical(vert);
245 assert_eq!(elem.try_extract_vertical(), Some(vert));
246
247 let hztl = FE::new(1111);
249 let elem = elem.scale_horizontal(hztl);
250 assert_eq!(elem.try_extract_vertical(), None);
251
252 let hztl_inv = hztl.invert().unwrap();
254 let elem = elem.scale_horizontal(hztl_inv);
255 assert_eq!(elem.try_extract_vertical(), Some(vert));
256
257 let hztl_subfield = FE::from(F::ONE);
259 let elem = elem.scale_horizontal(hztl_subfield);
260 assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
261 }
262}