binius_math/
tensor_algebra.rs1use std::{
4 iter::Sum,
5 marker::PhantomData,
6 mem,
7 ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{ExtensionField, Field, field::FieldOps};
11
12use crate::inner_product::inner_product_scalars;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct TensorAlgebra<F, FE> {
25 pub elems: Vec<FE>,
26 _marker: PhantomData<F>,
27}
28
29impl<F, FE> TensorAlgebra<F, FE>
30where
31 F: Field,
32 FE: FieldOps<Scalar: ExtensionField<F>>,
33{
34 pub fn new(mut elems: Vec<FE>) -> Self {
41 elems.resize(<FE::Scalar as ExtensionField<F>>::DEGREE, FE::zero());
42 Self {
43 elems,
44 _marker: PhantomData,
45 }
46 }
47
48 pub const fn kappa() -> usize {
50 <FE::Scalar as ExtensionField<F>>::LOG_DEGREE
51 }
52
53 pub fn one() -> Self {
55 let mut elems = vec![FE::zero(); <FE::Scalar as ExtensionField<F>>::DEGREE];
56 elems[0] = FE::one();
57 Self {
58 elems,
59 _marker: PhantomData,
60 }
61 }
62
63 pub fn vertical_elems(&self) -> &[FE] {
65 &self.elems
66 }
67
68 pub fn from_vertical(x: FE) -> Self {
70 let mut elems = vec![FE::zero(); <FE::Scalar as ExtensionField<F>>::DEGREE];
71 elems[0] = x;
72 Self {
73 elems,
74 _marker: PhantomData,
75 }
76 }
77
78 pub fn scale_vertical(mut self, scalar: FE) -> Self {
80 for elem_i in &mut self.elems {
81 *elem_i *= scalar.clone();
82 }
83 self
84 }
85
86 pub fn scale_horizontal(self, scalar: FE) -> Self {
92 self.transpose().scale_vertical(scalar).transpose()
93 }
94
95 pub fn transpose(mut self) -> Self {
99 FE::square_transpose::<F>(&mut self.elems);
100 self
101 }
102
103 pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
109 inner_product_scalars(self.transpose().elems, coeffs.iter().cloned())
110 }
111}
112
113impl<F, FE> Default for TensorAlgebra<F, FE>
114where
115 F: Field,
116 FE: FieldOps<Scalar: ExtensionField<F>>,
117{
118 fn default() -> Self {
119 Self {
120 elems: vec![FE::zero(); <FE::Scalar as ExtensionField<F>>::DEGREE],
121 _marker: PhantomData,
122 }
123 }
124}
125
126impl<F, FE> TensorAlgebra<F, FE>
127where
128 F: Field,
129 FE: ExtensionField<F>,
130{
131 pub const fn byte_size() -> usize {
133 mem::size_of::<FE>() << <FE as ExtensionField<F>>::LOG_DEGREE
134 }
135
136 pub fn tensor(vertical: FE, horizontal: FE) -> Self {
138 let elems = horizontal
139 .iter_bases()
140 .map(|base| vertical * base)
141 .collect();
142 Self {
143 elems,
144 _marker: PhantomData,
145 }
146 }
147
148 pub fn try_extract_vertical(&self) -> Option<FE> {
150 self.elems
151 .iter()
152 .skip(1)
153 .all(|&elem| elem == FE::ZERO)
154 .then_some(self.elems[0])
155 }
156}
157
158impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
159where
160 F: Field,
161 FE: FieldOps<Scalar: ExtensionField<F>>,
162{
163 type Output = Self;
164
165 fn add(mut self, rhs: &Self) -> Self {
166 self.add_assign(rhs);
167 self
168 }
169}
170
171impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
172where
173 F: Field,
174 FE: FieldOps<Scalar: ExtensionField<F>>,
175{
176 type Output = Self;
177
178 fn sub(mut self, rhs: &Self) -> Self {
179 self.sub_assign(rhs);
180 self
181 }
182}
183
184impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
185where
186 F: Field,
187 FE: FieldOps<Scalar: ExtensionField<F>>,
188{
189 fn add_assign(&mut self, rhs: &Self) {
190 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
191 *self_i += rhs_i;
192 }
193 }
194}
195
196impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
197where
198 F: Field,
199 FE: FieldOps<Scalar: ExtensionField<F>>,
200{
201 fn sub_assign(&mut self, rhs: &Self) {
202 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
203 *self_i -= rhs_i;
204 }
205 }
206}
207
208impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
209where
210 F: Field,
211 FE: FieldOps<Scalar: ExtensionField<F>>,
212{
213 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
214 iter.fold(Self::default(), |sum, item| sum + item)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use binius_field::{BinaryField1b as B1, Random};
221 use rand::{SeedableRng, rngs::StdRng};
222
223 use super::*;
224 use crate::test_utils::B128;
225
226 #[test]
227 fn test_tensor_product() {
228 type F = B1;
229 type FE = B128;
230
231 let mut rng = StdRng::seed_from_u64(0);
232
233 let vert = FE::random(&mut rng);
234 let hztl = FE::random(&mut rng);
235
236 let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
237 assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
238 }
239
240 #[test]
241 fn test_try_extract_vertical() {
242 type F = B1;
243 type FE = B128;
244
245 let mut rng = StdRng::seed_from_u64(0);
246
247 let vert = FE::random(&mut rng);
248 let elem = TensorAlgebra::<F, _>::from_vertical(vert);
249 assert_eq!(elem.try_extract_vertical(), Some(vert));
250
251 let hztl = FE::new(1111);
253 let elem = elem.scale_horizontal(hztl);
254 assert_eq!(elem.try_extract_vertical(), None);
255
256 let hztl_inv = hztl.invert().unwrap();
258 let elem = elem.scale_horizontal(hztl_inv);
259 assert_eq!(elem.try_extract_vertical(), Some(vert));
260
261 let hztl_subfield = FE::from(F::ONE);
263 let elem = elem.scale_horizontal(hztl_subfield);
264 assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
265 }
266}