binius_core/ring_switch/
tower_tensor_algebra.rs1use super::error::Error;
4use crate::{
5 tensor_algebra::TensorAlgebra,
6 tower::{PackedTop, TowerFamily},
7};
8
9type FExt<Tower> = <Tower as TowerFamily>::B128;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum TowerTensorAlgebra<Tower: TowerFamily> {
13 B1(TensorAlgebra<Tower::B1, Tower::B128>),
14 B8(TensorAlgebra<Tower::B8, Tower::B128>),
15 B16(TensorAlgebra<Tower::B16, Tower::B128>),
16 B32(TensorAlgebra<Tower::B32, Tower::B128>),
17 B64(TensorAlgebra<Tower::B64, Tower::B128>),
18 B128(TensorAlgebra<Tower::B128, Tower::B128>),
19}
20
21impl<Tower: TowerFamily> TowerTensorAlgebra<Tower> {
22 pub fn new(kappa: usize, elems: Vec<FExt<Tower>>) -> Result<Self, Error> {
28 match kappa {
29 7 => Ok(Self::B1(TensorAlgebra::new(elems))),
30 4 => Ok(Self::B8(TensorAlgebra::new(elems))),
31 3 => Ok(Self::B16(TensorAlgebra::new(elems))),
32 2 => Ok(Self::B32(TensorAlgebra::new(elems))),
33 1 => Ok(Self::B64(TensorAlgebra::new(elems))),
34 0 => Ok(Self::B128(TensorAlgebra::new(elems))),
35 _ => Err(Error::PackingDegreeNotSupported { kappa }),
36 }
37 }
38
39 pub fn zero(kappa: usize) -> Result<Self, Error> {
41 match kappa {
42 7 => Ok(Self::B1(TensorAlgebra::default())),
43 4 => Ok(Self::B8(TensorAlgebra::default())),
44 3 => Ok(Self::B16(TensorAlgebra::default())),
45 2 => Ok(Self::B32(TensorAlgebra::default())),
46 1 => Ok(Self::B64(TensorAlgebra::default())),
47 0 => Ok(Self::B128(TensorAlgebra::default())),
48 _ => Err(Error::PackingDegreeNotSupported { kappa }),
49 }
50 }
51
52 pub const fn kappa(&self) -> usize {
54 match self {
55 Self::B1(_) => 7,
56 Self::B8(_) => 4,
57 Self::B16(_) => 3,
58 Self::B32(_) => 2,
59 Self::B64(_) => 1,
60 Self::B128(_) => 0,
61 }
62 }
63
64 pub fn vertical_elems(&self) -> &[FExt<Tower>] {
66 match self {
67 Self::B1(elem) => elem.vertical_elems(),
68 Self::B8(elem) => elem.vertical_elems(),
69 Self::B16(elem) => elem.vertical_elems(),
70 Self::B32(elem) => elem.vertical_elems(),
71 Self::B64(elem) => elem.vertical_elems(),
72 Self::B128(elem) => elem.vertical_elems(),
73 }
74 }
75
76 pub fn scale_vertical(self, scalar: FExt<Tower>) -> Self {
78 match self {
79 Self::B1(elem) => Self::B1(elem.scale_vertical(scalar)),
80 Self::B8(elem) => Self::B8(elem.scale_vertical(scalar)),
81 Self::B16(elem) => Self::B16(elem.scale_vertical(scalar)),
82 Self::B32(elem) => Self::B32(elem.scale_vertical(scalar)),
83 Self::B64(elem) => Self::B64(elem.scale_vertical(scalar)),
84 Self::B128(elem) => Self::B128(elem.scale_vertical(scalar)),
85 }
86 }
87
88 pub fn add_assign(&mut self, rhs: &Self) -> Result<(), Error> {
94 match (self, rhs) {
95 (Self::B1(lhs), Self::B1(rhs)) => *lhs += rhs,
96 (Self::B8(lhs), Self::B8(rhs)) => *lhs += rhs,
97 (Self::B16(lhs), Self::B16(rhs)) => *lhs += rhs,
98 (Self::B32(lhs), Self::B32(rhs)) => *lhs += rhs,
99 (Self::B64(lhs), Self::B64(rhs)) => *lhs += rhs,
100 (Self::B128(lhs), Self::B128(rhs)) => *lhs += rhs,
101 _ => return Err(Error::TowerLevelMismatch),
102 }
103 Ok(())
104 }
105}
106
107impl<Tower> TowerTensorAlgebra<Tower>
108where
109 Tower: TowerFamily,
110 FExt<Tower>: PackedTop<Tower>,
111{
112 pub fn fold_vertical(self, coeffs: &[FExt<Tower>]) -> FExt<Tower> {
118 match self {
119 Self::B1(elem) => elem.fold_vertical(coeffs),
120 Self::B8(elem) => elem.fold_vertical(coeffs),
121 Self::B16(elem) => elem.fold_vertical(coeffs),
122 Self::B32(elem) => elem.fold_vertical(coeffs),
123 Self::B64(elem) => elem.fold_vertical(coeffs),
124 Self::B128(elem) => elem.fold_vertical(coeffs),
125 }
126 }
127}