1use std::{
8 fmt::Debug,
9 iter::{self, Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14 iter::IterExtensions,
15 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18
19use super::{
20 Error, PackedExtension, Random,
21 arithmetic_traits::{Broadcast, MulAlpha, Square},
22 binary_field_arithmetic::TowerFieldArithmetic,
23};
24use crate::{BinaryField, Field, arithmetic_traits::InvertOrZero};
25
26pub trait PackedField:
32 Default
33 + Debug
34 + Clone
35 + Copy
36 + Eq
37 + Sized
38 + Add<Output = Self>
39 + Sub<Output = Self>
40 + Mul<Output = Self>
41 + AddAssign
42 + SubAssign
43 + MulAssign
44 + Add<Self::Scalar, Output = Self>
45 + Sub<Self::Scalar, Output = Self>
46 + Mul<Self::Scalar, Output = Self>
47 + AddAssign<Self::Scalar>
48 + SubAssign<Self::Scalar>
49 + MulAssign<Self::Scalar>
50 + Sum
52 + Product
53 + Send
54 + Sync
55 + Zeroable
56 + Random
57 + 'static
58{
59 type Scalar: Field;
60
61 const LOG_WIDTH: usize;
63
64 const WIDTH: usize = 1 << Self::LOG_WIDTH;
68
69 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
73
74 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
78
79 #[inline]
81 fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
82 (i < Self::WIDTH)
83 .then_some(unsafe { self.get_unchecked(i) })
84 .ok_or(Error::IndexOutOfRange {
85 index: i,
86 max: Self::WIDTH,
87 })
88 }
89
90 #[inline]
92 fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
93 (i < Self::WIDTH)
94 .then(|| unsafe { self.set_unchecked(i, scalar) })
95 .ok_or(Error::IndexOutOfRange {
96 index: i,
97 max: Self::WIDTH,
98 })
99 }
100
101 #[inline]
103 fn get(&self, i: usize) -> Self::Scalar {
104 self.get_checked(i).expect("index must be less than width")
105 }
106
107 #[inline]
109 fn set(&mut self, i: usize, scalar: Self::Scalar) {
110 self.set_checked(i, scalar).expect("index must be less than width")
111 }
112
113 #[inline]
114 fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
115 (0..Self::WIDTH).map_skippable(move |i|
116 unsafe { self.get_unchecked(i) })
118 }
119
120 #[inline]
121 fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
122 (0..Self::WIDTH).map_skippable(move |i|
123 unsafe { self.get_unchecked(i) })
125 }
126
127 #[inline]
128 fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
129 slice.iter().flat_map(Self::iter)
130 }
131
132 #[inline]
133 fn zero() -> Self {
134 Self::broadcast(Self::Scalar::ZERO)
135 }
136
137 #[inline]
138 fn one() -> Self {
139 Self::broadcast(Self::Scalar::ONE)
140 }
141
142 #[inline(always)]
144 fn set_single(scalar: Self::Scalar) -> Self {
145 let mut result = Self::default();
146 result.set(0, scalar);
147
148 result
149 }
150
151 fn broadcast(scalar: Self::Scalar) -> Self;
152
153 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
155
156 fn try_from_fn<E>(
158 mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
159 ) -> Result<Self, E> {
160 let mut result = Self::default();
161 for i in 0..Self::WIDTH {
162 let scalar = f(i)?;
163 unsafe {
164 result.set_unchecked(i, scalar);
165 };
166 }
167 Ok(result)
168 }
169
170 #[inline]
176 fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
177 let mut result = Self::default();
178 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
179 result.set(i, val);
180 }
181 result
182 }
183
184 fn square(self) -> Self;
186
187 fn pow(self, exp: u64) -> Self {
189 let mut res = Self::one();
190 for i in (0..64).rev() {
191 res = res.square();
192 if ((exp >> i) & 1) == 1 {
193 res.mul_assign(self)
194 }
195 }
196 res
197 }
198
199 fn invert_or_zero(self) -> Self;
201
202 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
218
219 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
232
233 #[inline]
265 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
266 assert!(log_block_len <= Self::LOG_WIDTH);
267 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
268
269 unsafe {
271 self.spread_unchecked(log_block_len, block_idx)
272 }
273 }
274
275 #[inline]
280 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
281 let block_len = 1 << log_block_len;
282 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
283
284 Self::from_scalars(
285 self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
286 )
287 }
288}
289
290#[inline]
295pub fn iter_packed_slice_with_offset<P: PackedField>(
296 packed: &[P],
297 offset: usize,
298) -> impl Iterator<Item = P::Scalar> + '_ + Send {
299 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
300 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
301 } else {
302 (&[], 0)
303 };
304
305 P::iter_slice(packed).skip(offset)
306}
307
308#[inline(always)]
309pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
310 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
311
312 unsafe { get_packed_slice_unchecked(packed, i) }
313}
314
315#[inline(always)]
319pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
320 unsafe {
327 packed
328 .get_unchecked(i >> P::LOG_WIDTH)
329 .get_unchecked(i % P::WIDTH)
330 }
331}
332
333#[inline]
334pub fn get_packed_slice_checked<P: PackedField>(
335 packed: &[P],
336 i: usize,
337) -> Result<P::Scalar, Error> {
338 if i >> P::LOG_WIDTH < packed.len() {
339 Ok(unsafe { get_packed_slice_unchecked(packed, i) })
341 } else {
342 Err(Error::IndexOutOfRange {
343 index: i,
344 max: len_packed_slice(packed),
345 })
346 }
347}
348
349#[inline]
353pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
354 packed: &mut [P],
355 i: usize,
356 scalar: P::Scalar,
357) {
358 unsafe {
364 packed
365 .get_unchecked_mut(i >> P::LOG_WIDTH)
366 .set_unchecked(i % P::WIDTH, scalar)
367 }
368}
369
370#[inline]
371pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
372 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
373
374 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
375}
376
377#[inline]
378pub fn set_packed_slice_checked<P: PackedField>(
379 packed: &mut [P],
380 i: usize,
381 scalar: P::Scalar,
382) -> Result<(), Error> {
383 if i >> P::LOG_WIDTH < packed.len() {
384 unsafe { set_packed_slice_unchecked(packed, i, scalar) };
386 Ok(())
387 } else {
388 Err(Error::IndexOutOfRange {
389 index: i,
390 max: len_packed_slice(packed),
391 })
392 }
393}
394
395#[inline(always)]
396pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
397 packed.len() << P::LOG_WIDTH
398}
399
400#[inline]
404pub fn packed_from_fn_with_offset<P: PackedField>(
405 offset: usize,
406 mut f: impl FnMut(usize) -> P::Scalar,
407) -> P {
408 P::from_fn(|i| f(i + offset * P::WIDTH))
409}
410
411pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
413 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
414}
415
416pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
418 scalars
419 .chunks(P::WIDTH)
420 .map(|chunk| P::from_scalars(chunk.iter().copied()))
421 .collect()
422}
423
424#[derive(Clone)]
426pub struct PackedSlice<'a, P: PackedField> {
427 slice: &'a [P],
428 len: usize,
429}
430
431impl<'a, P: PackedField> PackedSlice<'a, P> {
432 #[inline(always)]
433 pub fn new(slice: &'a [P]) -> Self {
434 Self {
435 slice,
436 len: len_packed_slice(slice),
437 }
438 }
439
440 #[inline(always)]
441 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
442 assert!(len <= len_packed_slice(slice));
443
444 Self { slice, len }
445 }
446}
447
448impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
449 #[inline(always)]
450 fn len(&self) -> usize {
451 self.len
452 }
453
454 #[inline(always)]
455 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
456 unsafe { get_packed_slice_unchecked(self.slice, index) }
457 }
458}
459
460pub struct PackedSliceMut<'a, P: PackedField> {
462 slice: &'a mut [P],
463 len: usize,
464}
465
466impl<'a, P: PackedField> PackedSliceMut<'a, P> {
467 #[inline(always)]
468 pub fn new(slice: &'a mut [P]) -> Self {
469 let len = len_packed_slice(slice);
470 Self { slice, len }
471 }
472
473 #[inline(always)]
474 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
475 assert!(len <= len_packed_slice(slice));
476
477 Self { slice, len }
478 }
479}
480
481impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
482 #[inline(always)]
483 fn len(&self) -> usize {
484 self.len
485 }
486
487 #[inline(always)]
488 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
489 unsafe { get_packed_slice_unchecked(self.slice, index) }
490 }
491}
492impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
493 #[inline(always)]
494 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
495 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
496 }
497}
498
499impl<F: Field> Broadcast<F> for F {
500 #[inline]
501 fn broadcast(scalar: F) -> Self {
502 scalar
503 }
504}
505
506impl<T: TowerFieldArithmetic> MulAlpha for T {
507 #[inline]
508 fn mul_alpha(self) -> Self {
509 <Self as TowerFieldArithmetic>::multiply_alpha(self)
510 }
511}
512
513impl<F: Field> PackedField for F {
514 type Scalar = F;
515
516 const LOG_WIDTH: usize = 0;
517
518 #[inline]
519 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
520 *self
521 }
522
523 #[inline]
524 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
525 *self = scalar;
526 }
527
528 #[inline]
529 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
530 iter::once(*self)
531 }
532
533 #[inline]
534 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
535 iter::once(self)
536 }
537
538 #[inline]
539 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
540 slice.iter().copied()
541 }
542
543 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
544 panic!("cannot interleave when WIDTH = 1");
545 }
546
547 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
548 panic!("cannot transpose when WIDTH = 1");
549 }
550
551 #[inline]
552 fn broadcast(scalar: Self::Scalar) -> Self {
553 scalar
554 }
555
556 #[inline]
557 fn zero() -> Self {
558 Self::ZERO
559 }
560
561 #[inline]
562 fn one() -> Self {
563 Self::ONE
564 }
565
566 #[inline]
567 fn square(self) -> Self {
568 <Self as Square>::square(self)
569 }
570
571 #[inline]
572 fn invert_or_zero(self) -> Self {
573 <Self as InvertOrZero>::invert_or_zero(self)
574 }
575
576 #[inline]
577 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
578 f(0)
579 }
580
581 #[inline]
582 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
583 self
584 }
585}
586
587pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
589
590impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
591
592#[cfg(test)]
593mod tests {
594 use itertools::Itertools;
595 use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
596
597 use super::*;
598 use crate::{
599 AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedBinaryGhash1x128b,
600 PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
601 arch::{
602 packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*, packed_32::*,
603 packed_64::*, packed_128::*, packed_256::*, packed_512::*, packed_aes_8::*,
604 packed_aes_16::*, packed_aes_32::*, packed_aes_64::*, packed_aes_128::*,
605 packed_aes_256::*, packed_aes_512::*,
606 },
607 };
608
609 trait PackedFieldTest {
610 fn run<P: PackedField>(&self);
611 }
612
613 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
615 test.run::<BinaryField1b>();
617 test.run::<PackedBinaryField1x1b>();
618 test.run::<PackedBinaryField2x1b>();
619 test.run::<PackedBinaryField4x1b>();
620 test.run::<PackedBinaryField8x1b>();
621 test.run::<PackedBinaryField16x1b>();
622 test.run::<PackedBinaryField32x1b>();
623 test.run::<PackedBinaryField64x1b>();
624 test.run::<PackedBinaryField128x1b>();
625 test.run::<PackedBinaryField256x1b>();
626 test.run::<PackedBinaryField512x1b>();
627
628 test.run::<AESTowerField8b>();
630 test.run::<PackedAESBinaryField1x8b>();
631 test.run::<PackedAESBinaryField2x8b>();
632 test.run::<PackedAESBinaryField4x8b>();
633 test.run::<PackedAESBinaryField8x8b>();
634 test.run::<PackedAESBinaryField16x8b>();
635 test.run::<PackedAESBinaryField32x8b>();
636 test.run::<PackedAESBinaryField64x8b>();
637
638 test.run::<BinaryField128bGhash>();
640 test.run::<PackedBinaryGhash1x128b>();
641 test.run::<PackedBinaryGhash2x128b>();
642 test.run::<PackedBinaryGhash4x128b>();
643 }
644
645 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
646 let packed = P::random(&mut rng);
647 let mut iter = packed.iter();
648 for i in 0..P::WIDTH {
649 assert_eq!(packed.get(i), iter.next().unwrap());
650 }
651 assert!(iter.next().is_none());
652 }
653
654 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
655 let packed = P::random(&mut rng);
656 let mut iter = packed.into_iter();
657 for i in 0..P::WIDTH {
658 assert_eq!(packed.get(i), iter.next().unwrap());
659 }
660 assert!(iter.next().is_none());
661 }
662
663 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
664 for len in [0, 1, 5] {
665 let packed = std::iter::repeat_with(|| P::random(&mut rng))
666 .take(len)
667 .collect::<Vec<_>>();
668
669 let elements_count = len * P::WIDTH;
670 for offset in [
671 0,
672 1,
673 rng.random_range(0..elements_count.max(1)),
674 elements_count.saturating_sub(1),
675 elements_count,
676 ] {
677 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
678 let expected = (offset..elements_count)
679 .map(|i| get_packed_slice(&packed, i))
680 .collect::<Vec<_>>();
681
682 assert_eq!(actual, expected);
683 }
684 }
685 }
686
687 struct PackedFieldIterationTest;
688
689 impl PackedFieldTest for PackedFieldIterationTest {
690 fn run<P: PackedField>(&self) {
691 let mut rng = StdRng::seed_from_u64(0);
692
693 check_value_iteration::<P>(&mut rng);
694 check_ref_iteration::<P>(&mut rng);
695 check_slice_iteration::<P>(&mut rng);
696 }
697 }
698
699 #[test]
700 fn test_iteration() {
701 run_for_all_packed_fields(&PackedFieldIterationTest);
702 }
703
704 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
705 assert_eq!(collection.len(), expected.len());
706
707 for (i, v) in expected.iter().enumerate() {
708 assert_eq!(&collection.get(i), v);
709 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
710 }
711 }
712
713 fn check_collection_get_set<F: Field>(
714 collection: &mut impl RandomAccessSequenceMut<F>,
715 random: &mut impl FnMut() -> F,
716 ) {
717 for i in 0..collection.len() {
718 let value = random();
719 collection.set(i, value);
720 assert_eq!(collection.get(i), value);
721 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
722 }
723 }
724
725 #[test]
726 fn check_packed_slice() {
727 let slice: &[PackedAESBinaryField16x8b] = &[];
728 let packed_slice = PackedSlice::new(slice);
729 check_collection(&packed_slice, &[]);
730 let packed_slice = PackedSlice::new_with_len(slice, 0);
731 check_collection(&packed_slice, &[]);
732
733 let mut rng = StdRng::seed_from_u64(0);
734 let slice: &[PackedAESBinaryField16x8b] = &[
735 PackedAESBinaryField16x8b::random(&mut rng),
736 PackedAESBinaryField16x8b::random(&mut rng),
737 ];
738 let packed_slice = PackedSlice::new(slice);
739 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
740
741 let packed_slice = PackedSlice::new_with_len(slice, 3);
742 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
743 }
744
745 #[test]
746 fn check_packed_slice_mut() {
747 let mut rng = StdRng::seed_from_u64(0);
748 let mut random = || AESTowerField8b::random(&mut rng);
749
750 let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
751 let packed_slice = PackedSliceMut::new(slice);
752 check_collection(&packed_slice, &[]);
753 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
754 check_collection(&packed_slice, &[]);
755
756 let mut rng = StdRng::seed_from_u64(0);
757 let slice: &mut [PackedAESBinaryField16x8b] = &mut [
758 PackedAESBinaryField16x8b::random(&mut rng),
759 PackedAESBinaryField16x8b::random(&mut rng),
760 ];
761 let values = PackedField::iter_slice(slice).collect_vec();
762 let mut packed_slice = PackedSliceMut::new(slice);
763 check_collection(&packed_slice, &values);
764 check_collection_get_set(&mut packed_slice, &mut random);
765
766 let values = PackedField::iter_slice(slice).collect_vec();
767 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
768 check_collection(&packed_slice, &values[..3]);
769 check_collection_get_set(&mut packed_slice, &mut random);
770 }
771}