Skip to main content

binius_math/
tensor_algebra.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::Sum,
5	marker::PhantomData,
6	mem,
7	ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{ExtensionField, Field, field::FieldOps};
11
12use crate::inner_product::inner_product_scalars;
13
14/// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields.
15///
16/// A tensor algebra element is a length $D$ vector of `FE` field elements, where $D$ is the degree
17/// of `FE` as an extension of `F`. The algebra has a "vertical subring" and a "horizontal subring",
18/// which are both isomorphic to `FE` as a field.
19///
20/// See [DP24] Section 2 for further details.
21///
22/// [DP24]: <https://eprint.iacr.org/2024/504>
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct TensorAlgebra<F, FE> {
25	pub elems: Vec<FE>,
26	_marker: PhantomData<F>,
27}
28
29impl<F, FE> TensorAlgebra<F, FE>
30where
31	F: Field,
32	FE: FieldOps<Scalar: ExtensionField<F>>,
33{
34	/// Constructs an element from a vector of vertical subring elements.
35	///
36	/// ## Preconditions
37	///
38	/// * `elems` must have length equal to the extension degree, otherwise this will pad or
39	///   truncate.
40	pub fn new(mut elems: Vec<FE>) -> Self {
41		elems.resize(<FE::Scalar as ExtensionField<F>>::DEGREE, FE::zero());
42		Self {
43			elems,
44			_marker: PhantomData,
45		}
46	}
47
48	/// Returns $\kappa$, the base-2 logarithm of the extension degree.
49	pub const fn kappa() -> usize {
50		<FE::Scalar as ExtensionField<F>>::LOG_DEGREE
51	}
52
53	/// Returns the multiplicative identity element, one.
54	pub fn one() -> Self {
55		let mut elems = vec![FE::zero(); <FE::Scalar as ExtensionField<F>>::DEGREE];
56		elems[0] = FE::one();
57		Self {
58			elems,
59			_marker: PhantomData,
60		}
61	}
62
63	/// Returns a slice of the vertical subfield elements composing the tensor algebra element.
64	pub fn vertical_elems(&self) -> &[FE] {
65		&self.elems
66	}
67
68	/// Constructs a [`TensorAlgebra`] in the vertical subring.
69	pub fn from_vertical(x: FE) -> Self {
70		let mut elems = vec![FE::zero(); <FE::Scalar as ExtensionField<F>>::DEGREE];
71		elems[0] = x;
72		Self {
73			elems,
74			_marker: PhantomData,
75		}
76	}
77
78	/// Multiply by an element from the vertical subring.
79	pub fn scale_vertical(mut self, scalar: FE) -> Self {
80		for elem_i in &mut self.elems {
81			*elem_i *= scalar.clone();
82		}
83		self
84	}
85
86	/// Multiply by an element from the horizontal subring.
87	///
88	/// Internally, this performs a transpose, vertical scaling, then transpose sequence. If
89	/// multiple horizontal scaling operations are required and performance is a concern, it may be
90	/// better for the caller to do the transposes directly and amortize their cost.
91	pub fn scale_horizontal(self, scalar: FE) -> Self {
92		self.transpose().scale_vertical(scalar).transpose()
93	}
94
95	/// Transposes the algebra element.
96	///
97	/// A transpose flips the vertical and horizontal subring elements.
98	pub fn transpose(mut self) -> Self {
99		FE::square_transpose::<F>(&mut self.elems);
100		self
101	}
102
103	/// Fold the tensor algebra element into a field element by scaling the rows and accumulating.
104	///
105	/// ## Preconditions
106	///
107	/// * `coeffs` must have length $2^\kappa$
108	pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
109		inner_product_scalars(self.transpose().elems, coeffs.iter().cloned())
110	}
111}
112
113impl<F, FE> Default for TensorAlgebra<F, FE>
114where
115	F: Field,
116	FE: FieldOps<Scalar: ExtensionField<F>>,
117{
118	fn default() -> Self {
119		Self {
120			elems: vec![FE::zero(); <FE::Scalar as ExtensionField<F>>::DEGREE],
121			_marker: PhantomData,
122		}
123	}
124}
125
126impl<F, FE> TensorAlgebra<F, FE>
127where
128	F: Field,
129	FE: ExtensionField<F>,
130{
131	/// Returns the byte size of an element.
132	pub const fn byte_size() -> usize {
133		mem::size_of::<FE>() << <FE as ExtensionField<F>>::LOG_DEGREE
134	}
135
136	/// Tensor product of a vertical subring element and a horizontal subring element.
137	pub fn tensor(vertical: FE, horizontal: FE) -> Self {
138		let elems = horizontal
139			.iter_bases()
140			.map(|base| vertical * base)
141			.collect();
142		Self {
143			elems,
144			_marker: PhantomData,
145		}
146	}
147
148	/// If the algebra element lives in the vertical subring, this returns it as a field element.
149	pub fn try_extract_vertical(&self) -> Option<FE> {
150		self.elems
151			.iter()
152			.skip(1)
153			.all(|&elem| elem == FE::ZERO)
154			.then_some(self.elems[0])
155	}
156}
157
158impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
159where
160	F: Field,
161	FE: FieldOps<Scalar: ExtensionField<F>>,
162{
163	type Output = Self;
164
165	fn add(mut self, rhs: &Self) -> Self {
166		self.add_assign(rhs);
167		self
168	}
169}
170
171impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
172where
173	F: Field,
174	FE: FieldOps<Scalar: ExtensionField<F>>,
175{
176	type Output = Self;
177
178	fn sub(mut self, rhs: &Self) -> Self {
179		self.sub_assign(rhs);
180		self
181	}
182}
183
184impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
185where
186	F: Field,
187	FE: FieldOps<Scalar: ExtensionField<F>>,
188{
189	fn add_assign(&mut self, rhs: &Self) {
190		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
191			*self_i += rhs_i;
192		}
193	}
194}
195
196impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
197where
198	F: Field,
199	FE: FieldOps<Scalar: ExtensionField<F>>,
200{
201	fn sub_assign(&mut self, rhs: &Self) {
202		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
203			*self_i -= rhs_i;
204		}
205	}
206}
207
208impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
209where
210	F: Field,
211	FE: FieldOps<Scalar: ExtensionField<F>>,
212{
213	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
214		iter.fold(Self::default(), |sum, item| sum + item)
215	}
216}
217
218#[cfg(test)]
219mod tests {
220	use binius_field::{BinaryField1b as B1, Random};
221	use rand::{SeedableRng, rngs::StdRng};
222
223	use super::*;
224	use crate::test_utils::B128;
225
226	#[test]
227	fn test_tensor_product() {
228		type F = B1;
229		type FE = B128;
230
231		let mut rng = StdRng::seed_from_u64(0);
232
233		let vert = FE::random(&mut rng);
234		let hztl = FE::random(&mut rng);
235
236		let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
237		assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
238	}
239
240	#[test]
241	fn test_try_extract_vertical() {
242		type F = B1;
243		type FE = B128;
244
245		let mut rng = StdRng::seed_from_u64(0);
246
247		let vert = FE::random(&mut rng);
248		let elem = TensorAlgebra::<F, _>::from_vertical(vert);
249		assert_eq!(elem.try_extract_vertical(), Some(vert));
250
251		// Scale horizontally by an extension element, and we should no longer be vertical.
252		let hztl = FE::new(1111);
253		let elem = elem.scale_horizontal(hztl);
254		assert_eq!(elem.try_extract_vertical(), None);
255
256		// Scale back by the inverse to get back to the vertical subring.
257		let hztl_inv = hztl.invert().unwrap();
258		let elem = elem.scale_horizontal(hztl_inv);
259		assert_eq!(elem.try_extract_vertical(), Some(vert));
260
261		// If we scale horizontally by an F element, we should remain in the vertical subring.
262		let hztl_subfield = FE::from(F::ONE);
263		let elem = elem.scale_horizontal(hztl_subfield);
264		assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
265	}
266}