use super::{
arithmetic_traits::{Broadcast, MulAlpha, Square},
binary_field_arithmetic::TowerFieldArithmetic,
Error,
};
use crate::{
arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, ExtensionField, Field,
PackedExtension,
};
use binius_utils::iter::IterExtensions;
use bytemuck::Zeroable;
use rand::RngCore;
use std::{
fmt::Debug,
iter::{self, Product, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
};
pub trait PackedField:
Default
+ Debug
+ Clone
+ Copy
+ Eq
+ Sized
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ AddAssign
+ SubAssign
+ MulAssign
+ Add<Self::Scalar, Output = Self>
+ Sub<Self::Scalar, Output = Self>
+ Mul<Self::Scalar, Output = Self>
+ AddAssign<Self::Scalar>
+ SubAssign<Self::Scalar>
+ MulAssign<Self::Scalar>
+ Sum
+ Product
+ Send
+ Sync
+ Zeroable
+ 'static
{
type Scalar: Field;
const LOG_WIDTH: usize;
const WIDTH: usize = 1 << Self::LOG_WIDTH;
unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
#[inline]
fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
(i < Self::WIDTH)
.then_some(unsafe { self.get_unchecked(i) })
.ok_or(Error::IndexOutOfRange {
index: i,
max: Self::WIDTH,
})
}
#[inline]
fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
(i < Self::WIDTH)
.then(|| unsafe { self.set_unchecked(i, scalar) })
.ok_or(Error::IndexOutOfRange {
index: i,
max: Self::WIDTH,
})
}
#[inline]
fn get(&self, i: usize) -> Self::Scalar {
self.get_checked(i).expect("index must be less than width")
}
#[inline]
fn set(&mut self, i: usize, scalar: Self::Scalar) {
self.set_checked(i, scalar).expect("index must be less than width")
}
#[inline]
fn into_iter(self) -> impl Iterator<Item=Self::Scalar> {
(0..Self::WIDTH).map_skippable(move |i|
unsafe { self.get_unchecked(i) })
}
#[inline]
fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send {
(0..Self::WIDTH).map_skippable(move |i|
unsafe { self.get_unchecked(i) })
}
#[inline]
fn zero() -> Self {
Self::broadcast(Self::Scalar::ZERO)
}
#[inline]
fn one() -> Self {
Self::broadcast(Self::Scalar::ONE)
}
#[inline(always)]
fn set_single(scalar: Self::Scalar) -> Self {
let mut result = Self::default();
result.set(0, scalar);
result
}
fn random(rng: impl RngCore) -> Self;
fn broadcast(scalar: Self::Scalar) -> Self;
fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
let mut result = Self::default();
for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
result.set(i, val);
}
result
}
fn square(self) -> Self;
fn invert_or_zero(self) -> Self;
fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
assert!(log_block_len <= Self::LOG_WIDTH);
let block_len = 1 << log_block_len;
let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
assert!(block_idx < repeat);
Self::from_scalars(
self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
)
}
}
pub fn iter_packed_slice<P: PackedField>(
packed: &[P],
) -> impl Iterator<Item = P::Scalar> + '_ + Send {
packed.iter().flat_map(|packed_i| packed_i.iter())
}
#[inline]
pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
unsafe { packed[i / P::WIDTH].get_unchecked(i % P::WIDTH) }
}
#[inline]
pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
packed
.get_unchecked(i / P::WIDTH)
.get_unchecked(i % P::WIDTH)
}
pub fn get_packed_slice_checked<P: PackedField>(
packed: &[P],
i: usize,
) -> Result<P::Scalar, Error> {
packed
.get(i / P::WIDTH)
.map(|el| el.get(i % P::WIDTH))
.ok_or(Error::IndexOutOfRange {
index: i,
max: packed.len() * P::WIDTH,
})
}
pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
packed: &mut [P],
i: usize,
scalar: P::Scalar,
) {
unsafe {
packed
.get_unchecked_mut(i / P::WIDTH)
.set_unchecked(i % P::WIDTH, scalar)
}
}
pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
unsafe { packed[i / P::WIDTH].set_unchecked(i % P::WIDTH, scalar) }
}
pub fn set_packed_slice_checked<P: PackedField>(
packed: &mut [P],
i: usize,
scalar: P::Scalar,
) -> Result<(), Error> {
packed
.get_mut(i / P::WIDTH)
.map(|el| el.set(i % P::WIDTH, scalar))
.ok_or(Error::IndexOutOfRange {
index: i,
max: packed.len() * P::WIDTH,
})
}
pub fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
packed.len() * P::WIDTH
}
pub fn mul_by_subfield_scalar<P, FS>(val: P, multiplier: FS) -> P
where
P: PackedExtension<FS, Scalar: ExtensionField<FS>>,
FS: Field,
{
use crate::underlier::UnderlierType;
let subfield_bits = FS::Underlier::BITS;
let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
} else {
P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
}
}
impl<F: Field> Broadcast<F> for F {
fn broadcast(scalar: F) -> Self {
scalar
}
}
impl<T: TowerFieldArithmetic> MulAlpha for T {
#[inline]
fn mul_alpha(self) -> Self {
<Self as TowerFieldArithmetic>::multiply_alpha(self)
}
}
impl<F: Field> PackedField for F {
type Scalar = F;
const LOG_WIDTH: usize = 0;
#[inline]
unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
*self
}
#[inline]
unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
*self = scalar;
}
fn iter(&self) -> impl Iterator<Item = Self::Scalar> {
iter::once(*self)
}
fn random(rng: impl RngCore) -> Self {
<Self as Field>::random(rng)
}
fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
panic!("cannot interleave when WIDTH = 1");
}
fn broadcast(scalar: Self::Scalar) -> Self {
scalar
}
fn square(self) -> Self {
<Self as Square>::square(self)
}
fn invert_or_zero(self) -> Self {
<Self as InvertOrZero>::invert_or_zero(self)
}
#[inline]
fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
f(0)
}
fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
debug_assert_eq!(log_block_len, 0);
debug_assert_eq!(block_idx, 0);
self
}
}
pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}