binius_core/
tensor_algebra.rs1use std::{
4 iter::Sum,
5 marker::PhantomData,
6 mem,
7 ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{
11 square_transpose, util::inner_product_unchecked, ExtensionField, Field, PackedExtension,
12};
13
14#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct TensorAlgebra<F, FE>
25where
26 F: Field,
27 FE: ExtensionField<F>,
28{
29 pub elems: Vec<FE>,
30 _marker: PhantomData<F>,
31}
32
33impl<F, FE> Default for TensorAlgebra<F, FE>
34where
35 F: Field,
36 FE: ExtensionField<F>,
37{
38 fn default() -> Self {
39 Self {
40 elems: vec![FE::default(); FE::DEGREE],
41 _marker: PhantomData,
42 }
43 }
44}
45
46impl<F, FE> TensorAlgebra<F, FE>
47where
48 F: Field,
49 FE: ExtensionField<F>,
50{
51 pub fn new(mut elems: Vec<FE>) -> Self {
57 elems.resize(FE::DEGREE, FE::ZERO);
58 Self {
59 elems,
60 _marker: PhantomData,
61 }
62 }
63
64 pub const fn kappa() -> usize {
66 FE::LOG_DEGREE
67 }
68
69 pub const fn byte_size() -> usize {
71 mem::size_of::<FE>() << Self::kappa()
72 }
73
74 pub fn one() -> Self {
76 let mut one = Self::default();
77 one.elems[0] = FE::ONE;
78 one
79 }
80
81 pub fn vertical_elems(&self) -> &[FE] {
83 &self.elems
84 }
85
86 pub fn tensor(vertical: FE, horizontal: FE) -> Self {
88 let elems = horizontal
89 .iter_bases()
90 .map(|base| vertical * base)
91 .collect();
92 Self {
93 elems,
94 _marker: PhantomData,
95 }
96 }
97
98 pub fn from_vertical(x: FE) -> Self {
100 let mut elems = vec![FE::ZERO; FE::DEGREE];
101 elems[0] = x;
102 Self {
103 elems,
104 _marker: PhantomData,
105 }
106 }
107
108 pub fn try_extract_vertical(&self) -> Option<FE> {
110 self.elems
111 .iter()
112 .skip(1)
113 .all(|&elem| elem == FE::ZERO)
114 .then_some(self.elems[0])
115 }
116
117 pub fn scale_vertical(mut self, scalar: FE) -> Self {
119 for elem_i in &mut self.elems {
120 *elem_i *= scalar;
121 }
122 self
123 }
124}
125
126impl<F: Field, FE: ExtensionField<F> + PackedExtension<F>> TensorAlgebra<F, FE> {
127 pub fn scale_horizontal(self, scalar: FE) -> Self {
133 self.transpose().scale_vertical(scalar).transpose()
134 }
135
136 pub fn transpose(mut self) -> Self {
140 square_transpose(Self::kappa(), FE::cast_bases_mut(&mut self.elems))
141 .expect("transpose dimensions are square by struct invariant");
142 self
143 }
144
145 pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
151 inner_product_unchecked::<FE, _>(self.transpose().elems, coeffs.iter().copied())
152 }
153}
154
155impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
156where
157 F: Field,
158 FE: ExtensionField<F>,
159{
160 type Output = Self;
161
162 fn add(mut self, rhs: &Self) -> Self {
163 self.add_assign(rhs);
164 self
165 }
166}
167
168impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
169where
170 F: Field,
171 FE: ExtensionField<F>,
172{
173 type Output = Self;
174
175 fn sub(mut self, rhs: &Self) -> Self {
176 self.sub_assign(rhs);
177 self
178 }
179}
180
181impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
182where
183 F: Field,
184 FE: ExtensionField<F>,
185{
186 fn add_assign(&mut self, rhs: &Self) {
187 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
188 *self_i += *rhs_i;
189 }
190 }
191}
192
193impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
194where
195 F: Field,
196 FE: ExtensionField<F>,
197{
198 fn sub_assign(&mut self, rhs: &Self) {
199 for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
200 *self_i -= *rhs_i;
201 }
202 }
203}
204
205impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
206where
207 F: Field,
208 FE: ExtensionField<F>,
209{
210 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
211 iter.fold(Self::default(), |sum, item| sum + item)
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use binius_field::{BinaryField128b, BinaryField8b};
218 use rand::{rngs::StdRng, SeedableRng};
219
220 use super::*;
221
222 #[test]
223 fn test_tensor_product() {
224 type F = BinaryField8b;
225 type FE = BinaryField128b;
226
227 let mut rng = StdRng::seed_from_u64(0);
228
229 let vert = FE::random(&mut rng);
230 let hztl = FE::random(&mut rng);
231
232 let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
233 assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
234 }
235
236 #[test]
237 fn test_try_extract_vertical() {
238 type F = BinaryField8b;
239 type FE = BinaryField128b;
240
241 let mut rng = StdRng::seed_from_u64(0);
242
243 let vert = FE::random(&mut rng);
244 let elem = TensorAlgebra::<F, _>::from_vertical(vert);
245 assert_eq!(elem.try_extract_vertical(), Some(vert));
246
247 let hztl = FE::new(1111);
249 let elem = elem.scale_horizontal(hztl);
250 assert_eq!(elem.try_extract_vertical(), None);
251
252 let hztl_inv = hztl.invert().unwrap();
254 let elem = elem.scale_horizontal(hztl_inv);
255 assert_eq!(elem.try_extract_vertical(), Some(vert));
256
257 let hztl_subfield = FE::from(F::new(7));
259 let elem = elem.scale_horizontal(hztl_subfield);
260 assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
261 }
262}