binius_math/
tensor_algebra.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::Sum,
5	marker::PhantomData,
6	mem,
7	ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{ExtensionField, Field};
11
12use crate::inner_product::inner_product;
13
14/// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields.
15///
16/// A tensor algebra element is a length $D$ vector of `FE` field elements, where $D$ is the degree
17/// of `FE` as an extension of `F`. The algebra has a "vertical subring" and a "horizontal subring",
18/// which are both isomorphic to `FE` as a field.
19///
20/// See [DP24] Section 2 for further details.
21///
22/// [DP24]: <https://eprint.iacr.org/2024/504>
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct TensorAlgebra<F, FE>
25where
26	F: Field,
27	FE: ExtensionField<F>,
28{
29	pub elems: Vec<FE>,
30	_marker: PhantomData<F>,
31}
32
33impl<F, FE> Default for TensorAlgebra<F, FE>
34where
35	F: Field,
36	FE: ExtensionField<F>,
37{
38	fn default() -> Self {
39		Self {
40			elems: vec![FE::default(); FE::DEGREE],
41			_marker: PhantomData,
42		}
43	}
44}
45
46impl<F, FE> TensorAlgebra<F, FE>
47where
48	F: Field,
49	FE: ExtensionField<F>,
50{
51	/// Constructs an element from a vector of vertical subring elements.
52	///
53	/// ## Preconditions
54	///
55	/// * `elems` must have length `FE::DEGREE`, otherwise this will pad or truncate.
56	pub fn new(mut elems: Vec<FE>) -> Self {
57		elems.resize(FE::DEGREE, FE::ZERO);
58		Self {
59			elems,
60			_marker: PhantomData,
61		}
62	}
63
64	/// Returns $\kappa$, the base-2 logarithm of the extension degree.
65	pub const fn kappa() -> usize {
66		FE::LOG_DEGREE
67	}
68
69	/// Returns the byte size of an element.
70	pub const fn byte_size() -> usize {
71		mem::size_of::<FE>() << Self::kappa()
72	}
73
74	/// Returns the multiplicative identity element, one.
75	pub fn one() -> Self {
76		let mut one = Self::default();
77		one.elems[0] = FE::ONE;
78		one
79	}
80
81	/// Returns a slice of the vertical subfield elements composing the tensor algebra element.
82	pub fn vertical_elems(&self) -> &[FE] {
83		&self.elems
84	}
85
86	/// Tensor product of a vertical subring element and a horizontal subring element.
87	pub fn tensor(vertical: FE, horizontal: FE) -> Self {
88		let elems = horizontal
89			.iter_bases()
90			.map(|base| vertical * base)
91			.collect();
92		Self {
93			elems,
94			_marker: PhantomData,
95		}
96	}
97
98	/// Constructs a [`TensorAlgebra`] in the vertical subring.
99	pub fn from_vertical(x: FE) -> Self {
100		let mut elems = vec![FE::ZERO; FE::DEGREE];
101		elems[0] = x;
102		Self {
103			elems,
104			_marker: PhantomData,
105		}
106	}
107
108	/// If the algebra element lives in the vertical subring, this returns it as a field element.
109	pub fn try_extract_vertical(&self) -> Option<FE> {
110		self.elems
111			.iter()
112			.skip(1)
113			.all(|&elem| elem == FE::ZERO)
114			.then_some(self.elems[0])
115	}
116
117	/// Multiply by an element from the vertical subring.
118	pub fn scale_vertical(mut self, scalar: FE) -> Self {
119		for elem_i in &mut self.elems {
120			*elem_i *= scalar;
121		}
122		self
123	}
124}
125
126impl<F: Field, FE: ExtensionField<F>> TensorAlgebra<F, FE> {
127	/// Multiply by an element from the vertical subring.
128	///
129	/// Internally, this performs a transpose, vertical scaling, then transpose sequence. If
130	/// multiple horizontal scaling operations are required and performance is a concern, it may be
131	/// better for the caller to do the transposes directly and amortize their cost.
132	pub fn scale_horizontal(self, scalar: FE) -> Self {
133		self.transpose().scale_vertical(scalar).transpose()
134	}
135
136	/// Transposes the algebra element.
137	///
138	/// A transpose flips the vertical and horizontal subring elements.
139	pub fn transpose(mut self) -> Self {
140		FE::square_transpose(&mut self.elems);
141		self
142	}
143
144	/// Fold the tensor algebra element into a field element by scaling the rows and accumulating.
145	///
146	/// ## Preconditions
147	///
148	/// * `coeffs` must have length $2^\kappa$
149	pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
150		inner_product(self.transpose().elems, coeffs.iter().copied())
151	}
152}
153
154impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
155where
156	F: Field,
157	FE: ExtensionField<F>,
158{
159	type Output = Self;
160
161	fn add(mut self, rhs: &Self) -> Self {
162		self.add_assign(rhs);
163		self
164	}
165}
166
167impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
168where
169	F: Field,
170	FE: ExtensionField<F>,
171{
172	type Output = Self;
173
174	fn sub(mut self, rhs: &Self) -> Self {
175		self.sub_assign(rhs);
176		self
177	}
178}
179
180impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
181where
182	F: Field,
183	FE: ExtensionField<F>,
184{
185	fn add_assign(&mut self, rhs: &Self) {
186		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
187			*self_i += *rhs_i;
188		}
189	}
190}
191
192impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
193where
194	F: Field,
195	FE: ExtensionField<F>,
196{
197	fn sub_assign(&mut self, rhs: &Self) {
198		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
199			*self_i -= *rhs_i;
200		}
201	}
202}
203
204impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
205where
206	F: Field,
207	FE: ExtensionField<F>,
208{
209	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
210		iter.fold(Self::default(), |sum, item| sum + item)
211	}
212}
213
214#[cfg(test)]
215mod tests {
216	use binius_field::{BinaryField1b as B1, Random};
217	use rand::{SeedableRng, rngs::StdRng};
218
219	use super::*;
220	use crate::test_utils::B128;
221
222	#[test]
223	fn test_tensor_product() {
224		type F = B1;
225		type FE = B128;
226
227		let mut rng = StdRng::seed_from_u64(0);
228
229		let vert = FE::random(&mut rng);
230		let hztl = FE::random(&mut rng);
231
232		let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
233		assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
234	}
235
236	#[test]
237	fn test_try_extract_vertical() {
238		type F = B1;
239		type FE = B128;
240
241		let mut rng = StdRng::seed_from_u64(0);
242
243		let vert = FE::random(&mut rng);
244		let elem = TensorAlgebra::<F, _>::from_vertical(vert);
245		assert_eq!(elem.try_extract_vertical(), Some(vert));
246
247		// Scale horizontally by an extension element, and we should no longer be vertical.
248		let hztl = FE::new(1111);
249		let elem = elem.scale_horizontal(hztl);
250		assert_eq!(elem.try_extract_vertical(), None);
251
252		// Scale back by the inverse to get back to the vertical subring.
253		let hztl_inv = hztl.invert().unwrap();
254		let elem = elem.scale_horizontal(hztl_inv);
255		assert_eq!(elem.try_extract_vertical(), Some(vert));
256
257		// If we scale horizontally by an F element, we should remain in the vertical subring.
258		let hztl_subfield = FE::from(F::ONE);
259		let elem = elem.scale_horizontal(hztl_subfield);
260		assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
261	}
262}