1use std::{
8 fmt::Debug,
9 iter::{self, Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14 iter::IterExtensions,
15 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18
19use super::{
20 Error, Random,
21 arithmetic_traits::{Broadcast, MulAlpha, Square},
22 binary_field_arithmetic::TowerFieldArithmetic,
23};
24use crate::{
25 BinaryField, Field, PackedExtension, arithmetic_traits::InvertOrZero,
26 is_packed_field_indexable, underlier::WithUnderlier,
27};
28
29pub trait PackedField:
35 Default
36 + Debug
37 + Clone
38 + Copy
39 + Eq
40 + Sized
41 + Add<Output = Self>
42 + Sub<Output = Self>
43 + Mul<Output = Self>
44 + AddAssign
45 + SubAssign
46 + MulAssign
47 + Add<Self::Scalar, Output = Self>
48 + Sub<Self::Scalar, Output = Self>
49 + Mul<Self::Scalar, Output = Self>
50 + AddAssign<Self::Scalar>
51 + SubAssign<Self::Scalar>
52 + MulAssign<Self::Scalar>
53 + Sum
55 + Product
56 + Send
57 + Sync
58 + Zeroable
59 + Random
60 + 'static
61{
62 type Scalar: Field;
63
64 const LOG_WIDTH: usize;
66
67 const WIDTH: usize = 1 << Self::LOG_WIDTH;
71
72 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
76
77 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
81
82 #[inline]
84 fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
85 (i < Self::WIDTH)
86 .then_some(unsafe { self.get_unchecked(i) })
87 .ok_or(Error::IndexOutOfRange {
88 index: i,
89 max: Self::WIDTH,
90 })
91 }
92
93 #[inline]
95 fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
96 (i < Self::WIDTH)
97 .then(|| unsafe { self.set_unchecked(i, scalar) })
98 .ok_or(Error::IndexOutOfRange {
99 index: i,
100 max: Self::WIDTH,
101 })
102 }
103
104 #[inline]
106 fn get(&self, i: usize) -> Self::Scalar {
107 self.get_checked(i).expect("index must be less than width")
108 }
109
110 #[inline]
112 fn set(&mut self, i: usize, scalar: Self::Scalar) {
113 self.set_checked(i, scalar).expect("index must be less than width")
114 }
115
116 #[inline]
117 fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
118 (0..Self::WIDTH).map_skippable(move |i|
119 unsafe { self.get_unchecked(i) })
121 }
122
123 #[inline]
124 fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
125 (0..Self::WIDTH).map_skippable(move |i|
126 unsafe { self.get_unchecked(i) })
128 }
129
130 #[inline]
131 fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
132 slice.iter().flat_map(Self::iter)
133 }
134
135 #[inline]
136 fn zero() -> Self {
137 Self::broadcast(Self::Scalar::ZERO)
138 }
139
140 #[inline]
141 fn one() -> Self {
142 Self::broadcast(Self::Scalar::ONE)
143 }
144
145 #[inline(always)]
147 fn set_single(scalar: Self::Scalar) -> Self {
148 let mut result = Self::default();
149 result.set(0, scalar);
150
151 result
152 }
153
154 fn broadcast(scalar: Self::Scalar) -> Self;
155
156 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
158
159 fn try_from_fn<E>(
161 mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
162 ) -> Result<Self, E> {
163 let mut result = Self::default();
164 for i in 0..Self::WIDTH {
165 let scalar = f(i)?;
166 unsafe {
167 result.set_unchecked(i, scalar);
168 };
169 }
170 Ok(result)
171 }
172
173 #[inline]
179 fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
180 let mut result = Self::default();
181 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
182 result.set(i, val);
183 }
184 result
185 }
186
187 fn square(self) -> Self;
189
190 fn pow(self, exp: u64) -> Self {
192 let mut res = Self::one();
193 for i in (0..64).rev() {
194 res = res.square();
195 if ((exp >> i) & 1) == 1 {
196 res.mul_assign(self)
197 }
198 }
199 res
200 }
201
202 fn invert_or_zero(self) -> Self;
204
205 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
221
222 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
235
236 #[inline]
268 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
269 assert!(log_block_len <= Self::LOG_WIDTH);
270 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
271
272 unsafe {
274 self.spread_unchecked(log_block_len, block_idx)
275 }
276 }
277
278 #[inline]
283 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
284 let block_len = 1 << log_block_len;
285 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
286
287 Self::from_scalars(
288 self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
289 )
290 }
291}
292
293#[inline]
298pub fn iter_packed_slice_with_offset<P: PackedField>(
299 packed: &[P],
300 offset: usize,
301) -> impl Iterator<Item = P::Scalar> + '_ + Send {
302 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
303 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
304 } else {
305 (&[], 0)
306 };
307
308 P::iter_slice(packed).skip(offset)
309}
310
311#[inline(always)]
312pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
313 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
314
315 unsafe { get_packed_slice_unchecked(packed, i) }
316}
317
318#[inline(always)]
322pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
323 if is_packed_field_indexable::<P>() {
324 unsafe { *(packed.as_ptr() as *const P::Scalar).add(i) }
328 } else {
329 unsafe {
334 packed
335 .get_unchecked(i >> P::LOG_WIDTH)
336 .get_unchecked(i % P::WIDTH)
337 }
338 }
339}
340
341#[inline]
342pub fn get_packed_slice_checked<P: PackedField>(
343 packed: &[P],
344 i: usize,
345) -> Result<P::Scalar, Error> {
346 if i >> P::LOG_WIDTH < packed.len() {
347 Ok(unsafe { get_packed_slice_unchecked(packed, i) })
349 } else {
350 Err(Error::IndexOutOfRange {
351 index: i,
352 max: len_packed_slice(packed),
353 })
354 }
355}
356
357#[inline]
361pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
362 packed: &mut [P],
363 i: usize,
364 scalar: P::Scalar,
365) {
366 if is_packed_field_indexable::<P>() {
367 unsafe {
371 *(packed.as_mut_ptr() as *mut P::Scalar).add(i) = scalar;
372 }
373 } else {
374 unsafe {
378 packed
379 .get_unchecked_mut(i >> P::LOG_WIDTH)
380 .set_unchecked(i % P::WIDTH, scalar)
381 }
382 }
383}
384
385#[inline]
386pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
387 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
388
389 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
390}
391
392#[inline]
393pub fn set_packed_slice_checked<P: PackedField>(
394 packed: &mut [P],
395 i: usize,
396 scalar: P::Scalar,
397) -> Result<(), Error> {
398 if i >> P::LOG_WIDTH < packed.len() {
399 unsafe { set_packed_slice_unchecked(packed, i, scalar) };
401 Ok(())
402 } else {
403 Err(Error::IndexOutOfRange {
404 index: i,
405 max: len_packed_slice(packed),
406 })
407 }
408}
409
410#[inline(always)]
411pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
412 packed.len() << P::LOG_WIDTH
413}
414
415#[inline]
419pub fn packed_from_fn_with_offset<P: PackedField>(
420 offset: usize,
421 mut f: impl FnMut(usize) -> P::Scalar,
422) -> P {
423 P::from_fn(|i| f(i + offset * P::WIDTH))
424}
425
426pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
428 use crate::underlier::UnderlierType;
429
430 let subfield_bits = FS::Underlier::BITS;
433 let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
434
435 if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
436 P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
437 } else {
438 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
439 }
440}
441
442pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
444 scalars
445 .chunks(P::WIDTH)
446 .map(|chunk| P::from_scalars(chunk.iter().copied()))
447 .collect()
448}
449
450#[derive(Clone)]
452pub struct PackedSlice<'a, P: PackedField> {
453 slice: &'a [P],
454 len: usize,
455}
456
457impl<'a, P: PackedField> PackedSlice<'a, P> {
458 #[inline(always)]
459 pub fn new(slice: &'a [P]) -> Self {
460 Self {
461 slice,
462 len: len_packed_slice(slice),
463 }
464 }
465
466 #[inline(always)]
467 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
468 assert!(len <= len_packed_slice(slice));
469
470 Self { slice, len }
471 }
472}
473
474impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
475 #[inline(always)]
476 fn len(&self) -> usize {
477 self.len
478 }
479
480 #[inline(always)]
481 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
482 unsafe { get_packed_slice_unchecked(self.slice, index) }
483 }
484}
485
486pub struct PackedSliceMut<'a, P: PackedField> {
488 slice: &'a mut [P],
489 len: usize,
490}
491
492impl<'a, P: PackedField> PackedSliceMut<'a, P> {
493 #[inline(always)]
494 pub fn new(slice: &'a mut [P]) -> Self {
495 let len = len_packed_slice(slice);
496 Self { slice, len }
497 }
498
499 #[inline(always)]
500 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
501 assert!(len <= len_packed_slice(slice));
502
503 Self { slice, len }
504 }
505}
506
507impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
508 #[inline(always)]
509 fn len(&self) -> usize {
510 self.len
511 }
512
513 #[inline(always)]
514 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
515 unsafe { get_packed_slice_unchecked(self.slice, index) }
516 }
517}
518impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
519 #[inline(always)]
520 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
521 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
522 }
523}
524
525impl<F: Field> Broadcast<F> for F {
526 #[inline]
527 fn broadcast(scalar: F) -> Self {
528 scalar
529 }
530}
531
532impl<T: TowerFieldArithmetic> MulAlpha for T {
533 #[inline]
534 fn mul_alpha(self) -> Self {
535 <Self as TowerFieldArithmetic>::multiply_alpha(self)
536 }
537}
538
539impl<F: Field> PackedField for F {
540 type Scalar = F;
541
542 const LOG_WIDTH: usize = 0;
543
544 #[inline]
545 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
546 *self
547 }
548
549 #[inline]
550 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
551 *self = scalar;
552 }
553
554 #[inline]
555 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
556 iter::once(*self)
557 }
558
559 #[inline]
560 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
561 iter::once(self)
562 }
563
564 #[inline]
565 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
566 slice.iter().copied()
567 }
568
569 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
570 panic!("cannot interleave when WIDTH = 1");
571 }
572
573 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
574 panic!("cannot transpose when WIDTH = 1");
575 }
576
577 #[inline]
578 fn broadcast(scalar: Self::Scalar) -> Self {
579 scalar
580 }
581
582 #[inline]
583 fn zero() -> Self {
584 Self::ZERO
585 }
586
587 #[inline]
588 fn one() -> Self {
589 Self::ONE
590 }
591
592 #[inline]
593 fn square(self) -> Self {
594 <Self as Square>::square(self)
595 }
596
597 #[inline]
598 fn invert_or_zero(self) -> Self {
599 <Self as InvertOrZero>::invert_or_zero(self)
600 }
601
602 #[inline]
603 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
604 f(0)
605 }
606
607 #[inline]
608 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
609 self
610 }
611}
612
613pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
615
616impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
617
618#[cfg(test)]
619mod tests {
620 use itertools::Itertools;
621 use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
622
623 use super::*;
624 use crate::{
625 AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedBinaryGhash1x128b,
626 PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
627 arch::{
628 packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*, packed_32::*,
629 packed_64::*, packed_128::*, packed_256::*, packed_512::*, packed_aes_8::*,
630 packed_aes_16::*, packed_aes_32::*, packed_aes_64::*, packed_aes_128::*,
631 packed_aes_256::*, packed_aes_512::*,
632 },
633 };
634
635 trait PackedFieldTest {
636 fn run<P: PackedField>(&self);
637 }
638
639 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
641 test.run::<BinaryField1b>();
643 test.run::<PackedBinaryField1x1b>();
644 test.run::<PackedBinaryField2x1b>();
645 test.run::<PackedBinaryField4x1b>();
646 test.run::<PackedBinaryField8x1b>();
647 test.run::<PackedBinaryField16x1b>();
648 test.run::<PackedBinaryField32x1b>();
649 test.run::<PackedBinaryField64x1b>();
650 test.run::<PackedBinaryField128x1b>();
651 test.run::<PackedBinaryField256x1b>();
652 test.run::<PackedBinaryField512x1b>();
653
654 test.run::<AESTowerField8b>();
656 test.run::<PackedAESBinaryField1x8b>();
657 test.run::<PackedAESBinaryField2x8b>();
658 test.run::<PackedAESBinaryField4x8b>();
659 test.run::<PackedAESBinaryField8x8b>();
660 test.run::<PackedAESBinaryField16x8b>();
661 test.run::<PackedAESBinaryField32x8b>();
662 test.run::<PackedAESBinaryField64x8b>();
663
664 test.run::<BinaryField128bGhash>();
666 test.run::<PackedBinaryGhash1x128b>();
667 test.run::<PackedBinaryGhash2x128b>();
668 test.run::<PackedBinaryGhash4x128b>();
669 }
670
671 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
672 let packed = P::random(&mut rng);
673 let mut iter = packed.iter();
674 for i in 0..P::WIDTH {
675 assert_eq!(packed.get(i), iter.next().unwrap());
676 }
677 assert!(iter.next().is_none());
678 }
679
680 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
681 let packed = P::random(&mut rng);
682 let mut iter = packed.into_iter();
683 for i in 0..P::WIDTH {
684 assert_eq!(packed.get(i), iter.next().unwrap());
685 }
686 assert!(iter.next().is_none());
687 }
688
689 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
690 for len in [0, 1, 5] {
691 let packed = std::iter::repeat_with(|| P::random(&mut rng))
692 .take(len)
693 .collect::<Vec<_>>();
694
695 let elements_count = len * P::WIDTH;
696 for offset in [
697 0,
698 1,
699 rng.random_range(0..elements_count.max(1)),
700 elements_count.saturating_sub(1),
701 elements_count,
702 ] {
703 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
704 let expected = (offset..elements_count)
705 .map(|i| get_packed_slice(&packed, i))
706 .collect::<Vec<_>>();
707
708 assert_eq!(actual, expected);
709 }
710 }
711 }
712
713 struct PackedFieldIterationTest;
714
715 impl PackedFieldTest for PackedFieldIterationTest {
716 fn run<P: PackedField>(&self) {
717 let mut rng = StdRng::seed_from_u64(0);
718
719 check_value_iteration::<P>(&mut rng);
720 check_ref_iteration::<P>(&mut rng);
721 check_slice_iteration::<P>(&mut rng);
722 }
723 }
724
725 #[test]
726 fn test_iteration() {
727 run_for_all_packed_fields(&PackedFieldIterationTest);
728 }
729
730 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
731 assert_eq!(collection.len(), expected.len());
732
733 for (i, v) in expected.iter().enumerate() {
734 assert_eq!(&collection.get(i), v);
735 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
736 }
737 }
738
739 fn check_collection_get_set<F: Field>(
740 collection: &mut impl RandomAccessSequenceMut<F>,
741 random: &mut impl FnMut() -> F,
742 ) {
743 for i in 0..collection.len() {
744 let value = random();
745 collection.set(i, value);
746 assert_eq!(collection.get(i), value);
747 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
748 }
749 }
750
751 #[test]
752 fn check_packed_slice() {
753 let slice: &[PackedAESBinaryField16x8b] = &[];
754 let packed_slice = PackedSlice::new(slice);
755 check_collection(&packed_slice, &[]);
756 let packed_slice = PackedSlice::new_with_len(slice, 0);
757 check_collection(&packed_slice, &[]);
758
759 let mut rng = StdRng::seed_from_u64(0);
760 let slice: &[PackedAESBinaryField16x8b] = &[
761 PackedAESBinaryField16x8b::random(&mut rng),
762 PackedAESBinaryField16x8b::random(&mut rng),
763 ];
764 let packed_slice = PackedSlice::new(slice);
765 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
766
767 let packed_slice = PackedSlice::new_with_len(slice, 3);
768 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
769 }
770
771 #[test]
772 fn check_packed_slice_mut() {
773 let mut rng = StdRng::seed_from_u64(0);
774 let mut random = || AESTowerField8b::random(&mut rng);
775
776 let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
777 let packed_slice = PackedSliceMut::new(slice);
778 check_collection(&packed_slice, &[]);
779 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
780 check_collection(&packed_slice, &[]);
781
782 let mut rng = StdRng::seed_from_u64(0);
783 let slice: &mut [PackedAESBinaryField16x8b] = &mut [
784 PackedAESBinaryField16x8b::random(&mut rng),
785 PackedAESBinaryField16x8b::random(&mut rng),
786 ];
787 let values = PackedField::iter_slice(slice).collect_vec();
788 let mut packed_slice = PackedSliceMut::new(slice);
789 check_collection(&packed_slice, &values);
790 check_collection_get_set(&mut packed_slice, &mut random);
791
792 let values = PackedField::iter_slice(slice).collect_vec();
793 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
794 check_collection(&packed_slice, &values[..3]);
795 check_collection_get_set(&mut packed_slice, &mut random);
796 }
797}