use std::{
fmt::Debug,
iter::{self, Product, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
};
use binius_utils::iter::IterExtensions;
use bytemuck::Zeroable;
use rand::RngCore;
use super::{
arithmetic_traits::{Broadcast, MulAlpha, Square},
binary_field_arithmetic::TowerFieldArithmetic,
Error,
};
use crate::{
arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, ExtensionField, Field,
PackedExtension,
};
pub trait PackedField:
Default
+ Debug
+ Clone
+ Copy
+ Eq
+ Sized
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ AddAssign
+ SubAssign
+ MulAssign
+ Add<Self::Scalar, Output = Self>
+ Sub<Self::Scalar, Output = Self>
+ Mul<Self::Scalar, Output = Self>
+ AddAssign<Self::Scalar>
+ SubAssign<Self::Scalar>
+ MulAssign<Self::Scalar>
+ Sum
+ Product
+ Send
+ Sync
+ Zeroable
+ 'static
{
type Scalar: Field;
const LOG_WIDTH: usize;
const WIDTH: usize = 1 << Self::LOG_WIDTH;
unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
#[inline]
fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
(i < Self::WIDTH)
.then_some(unsafe { self.get_unchecked(i) })
.ok_or(Error::IndexOutOfRange {
index: i,
max: Self::WIDTH,
})
}
#[inline]
fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
(i < Self::WIDTH)
.then(|| unsafe { self.set_unchecked(i, scalar) })
.ok_or(Error::IndexOutOfRange {
index: i,
max: Self::WIDTH,
})
}
#[inline]
fn get(&self, i: usize) -> Self::Scalar {
self.get_checked(i).expect("index must be less than width")
}
#[inline]
fn set(&mut self, i: usize, scalar: Self::Scalar) {
self.set_checked(i, scalar).expect("index must be less than width")
}
#[inline]
fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send {
(0..Self::WIDTH).map_skippable(move |i|
unsafe { self.get_unchecked(i) })
}
#[inline]
fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + '_ {
(0..Self::WIDTH).map_skippable(move |i|
unsafe { self.get_unchecked(i) })
}
#[inline]
fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + '_ {
slice.iter().flat_map(Self::iter)
}
#[inline]
fn zero() -> Self {
Self::broadcast(Self::Scalar::ZERO)
}
#[inline]
fn one() -> Self {
Self::broadcast(Self::Scalar::ONE)
}
#[inline(always)]
fn set_single(scalar: Self::Scalar) -> Self {
let mut result = Self::default();
result.set(0, scalar);
result
}
fn random(rng: impl RngCore) -> Self;
fn broadcast(scalar: Self::Scalar) -> Self;
fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
let mut result = Self::default();
for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
result.set(i, val);
}
result
}
fn square(self) -> Self;
fn invert_or_zero(self) -> Self;
fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
#[inline]
fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
assert!(log_block_len <= Self::LOG_WIDTH);
assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
unsafe {
self.spread_unchecked(log_block_len, block_idx)
}
}
#[inline]
unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
let block_len = 1 << log_block_len;
let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
Self::from_scalars(
self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
)
}
}
pub fn iter_packed_slice_with_offset<P: PackedField>(
packed: &[P],
offset: usize,
) -> impl Iterator<Item = P::Scalar> + '_ + Send {
let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
(&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
} else {
(&[], 0)
};
P::iter_slice(packed).skip(offset)
}
#[inline]
pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
unsafe { packed[i / P::WIDTH].get_unchecked(i % P::WIDTH) }
}
#[inline]
pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
packed
.get_unchecked(i / P::WIDTH)
.get_unchecked(i % P::WIDTH)
}
pub fn get_packed_slice_checked<P: PackedField>(
packed: &[P],
i: usize,
) -> Result<P::Scalar, Error> {
packed
.get(i / P::WIDTH)
.map(|el| el.get(i % P::WIDTH))
.ok_or(Error::IndexOutOfRange {
index: i,
max: packed.len() * P::WIDTH,
})
}
pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
packed: &mut [P],
i: usize,
scalar: P::Scalar,
) {
unsafe {
packed
.get_unchecked_mut(i / P::WIDTH)
.set_unchecked(i % P::WIDTH, scalar)
}
}
pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
unsafe { packed[i / P::WIDTH].set_unchecked(i % P::WIDTH, scalar) }
}
pub fn set_packed_slice_checked<P: PackedField>(
packed: &mut [P],
i: usize,
scalar: P::Scalar,
) -> Result<(), Error> {
packed
.get_mut(i / P::WIDTH)
.map(|el| el.set(i % P::WIDTH, scalar))
.ok_or(Error::IndexOutOfRange {
index: i,
max: packed.len() * P::WIDTH,
})
}
pub fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
packed.len() * P::WIDTH
}
pub fn mul_by_subfield_scalar<P, FS>(val: P, multiplier: FS) -> P
where
P: PackedExtension<FS, Scalar: ExtensionField<FS>>,
FS: Field,
{
use crate::underlier::UnderlierType;
let subfield_bits = FS::Underlier::BITS;
let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
} else {
P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
}
}
impl<F: Field> Broadcast<F> for F {
fn broadcast(scalar: F) -> Self {
scalar
}
}
impl<T: TowerFieldArithmetic> MulAlpha for T {
#[inline]
fn mul_alpha(self) -> Self {
<Self as TowerFieldArithmetic>::multiply_alpha(self)
}
}
impl<F: Field> PackedField for F {
type Scalar = F;
const LOG_WIDTH: usize = 0;
#[inline]
unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
*self
}
#[inline]
unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
*self = scalar;
}
#[inline]
fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + '_ {
iter::once(*self)
}
#[inline]
fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send {
iter::once(self)
}
#[inline]
fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + '_ {
slice.iter().copied()
}
fn random(rng: impl RngCore) -> Self {
<Self as Field>::random(rng)
}
fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
panic!("cannot interleave when WIDTH = 1");
}
fn broadcast(scalar: Self::Scalar) -> Self {
scalar
}
fn square(self) -> Self {
<Self as Square>::square(self)
}
fn invert_or_zero(self) -> Self {
<Self as InvertOrZero>::invert_or_zero(self)
}
#[inline]
fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
f(0)
}
#[inline]
unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
self
}
}
pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
#[cfg(test)]
mod tests {
use rand::{
distributions::{Distribution, Uniform},
rngs::StdRng,
SeedableRng,
};
use super::*;
use crate::{
AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
BinaryField128b, BinaryField128bPolyval, BinaryField16b, BinaryField1b, BinaryField2b,
BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, ByteSlicedAES32x128b,
ByteSlicedAES32x16b, ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b,
PackedBinaryField128x1b, PackedBinaryField128x2b, PackedBinaryField128x4b,
PackedBinaryField16x16b, PackedBinaryField16x1b, PackedBinaryField16x2b,
PackedBinaryField16x32b, PackedBinaryField16x4b, PackedBinaryField16x8b,
PackedBinaryField1x128b, PackedBinaryField1x16b, PackedBinaryField1x1b,
PackedBinaryField1x2b, PackedBinaryField1x32b, PackedBinaryField1x4b,
PackedBinaryField1x64b, PackedBinaryField1x8b, PackedBinaryField256x1b,
PackedBinaryField256x2b, PackedBinaryField2x128b, PackedBinaryField2x16b,
PackedBinaryField2x1b, PackedBinaryField2x2b, PackedBinaryField2x32b,
PackedBinaryField2x4b, PackedBinaryField2x64b, PackedBinaryField2x8b,
PackedBinaryField32x16b, PackedBinaryField32x1b, PackedBinaryField32x2b,
PackedBinaryField32x4b, PackedBinaryField32x8b, PackedBinaryField4x128b,
PackedBinaryField4x16b, PackedBinaryField4x1b, PackedBinaryField4x2b,
PackedBinaryField4x32b, PackedBinaryField4x4b, PackedBinaryField4x64b,
PackedBinaryField4x8b, PackedBinaryField512x1b, PackedBinaryField64x1b,
PackedBinaryField64x2b, PackedBinaryField64x4b, PackedBinaryField64x8b,
PackedBinaryField8x16b, PackedBinaryField8x1b, PackedBinaryField8x2b,
PackedBinaryField8x32b, PackedBinaryField8x4b, PackedBinaryField8x64b,
PackedBinaryField8x8b, PackedBinaryPolyval1x128b, PackedBinaryPolyval2x128b,
PackedBinaryPolyval4x128b, PackedField,
};
trait PackedFieldTest {
fn run<P: PackedField>(&self);
}
fn run_for_all_packed_fields(test: impl PackedFieldTest) {
test.run::<BinaryField1b>();
test.run::<BinaryField2b>();
test.run::<BinaryField4b>();
test.run::<BinaryField8b>();
test.run::<BinaryField16b>();
test.run::<BinaryField32b>();
test.run::<BinaryField64b>();
test.run::<BinaryField128b>();
test.run::<PackedBinaryField1x1b>();
test.run::<PackedBinaryField2x1b>();
test.run::<PackedBinaryField1x2b>();
test.run::<PackedBinaryField4x1b>();
test.run::<PackedBinaryField2x2b>();
test.run::<PackedBinaryField1x4b>();
test.run::<PackedBinaryField8x1b>();
test.run::<PackedBinaryField4x2b>();
test.run::<PackedBinaryField2x4b>();
test.run::<PackedBinaryField1x8b>();
test.run::<PackedBinaryField16x1b>();
test.run::<PackedBinaryField8x2b>();
test.run::<PackedBinaryField4x4b>();
test.run::<PackedBinaryField2x8b>();
test.run::<PackedBinaryField1x16b>();
test.run::<PackedBinaryField32x1b>();
test.run::<PackedBinaryField16x2b>();
test.run::<PackedBinaryField8x4b>();
test.run::<PackedBinaryField4x8b>();
test.run::<PackedBinaryField2x16b>();
test.run::<PackedBinaryField1x32b>();
test.run::<PackedBinaryField64x1b>();
test.run::<PackedBinaryField32x2b>();
test.run::<PackedBinaryField16x4b>();
test.run::<PackedBinaryField8x8b>();
test.run::<PackedBinaryField4x16b>();
test.run::<PackedBinaryField2x32b>();
test.run::<PackedBinaryField1x64b>();
test.run::<PackedBinaryField128x1b>();
test.run::<PackedBinaryField64x2b>();
test.run::<PackedBinaryField32x4b>();
test.run::<PackedBinaryField16x8b>();
test.run::<PackedBinaryField8x16b>();
test.run::<PackedBinaryField4x32b>();
test.run::<PackedBinaryField2x64b>();
test.run::<PackedBinaryField1x128b>();
test.run::<PackedBinaryField256x1b>();
test.run::<PackedBinaryField128x2b>();
test.run::<PackedBinaryField64x4b>();
test.run::<PackedBinaryField32x8b>();
test.run::<PackedBinaryField16x16b>();
test.run::<PackedBinaryField8x32b>();
test.run::<PackedBinaryField4x64b>();
test.run::<PackedBinaryField2x128b>();
test.run::<PackedBinaryField512x1b>();
test.run::<PackedBinaryField256x2b>();
test.run::<PackedBinaryField128x4b>();
test.run::<PackedBinaryField64x8b>();
test.run::<PackedBinaryField32x16b>();
test.run::<PackedBinaryField16x32b>();
test.run::<PackedBinaryField8x64b>();
test.run::<PackedBinaryField4x128b>();
test.run::<AESTowerField8b>();
test.run::<AESTowerField16b>();
test.run::<AESTowerField32b>();
test.run::<AESTowerField64b>();
test.run::<AESTowerField128b>();
test.run::<PackedBinaryField1x8b>();
test.run::<PackedBinaryField2x8b>();
test.run::<PackedBinaryField1x16b>();
test.run::<PackedBinaryField4x8b>();
test.run::<PackedBinaryField2x16b>();
test.run::<PackedBinaryField1x32b>();
test.run::<PackedBinaryField8x8b>();
test.run::<PackedBinaryField4x16b>();
test.run::<PackedBinaryField2x32b>();
test.run::<PackedBinaryField1x64b>();
test.run::<PackedBinaryField16x8b>();
test.run::<PackedBinaryField8x16b>();
test.run::<PackedBinaryField4x32b>();
test.run::<PackedBinaryField2x64b>();
test.run::<PackedBinaryField1x128b>();
test.run::<PackedBinaryField32x8b>();
test.run::<PackedBinaryField16x16b>();
test.run::<PackedBinaryField8x32b>();
test.run::<PackedBinaryField4x64b>();
test.run::<PackedBinaryField2x128b>();
test.run::<PackedBinaryField64x8b>();
test.run::<PackedBinaryField32x16b>();
test.run::<PackedBinaryField16x32b>();
test.run::<PackedBinaryField8x64b>();
test.run::<PackedBinaryField4x128b>();
test.run::<ByteSlicedAES32x8b>();
test.run::<ByteSlicedAES32x64b>();
test.run::<ByteSlicedAES32x16b>();
test.run::<ByteSlicedAES32x32b>();
test.run::<ByteSlicedAES32x128b>();
test.run::<BinaryField128bPolyval>();
test.run::<PackedBinaryPolyval1x128b>();
test.run::<PackedBinaryPolyval2x128b>();
test.run::<PackedBinaryPolyval4x128b>();
}
fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
let packed = P::random(&mut rng);
let mut iter = packed.iter();
for i in 0..P::WIDTH {
assert_eq!(packed.get(i), iter.next().unwrap());
}
assert!(iter.next().is_none());
}
fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
let packed = P::random(&mut rng);
let mut iter = packed.into_iter();
for i in 0..P::WIDTH {
assert_eq!(packed.get(i), iter.next().unwrap());
}
assert!(iter.next().is_none());
}
fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
for len in [0, 1, 5] {
let packed = std::iter::repeat_with(|| P::random(&mut rng))
.take(len)
.collect::<Vec<_>>();
let elements_count = len * P::WIDTH;
for offset in [
0,
1,
Uniform::new(0, elements_count.max(1)).sample(&mut rng),
elements_count.saturating_sub(1),
elements_count,
] {
let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
let expected = (offset..elements_count)
.map(|i| get_packed_slice(&packed, i))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
}
}
}
struct PackedFieldIterationTest;
impl PackedFieldTest for PackedFieldIterationTest {
fn run<P: PackedField>(&self) {
let mut rng = StdRng::seed_from_u64(0);
check_value_iteration::<P>(&mut rng);
check_ref_iteration::<P>(&mut rng);
check_slice_iteration::<P>(&mut rng);
}
}
#[test]
fn test_iteration() {
run_for_all_packed_fields(PackedFieldIterationTest);
}
}