binius_math/
tensor_algebra.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::Sum,
5	marker::PhantomData,
6	mem,
7	ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{ExtensionField, Field, util::inner_product_unchecked};
11
12/// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields.
13///
14/// A tensor algebra element is a length $D$ vector of `FE` field elements, where $D$ is the degree
15/// of `FE` as an extension of `F`. The algebra has a "vertical subring" and a "horizontal subring",
16/// which are both isomorphic to `FE` as a field.
17///
18/// See [DP24] Section 2 for further details.
19///
20/// [DP24]: <https://eprint.iacr.org/2024/504>
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct TensorAlgebra<F, FE>
23where
24	F: Field,
25	FE: ExtensionField<F>,
26{
27	pub elems: Vec<FE>,
28	_marker: PhantomData<F>,
29}
30
31impl<F, FE> Default for TensorAlgebra<F, FE>
32where
33	F: Field,
34	FE: ExtensionField<F>,
35{
36	fn default() -> Self {
37		Self {
38			elems: vec![FE::default(); FE::DEGREE],
39			_marker: PhantomData,
40		}
41	}
42}
43
44impl<F, FE> TensorAlgebra<F, FE>
45where
46	F: Field,
47	FE: ExtensionField<F>,
48{
49	/// Constructs an element from a vector of vertical subring elements.
50	///
51	/// ## Preconditions
52	///
53	/// * `elems` must have length `FE::DEGREE`, otherwise this will pad or truncate.
54	pub fn new(mut elems: Vec<FE>) -> Self {
55		elems.resize(FE::DEGREE, FE::ZERO);
56		Self {
57			elems,
58			_marker: PhantomData,
59		}
60	}
61
62	/// Returns $\kappa$, the base-2 logarithm of the extension degree.
63	pub const fn kappa() -> usize {
64		FE::LOG_DEGREE
65	}
66
67	/// Returns the byte size of an element.
68	pub const fn byte_size() -> usize {
69		mem::size_of::<FE>() << Self::kappa()
70	}
71
72	/// Returns the multiplicative identity element, one.
73	pub fn one() -> Self {
74		let mut one = Self::default();
75		one.elems[0] = FE::ONE;
76		one
77	}
78
79	/// Returns a slice of the vertical subfield elements composing the tensor algebra element.
80	pub fn vertical_elems(&self) -> &[FE] {
81		&self.elems
82	}
83
84	/// Tensor product of a vertical subring element and a horizontal subring element.
85	pub fn tensor(vertical: FE, horizontal: FE) -> Self {
86		let elems = horizontal
87			.iter_bases()
88			.map(|base| vertical * base)
89			.collect();
90		Self {
91			elems,
92			_marker: PhantomData,
93		}
94	}
95
96	/// Constructs a [`TensorAlgebra`] in the vertical subring.
97	pub fn from_vertical(x: FE) -> Self {
98		let mut elems = vec![FE::ZERO; FE::DEGREE];
99		elems[0] = x;
100		Self {
101			elems,
102			_marker: PhantomData,
103		}
104	}
105
106	/// If the algebra element lives in the vertical subring, this returns it as a field element.
107	pub fn try_extract_vertical(&self) -> Option<FE> {
108		self.elems
109			.iter()
110			.skip(1)
111			.all(|&elem| elem == FE::ZERO)
112			.then_some(self.elems[0])
113	}
114
115	/// Multiply by an element from the vertical subring.
116	pub fn scale_vertical(mut self, scalar: FE) -> Self {
117		for elem_i in &mut self.elems {
118			*elem_i *= scalar;
119		}
120		self
121	}
122}
123
124impl<F: Field, FE: ExtensionField<F>> TensorAlgebra<F, FE> {
125	/// Multiply by an element from the vertical subring.
126	///
127	/// Internally, this performs a transpose, vertical scaling, then transpose sequence. If
128	/// multiple horizontal scaling operations are required and performance is a concern, it may be
129	/// better for the caller to do the transposes directly and amortize their cost.
130	pub fn scale_horizontal(self, scalar: FE) -> Self {
131		self.transpose().scale_vertical(scalar).transpose()
132	}
133
134	/// Transposes the algebra element.
135	///
136	/// A transpose flips the vertical and horizontal subring elements.
137	pub fn transpose(mut self) -> Self {
138		FE::square_transpose(&mut self.elems)
139			.expect("transpose dimensions are square by struct invariant");
140		self
141	}
142
143	/// Fold the tensor algebra element into a field element by scaling the rows and accumulating.
144	///
145	/// ## Preconditions
146	///
147	/// * `coeffs` must have length $2^\kappa$
148	pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
149		inner_product_unchecked::<FE, _>(self.transpose().elems, coeffs.iter().copied())
150	}
151}
152
153impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
154where
155	F: Field,
156	FE: ExtensionField<F>,
157{
158	type Output = Self;
159
160	fn add(mut self, rhs: &Self) -> Self {
161		self.add_assign(rhs);
162		self
163	}
164}
165
166impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
167where
168	F: Field,
169	FE: ExtensionField<F>,
170{
171	type Output = Self;
172
173	fn sub(mut self, rhs: &Self) -> Self {
174		self.sub_assign(rhs);
175		self
176	}
177}
178
179impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
180where
181	F: Field,
182	FE: ExtensionField<F>,
183{
184	fn add_assign(&mut self, rhs: &Self) {
185		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
186			*self_i += *rhs_i;
187		}
188	}
189}
190
191impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
192where
193	F: Field,
194	FE: ExtensionField<F>,
195{
196	fn sub_assign(&mut self, rhs: &Self) {
197		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
198			*self_i -= *rhs_i;
199		}
200	}
201}
202
203impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
204where
205	F: Field,
206	FE: ExtensionField<F>,
207{
208	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
209		iter.fold(Self::default(), |sum, item| sum + item)
210	}
211}
212
213#[cfg(test)]
214mod tests {
215	use binius_field::{BinaryField1b as B1, Random};
216	use rand::{SeedableRng, rngs::StdRng};
217
218	use super::*;
219	use crate::test_utils::B128;
220
221	#[test]
222	fn test_tensor_product() {
223		type F = B1;
224		type FE = B128;
225
226		let mut rng = StdRng::seed_from_u64(0);
227
228		let vert = FE::random(&mut rng);
229		let hztl = FE::random(&mut rng);
230
231		let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
232		assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
233	}
234
235	#[test]
236	fn test_try_extract_vertical() {
237		type F = B1;
238		type FE = B128;
239
240		let mut rng = StdRng::seed_from_u64(0);
241
242		let vert = FE::random(&mut rng);
243		let elem = TensorAlgebra::<F, _>::from_vertical(vert);
244		assert_eq!(elem.try_extract_vertical(), Some(vert));
245
246		// Scale horizontally by an extension element, and we should no longer be vertical.
247		let hztl = FE::new(1111);
248		let elem = elem.scale_horizontal(hztl);
249		assert_eq!(elem.try_extract_vertical(), None);
250
251		// Scale back by the inverse to get back to the vertical subring.
252		let hztl_inv = hztl.invert().unwrap();
253		let elem = elem.scale_horizontal(hztl_inv);
254		assert_eq!(elem.try_extract_vertical(), Some(vert));
255
256		// If we scale horizontally by an F element, we should remain in the vertical subring.
257		let hztl_subfield = FE::from(F::ONE);
258		let elem = elem.scale_horizontal(hztl_subfield);
259		assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
260	}
261}