use binius_field::{square_transpose, ExtensionField, Field, PackedExtension};
use binius_utils::checked_arithmetics::checked_log_2;
use std::{
iter::Sum,
marker::PhantomData,
mem,
ops::{Add, AddAssign, Sub, SubAssign},
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
pub elems: Vec<FE>,
_marker: PhantomData<F>,
}
impl<F, FE> Default for TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
fn default() -> Self {
Self {
elems: vec![FE::default(); FE::DEGREE],
_marker: PhantomData,
}
}
}
impl<F, FE> TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
pub fn new(mut elems: Vec<FE>) -> Self {
elems.resize(FE::DEGREE, FE::ZERO);
Self {
elems,
_marker: PhantomData,
}
}
pub const fn kappa() -> usize {
checked_log_2(FE::DEGREE)
}
pub fn byte_size() -> usize {
mem::size_of::<FE>() << Self::kappa()
}
pub fn one() -> Self {
let mut one = Self::default();
one.elems[0] = FE::ONE;
one
}
pub fn vertical_elems(&self) -> &[FE] {
&self.elems
}
pub fn tensor(vertical: FE, horizontal: FE) -> Self {
let elems = horizontal
.iter_bases()
.map(|base| vertical * base)
.collect();
Self {
elems,
_marker: PhantomData,
}
}
pub fn from_vertical(x: FE) -> Self {
let mut elems = vec![FE::ZERO; FE::DEGREE];
elems[0] = x;
Self {
elems,
_marker: PhantomData,
}
}
pub fn try_extract_vertical(&self) -> Option<FE> {
self.elems
.iter()
.skip(1)
.all(|&elem| elem == FE::ZERO)
.then_some(self.elems[0])
}
pub fn scale_vertical(mut self, scalar: FE) -> Self {
for elem_i in self.elems.iter_mut() {
*elem_i *= scalar;
}
self
}
}
impl<F, FE> TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F> + PackedExtension<F>,
FE::Scalar: ExtensionField<F>,
{
pub fn scale_horizontal(self, scalar: FE) -> Self {
self.transpose().scale_vertical(scalar).transpose()
}
pub fn transpose(mut self) -> Self {
square_transpose(Self::kappa(), FE::cast_bases_mut(&mut self.elems))
.expect("transpose dimensions are square by struct invariant");
self
}
}
impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
type Output = Self;
fn add(mut self, rhs: &Self) -> Self {
self.add_assign(rhs);
self
}
}
impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
type Output = Self;
fn sub(mut self, rhs: &Self) -> Self {
self.sub_assign(rhs);
self
}
}
impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
fn add_assign(&mut self, rhs: &Self) {
for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
*self_i += *rhs_i;
}
}
}
impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
fn sub_assign(&mut self, rhs: &Self) {
for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
*self_i -= *rhs_i;
}
}
}
impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
where
F: Field,
FE: ExtensionField<F>,
{
fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
iter.fold(Self::default(), |sum, item| sum + item)
}
}
#[cfg(test)]
mod tests {
use super::*;
use binius_field::{BinaryField128b, BinaryField8b};
use rand::{rngs::StdRng, SeedableRng};
#[test]
fn test_tensor_product() {
type F = BinaryField8b;
type FE = BinaryField128b;
let mut rng = StdRng::seed_from_u64(0);
let vert = FE::random(&mut rng);
let hztl = FE::random(&mut rng);
let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
}
#[test]
fn test_try_extract_vertical() {
type F = BinaryField8b;
type FE = BinaryField128b;
let mut rng = StdRng::seed_from_u64(0);
let vert = FE::random(&mut rng);
let elem = TensorAlgebra::<F, _>::from_vertical(vert);
assert_eq!(elem.try_extract_vertical(), Some(vert));
let hztl = FE::new(1111);
let elem = elem.scale_horizontal(hztl);
assert_eq!(elem.try_extract_vertical(), None);
let hztl_inv = hztl.invert().unwrap();
let elem = elem.scale_horizontal(hztl_inv);
assert_eq!(elem.try_extract_vertical(), Some(vert));
let hztl_subfield = FE::from(F::new(7));
let elem = elem.scale_horizontal(hztl_subfield);
assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
}
}