binius_core/ring_switch/
tower_tensor_algebra.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use super::error::Error;
4use crate::{
5	tensor_algebra::TensorAlgebra,
6	tower::{PackedTop, TowerFamily},
7};
8
9type FExt<Tower> = <Tower as TowerFamily>::B128;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum TowerTensorAlgebra<Tower: TowerFamily> {
13	B1(TensorAlgebra<Tower::B1, Tower::B128>),
14	B8(TensorAlgebra<Tower::B8, Tower::B128>),
15	B16(TensorAlgebra<Tower::B16, Tower::B128>),
16	B32(TensorAlgebra<Tower::B32, Tower::B128>),
17	B64(TensorAlgebra<Tower::B64, Tower::B128>),
18	B128(TensorAlgebra<Tower::B128, Tower::B128>),
19}
20
21impl<Tower: TowerFamily> TowerTensorAlgebra<Tower> {
22	/// Constructs an element from a vector of vertical subring elements.
23	///
24	/// ## Preconditions
25	///
26	/// * `elems` must have length `FE::DEGREE`, otherwise this will pad or truncate.
27	pub fn new(kappa: usize, elems: Vec<FExt<Tower>>) -> Result<Self, Error> {
28		match kappa {
29			7 => Ok(Self::B1(TensorAlgebra::new(elems))),
30			4 => Ok(Self::B8(TensorAlgebra::new(elems))),
31			3 => Ok(Self::B16(TensorAlgebra::new(elems))),
32			2 => Ok(Self::B32(TensorAlgebra::new(elems))),
33			1 => Ok(Self::B64(TensorAlgebra::new(elems))),
34			0 => Ok(Self::B128(TensorAlgebra::new(elems))),
35			_ => Err(Error::PackingDegreeNotSupported { kappa }),
36		}
37	}
38
39	/// Returns the additive identity element, zero.
40	pub fn zero(kappa: usize) -> Result<Self, Error> {
41		match kappa {
42			7 => Ok(Self::B1(TensorAlgebra::default())),
43			4 => Ok(Self::B8(TensorAlgebra::default())),
44			3 => Ok(Self::B16(TensorAlgebra::default())),
45			2 => Ok(Self::B32(TensorAlgebra::default())),
46			1 => Ok(Self::B64(TensorAlgebra::default())),
47			0 => Ok(Self::B128(TensorAlgebra::default())),
48			_ => Err(Error::PackingDegreeNotSupported { kappa }),
49		}
50	}
51
52	/// Returns $\kappa$, the base-2 logarithm of the extension degree.
53	pub const fn kappa(&self) -> usize {
54		match self {
55			Self::B1(_) => 7,
56			Self::B8(_) => 4,
57			Self::B16(_) => 3,
58			Self::B32(_) => 2,
59			Self::B64(_) => 1,
60			Self::B128(_) => 0,
61		}
62	}
63
64	/// Returns a slice of the vertical subfield elements composing the tensor algebra element.
65	pub fn vertical_elems(&self) -> &[FExt<Tower>] {
66		match self {
67			Self::B1(elem) => elem.vertical_elems(),
68			Self::B8(elem) => elem.vertical_elems(),
69			Self::B16(elem) => elem.vertical_elems(),
70			Self::B32(elem) => elem.vertical_elems(),
71			Self::B64(elem) => elem.vertical_elems(),
72			Self::B128(elem) => elem.vertical_elems(),
73		}
74	}
75
76	/// Multiply by an element from the vertical subring.
77	pub fn scale_vertical(self, scalar: FExt<Tower>) -> Self {
78		match self {
79			Self::B1(elem) => Self::B1(elem.scale_vertical(scalar)),
80			Self::B8(elem) => Self::B8(elem.scale_vertical(scalar)),
81			Self::B16(elem) => Self::B16(elem.scale_vertical(scalar)),
82			Self::B32(elem) => Self::B32(elem.scale_vertical(scalar)),
83			Self::B64(elem) => Self::B64(elem.scale_vertical(scalar)),
84			Self::B128(elem) => Self::B128(elem.scale_vertical(scalar)),
85		}
86	}
87
88	/// Adds the right hand size into the current value.
89	///
90	/// ## Throws
91	///
92	/// * [`Error::TowerLevelMismatch`] if the arguments' underlying tower level do not match
93	pub fn add_assign(&mut self, rhs: &Self) -> Result<(), Error> {
94		match (self, rhs) {
95			(Self::B1(lhs), Self::B1(rhs)) => *lhs += rhs,
96			(Self::B8(lhs), Self::B8(rhs)) => *lhs += rhs,
97			(Self::B16(lhs), Self::B16(rhs)) => *lhs += rhs,
98			(Self::B32(lhs), Self::B32(rhs)) => *lhs += rhs,
99			(Self::B64(lhs), Self::B64(rhs)) => *lhs += rhs,
100			(Self::B128(lhs), Self::B128(rhs)) => *lhs += rhs,
101			_ => return Err(Error::TowerLevelMismatch),
102		}
103		Ok(())
104	}
105}
106
107impl<Tower> TowerTensorAlgebra<Tower>
108where
109	Tower: TowerFamily,
110	FExt<Tower>: PackedTop<Tower>,
111{
112	/// Fold the tensor algebra element into a field element by scaling the rows and accumulating.
113	///
114	/// ## Preconditions
115	///
116	/// * `coeffs` must have length $2^\kappa$
117	pub fn fold_vertical(self, coeffs: &[FExt<Tower>]) -> FExt<Tower> {
118		match self {
119			Self::B1(elem) => elem.fold_vertical(coeffs),
120			Self::B8(elem) => elem.fold_vertical(coeffs),
121			Self::B16(elem) => elem.fold_vertical(coeffs),
122			Self::B32(elem) => elem.fold_vertical(coeffs),
123			Self::B64(elem) => elem.fold_vertical(coeffs),
124			Self::B128(elem) => elem.fold_vertical(coeffs),
125		}
126	}
127}