binius_core/
tensor_algebra.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::Sum,
5	marker::PhantomData,
6	mem,
7	ops::{Add, AddAssign, Sub, SubAssign},
8};
9
10use binius_field::{
11	square_transpose, util::inner_product_unchecked, ExtensionField, Field, PackedExtension,
12};
13
14/// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields.
15///
16/// A tensor algebra element is a length $D$ vector of `FE` field elements, where $D$ is the degree
17/// of `FE` as an extension of `F`. The algebra has a "vertical subring" and a "horizontal subring",
18/// which are both isomorphic to `FE` as a field.
19///
20/// See [DP24] Section 2 for further details.
21///
22/// [DP24]: <https://eprint.iacr.org/2024/504>
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct TensorAlgebra<F, FE>
25where
26	F: Field,
27	FE: ExtensionField<F>,
28{
29	pub elems: Vec<FE>,
30	_marker: PhantomData<F>,
31}
32
33impl<F, FE> Default for TensorAlgebra<F, FE>
34where
35	F: Field,
36	FE: ExtensionField<F>,
37{
38	fn default() -> Self {
39		Self {
40			elems: vec![FE::default(); FE::DEGREE],
41			_marker: PhantomData,
42		}
43	}
44}
45
46impl<F, FE> TensorAlgebra<F, FE>
47where
48	F: Field,
49	FE: ExtensionField<F>,
50{
51	/// Constructs an element from a vector of vertical subring elements.
52	///
53	/// ## Preconditions
54	///
55	/// * `elems` must have length `FE::DEGREE`, otherwise this will pad or truncate.
56	pub fn new(mut elems: Vec<FE>) -> Self {
57		elems.resize(FE::DEGREE, FE::ZERO);
58		Self {
59			elems,
60			_marker: PhantomData,
61		}
62	}
63
64	/// Returns $\kappa$, the base-2 logarithm of the extension degree.
65	pub const fn kappa() -> usize {
66		FE::LOG_DEGREE
67	}
68
69	/// Returns the byte size of an element.
70	pub const fn byte_size() -> usize {
71		mem::size_of::<FE>() << Self::kappa()
72	}
73
74	/// Returns the multiplicative identity element, one.
75	pub fn one() -> Self {
76		let mut one = Self::default();
77		one.elems[0] = FE::ONE;
78		one
79	}
80
81	/// Returns a slice of the vertical subfield elements composing the tensor algebra element.
82	pub fn vertical_elems(&self) -> &[FE] {
83		&self.elems
84	}
85
86	/// Tensor product of a vertical subring element and a horizontal subring element.
87	pub fn tensor(vertical: FE, horizontal: FE) -> Self {
88		let elems = horizontal
89			.iter_bases()
90			.map(|base| vertical * base)
91			.collect();
92		Self {
93			elems,
94			_marker: PhantomData,
95		}
96	}
97
98	/// Constructs a [`TensorAlgebra`] in the vertical subring.
99	pub fn from_vertical(x: FE) -> Self {
100		let mut elems = vec![FE::ZERO; FE::DEGREE];
101		elems[0] = x;
102		Self {
103			elems,
104			_marker: PhantomData,
105		}
106	}
107
108	/// If the algebra element lives in the vertical subring, this returns it as a field element.
109	pub fn try_extract_vertical(&self) -> Option<FE> {
110		self.elems
111			.iter()
112			.skip(1)
113			.all(|&elem| elem == FE::ZERO)
114			.then_some(self.elems[0])
115	}
116
117	/// Multiply by an element from the vertical subring.
118	pub fn scale_vertical(mut self, scalar: FE) -> Self {
119		for elem_i in &mut self.elems {
120			*elem_i *= scalar;
121		}
122		self
123	}
124}
125
126impl<F: Field, FE: ExtensionField<F> + PackedExtension<F>> TensorAlgebra<F, FE> {
127	/// Multiply by an element from the vertical subring.
128	///
129	/// Internally, this performs a transpose, vertical scaling, then transpose sequence. If
130	/// multiple horizontal scaling operations are required and performance is a concern, it may be
131	/// better for the caller to do the transposes directly and amortize their cost.
132	pub fn scale_horizontal(self, scalar: FE) -> Self {
133		self.transpose().scale_vertical(scalar).transpose()
134	}
135
136	/// Transposes the algebra element.
137	///
138	/// A transpose flips the vertical and horizontal subring elements.
139	pub fn transpose(mut self) -> Self {
140		square_transpose(Self::kappa(), FE::cast_bases_mut(&mut self.elems))
141			.expect("transpose dimensions are square by struct invariant");
142		self
143	}
144
145	/// Fold the tensor algebra element into a field element by scaling the rows and accumulating.
146	///
147	/// ## Preconditions
148	///
149	/// * `coeffs` must have length $2^\kappa$
150	pub fn fold_vertical(self, coeffs: &[FE]) -> FE {
151		inner_product_unchecked::<FE, _>(self.transpose().elems, coeffs.iter().copied())
152	}
153}
154
155impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
156where
157	F: Field,
158	FE: ExtensionField<F>,
159{
160	type Output = Self;
161
162	fn add(mut self, rhs: &Self) -> Self {
163		self.add_assign(rhs);
164		self
165	}
166}
167
168impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
169where
170	F: Field,
171	FE: ExtensionField<F>,
172{
173	type Output = Self;
174
175	fn sub(mut self, rhs: &Self) -> Self {
176		self.sub_assign(rhs);
177		self
178	}
179}
180
181impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
182where
183	F: Field,
184	FE: ExtensionField<F>,
185{
186	fn add_assign(&mut self, rhs: &Self) {
187		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
188			*self_i += *rhs_i;
189		}
190	}
191}
192
193impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
194where
195	F: Field,
196	FE: ExtensionField<F>,
197{
198	fn sub_assign(&mut self, rhs: &Self) {
199		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
200			*self_i -= *rhs_i;
201		}
202	}
203}
204
205impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
206where
207	F: Field,
208	FE: ExtensionField<F>,
209{
210	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
211		iter.fold(Self::default(), |sum, item| sum + item)
212	}
213}
214
215#[cfg(test)]
216mod tests {
217	use binius_field::{BinaryField128b, BinaryField8b};
218	use rand::{rngs::StdRng, SeedableRng};
219
220	use super::*;
221
222	#[test]
223	fn test_tensor_product() {
224		type F = BinaryField8b;
225		type FE = BinaryField128b;
226
227		let mut rng = StdRng::seed_from_u64(0);
228
229		let vert = FE::random(&mut rng);
230		let hztl = FE::random(&mut rng);
231
232		let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
233		assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
234	}
235
236	#[test]
237	fn test_try_extract_vertical() {
238		type F = BinaryField8b;
239		type FE = BinaryField128b;
240
241		let mut rng = StdRng::seed_from_u64(0);
242
243		let vert = FE::random(&mut rng);
244		let elem = TensorAlgebra::<F, _>::from_vertical(vert);
245		assert_eq!(elem.try_extract_vertical(), Some(vert));
246
247		// Scale horizontally by an extension element, and we should no longer be vertical.
248		let hztl = FE::new(1111);
249		let elem = elem.scale_horizontal(hztl);
250		assert_eq!(elem.try_extract_vertical(), None);
251
252		// Scale back by the inverse to get back to the vertical subring.
253		let hztl_inv = hztl.invert().unwrap();
254		let elem = elem.scale_horizontal(hztl_inv);
255		assert_eq!(elem.try_extract_vertical(), Some(vert));
256
257		// If we scale horizontally by an F element, we should remain in the vertical subring.
258		let hztl_subfield = FE::from(F::new(7));
259		let elem = elem.scale_horizontal(hztl_subfield);
260		assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
261	}
262}