1use std::{
8 fmt::Debug,
9 iter::{self, Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14 iter::IterExtensions,
15 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18use rand::RngCore;
19
20use super::{
21 Error,
22 arithmetic_traits::{Broadcast, MulAlpha, Square},
23 binary_field_arithmetic::TowerFieldArithmetic,
24};
25use crate::{
26 BinaryField, Field, PackedExtension, arithmetic_traits::InvertOrZero,
27 is_packed_field_indexable, underlier::WithUnderlier, unpack_if_possible_mut,
28};
29
30pub trait PackedField:
36 Default
37 + Debug
38 + Clone
39 + Copy
40 + Eq
41 + Sized
42 + Add<Output = Self>
43 + Sub<Output = Self>
44 + Mul<Output = Self>
45 + AddAssign
46 + SubAssign
47 + MulAssign
48 + Add<Self::Scalar, Output = Self>
49 + Sub<Self::Scalar, Output = Self>
50 + Mul<Self::Scalar, Output = Self>
51 + AddAssign<Self::Scalar>
52 + SubAssign<Self::Scalar>
53 + MulAssign<Self::Scalar>
54 + Sum
56 + Product
57 + Send
58 + Sync
59 + Zeroable
60 + 'static
61{
62 type Scalar: Field;
63
64 const LOG_WIDTH: usize;
66
67 const WIDTH: usize = 1 << Self::LOG_WIDTH;
71
72 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
76
77 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
81
82 #[inline]
84 fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
85 (i < Self::WIDTH)
86 .then_some(unsafe { self.get_unchecked(i) })
87 .ok_or(Error::IndexOutOfRange {
88 index: i,
89 max: Self::WIDTH,
90 })
91 }
92
93 #[inline]
95 fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
96 (i < Self::WIDTH)
97 .then(|| unsafe { self.set_unchecked(i, scalar) })
98 .ok_or(Error::IndexOutOfRange {
99 index: i,
100 max: Self::WIDTH,
101 })
102 }
103
104 #[inline]
106 fn get(&self, i: usize) -> Self::Scalar {
107 self.get_checked(i).expect("index must be less than width")
108 }
109
110 #[inline]
112 fn set(&mut self, i: usize, scalar: Self::Scalar) {
113 self.set_checked(i, scalar).expect("index must be less than width")
114 }
115
116 #[inline]
117 fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
118 (0..Self::WIDTH).map_skippable(move |i|
119 unsafe { self.get_unchecked(i) })
121 }
122
123 #[inline]
124 fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
125 (0..Self::WIDTH).map_skippable(move |i|
126 unsafe { self.get_unchecked(i) })
128 }
129
130 #[inline]
131 fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
132 slice.iter().flat_map(Self::iter)
133 }
134
135 #[inline]
136 fn zero() -> Self {
137 Self::broadcast(Self::Scalar::ZERO)
138 }
139
140 #[inline]
141 fn one() -> Self {
142 Self::broadcast(Self::Scalar::ONE)
143 }
144
145 #[inline(always)]
147 fn set_single(scalar: Self::Scalar) -> Self {
148 let mut result = Self::default();
149 result.set(0, scalar);
150
151 result
152 }
153
154 fn random(rng: impl RngCore) -> Self;
155 fn broadcast(scalar: Self::Scalar) -> Self;
156
157 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
159
160 fn try_from_fn<E>(
162 mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
163 ) -> Result<Self, E> {
164 let mut result = Self::default();
165 for i in 0..Self::WIDTH {
166 let scalar = f(i)?;
167 unsafe {
168 result.set_unchecked(i, scalar);
169 };
170 }
171 Ok(result)
172 }
173
174 fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
180 let mut result = Self::default();
181 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
182 result.set(i, val);
183 }
184 result
185 }
186
187 fn square(self) -> Self;
189
190 fn pow(self, exp: u64) -> Self {
192 let mut res = Self::one();
193 for i in (0..64).rev() {
194 res = res.square();
195 if ((exp >> i) & 1) == 1 {
196 res.mul_assign(self)
197 }
198 }
199 res
200 }
201
202 fn invert_or_zero(self) -> Self;
204
205 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
221
222 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
235
236 #[inline]
268 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
269 assert!(log_block_len <= Self::LOG_WIDTH);
270 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
271
272 unsafe {
274 self.spread_unchecked(log_block_len, block_idx)
275 }
276 }
277
278 #[inline]
283 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
284 let block_len = 1 << log_block_len;
285 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
286
287 Self::from_scalars(
288 self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
289 )
290 }
291}
292
293#[inline]
298pub fn iter_packed_slice_with_offset<P: PackedField>(
299 packed: &[P],
300 offset: usize,
301) -> impl Iterator<Item = P::Scalar> + '_ + Send {
302 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
303 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
304 } else {
305 (&[], 0)
306 };
307
308 P::iter_slice(packed).skip(offset)
309}
310
311#[inline(always)]
312pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
313 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
314
315 unsafe { get_packed_slice_unchecked(packed, i) }
316}
317
318#[inline(always)]
322pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
323 if is_packed_field_indexable::<P>() {
324 unsafe { *(packed.as_ptr() as *const P::Scalar).add(i) }
328 } else {
329 unsafe {
334 packed
335 .get_unchecked(i >> P::LOG_WIDTH)
336 .get_unchecked(i % P::WIDTH)
337 }
338 }
339}
340
341#[inline]
342pub fn get_packed_slice_checked<P: PackedField>(
343 packed: &[P],
344 i: usize,
345) -> Result<P::Scalar, Error> {
346 if i >> P::LOG_WIDTH < packed.len() {
347 Ok(unsafe { get_packed_slice_unchecked(packed, i) })
349 } else {
350 Err(Error::IndexOutOfRange {
351 index: i,
352 max: len_packed_slice(packed),
353 })
354 }
355}
356
357#[inline]
361pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
362 packed: &mut [P],
363 i: usize,
364 scalar: P::Scalar,
365) {
366 if is_packed_field_indexable::<P>() {
367 unsafe {
371 *(packed.as_mut_ptr() as *mut P::Scalar).add(i) = scalar;
372 }
373 } else {
374 unsafe {
378 packed
379 .get_unchecked_mut(i >> P::LOG_WIDTH)
380 .set_unchecked(i % P::WIDTH, scalar)
381 }
382 }
383}
384
385#[inline]
386pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
387 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
388
389 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
390}
391
392#[inline]
393pub fn set_packed_slice_checked<P: PackedField>(
394 packed: &mut [P],
395 i: usize,
396 scalar: P::Scalar,
397) -> Result<(), Error> {
398 if i >> P::LOG_WIDTH < packed.len() {
399 unsafe { set_packed_slice_unchecked(packed, i, scalar) };
401 Ok(())
402 } else {
403 Err(Error::IndexOutOfRange {
404 index: i,
405 max: len_packed_slice(packed),
406 })
407 }
408}
409
410#[inline(always)]
411pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
412 packed.len() << P::LOG_WIDTH
413}
414
415#[inline]
419pub fn packed_from_fn_with_offset<P: PackedField>(
420 offset: usize,
421 mut f: impl FnMut(usize) -> P::Scalar,
422) -> P {
423 P::from_fn(|i| f(i + offset * P::WIDTH))
424}
425
426pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
428 use crate::underlier::UnderlierType;
429
430 let subfield_bits = FS::Underlier::BITS;
433 let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
434
435 if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
436 P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
437 } else {
438 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
439 }
440}
441
442pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
444 scalars
445 .chunks(P::WIDTH)
446 .map(|chunk| P::from_scalars(chunk.iter().copied()))
447 .collect()
448}
449
450pub fn copy_packed_from_scalars_slice<P: PackedField>(src: &[P::Scalar], dst: &mut [P]) {
452 unpack_if_possible_mut(
453 dst,
454 |scalars| {
455 scalars[0..src.len()].copy_from_slice(src);
456 },
457 |packed| {
458 let chunks = src.chunks_exact(P::WIDTH);
459 let remainder = chunks.remainder();
460 for (chunk, packed) in chunks.zip(packed.iter_mut()) {
461 *packed = P::from_scalars(chunk.iter().copied());
462 }
463
464 if !remainder.is_empty() {
465 let offset = (src.len() >> P::LOG_WIDTH) << P::LOG_WIDTH;
466 let packed = &mut packed[offset];
467 for (i, scalar) in remainder.iter().enumerate() {
468 unsafe { packed.set_unchecked(i, *scalar) };
470 }
471 }
472 },
473 );
474}
475
476#[derive(Clone)]
478pub struct PackedSlice<'a, P: PackedField> {
479 slice: &'a [P],
480 len: usize,
481}
482
483impl<'a, P: PackedField> PackedSlice<'a, P> {
484 #[inline(always)]
485 pub fn new(slice: &'a [P]) -> Self {
486 Self {
487 slice,
488 len: len_packed_slice(slice),
489 }
490 }
491
492 #[inline(always)]
493 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
494 assert!(len <= len_packed_slice(slice));
495
496 Self { slice, len }
497 }
498}
499
500impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
501 #[inline(always)]
502 fn len(&self) -> usize {
503 self.len
504 }
505
506 #[inline(always)]
507 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
508 unsafe { get_packed_slice_unchecked(self.slice, index) }
509 }
510}
511
512pub struct PackedSliceMut<'a, P: PackedField> {
514 slice: &'a mut [P],
515 len: usize,
516}
517
518impl<'a, P: PackedField> PackedSliceMut<'a, P> {
519 #[inline(always)]
520 pub fn new(slice: &'a mut [P]) -> Self {
521 let len = len_packed_slice(slice);
522 Self { slice, len }
523 }
524
525 #[inline(always)]
526 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
527 assert!(len <= len_packed_slice(slice));
528
529 Self { slice, len }
530 }
531}
532
533impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
534 #[inline(always)]
535 fn len(&self) -> usize {
536 self.len
537 }
538
539 #[inline(always)]
540 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
541 unsafe { get_packed_slice_unchecked(self.slice, index) }
542 }
543}
544impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
545 #[inline(always)]
546 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
547 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
548 }
549}
550
551impl<F: Field> Broadcast<F> for F {
552 fn broadcast(scalar: F) -> Self {
553 scalar
554 }
555}
556
557impl<T: TowerFieldArithmetic> MulAlpha for T {
558 #[inline]
559 fn mul_alpha(self) -> Self {
560 <Self as TowerFieldArithmetic>::multiply_alpha(self)
561 }
562}
563
564impl<F: Field> PackedField for F {
565 type Scalar = F;
566
567 const LOG_WIDTH: usize = 0;
568
569 #[inline]
570 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
571 *self
572 }
573
574 #[inline]
575 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
576 *self = scalar;
577 }
578
579 #[inline]
580 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
581 iter::once(*self)
582 }
583
584 #[inline]
585 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
586 iter::once(self)
587 }
588
589 #[inline]
590 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
591 slice.iter().copied()
592 }
593
594 fn random(rng: impl RngCore) -> Self {
595 <Self as Field>::random(rng)
596 }
597
598 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
599 panic!("cannot interleave when WIDTH = 1");
600 }
601
602 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
603 panic!("cannot transpose when WIDTH = 1");
604 }
605
606 fn broadcast(scalar: Self::Scalar) -> Self {
607 scalar
608 }
609
610 fn square(self) -> Self {
611 <Self as Square>::square(self)
612 }
613
614 fn invert_or_zero(self) -> Self {
615 <Self as InvertOrZero>::invert_or_zero(self)
616 }
617
618 #[inline]
619 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
620 f(0)
621 }
622
623 #[inline]
624 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
625 self
626 }
627}
628
629pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
631
632impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
633
634#[cfg(test)]
635mod tests {
636 use itertools::Itertools;
637 use rand::{Rng, SeedableRng, rngs::StdRng};
638
639 use super::*;
640 use crate::{
641 AESTowerField8b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField128b,
642 BinaryField1b, BinaryField2b, BinaryField4b, BinaryField8b, BinaryField16b, BinaryField32b,
643 BinaryField64b, BinaryField128b, BinaryField128bPolyval, PackedField,
644 arch::{
645 byte_sliced::*, packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*,
646 packed_32::*, packed_64::*, packed_128::*, packed_256::*, packed_512::*,
647 packed_aes_8::*, packed_aes_16::*, packed_aes_32::*, packed_aes_64::*,
648 packed_aes_128::*, packed_aes_256::*, packed_aes_512::*, packed_polyval_128::*,
649 packed_polyval_256::*, packed_polyval_512::*,
650 },
651 };
652
653 trait PackedFieldTest {
654 fn run<P: PackedField>(&self);
655 }
656
657 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
659 test.run::<BinaryField1b>();
661 test.run::<BinaryField2b>();
662 test.run::<BinaryField4b>();
663 test.run::<BinaryField8b>();
664 test.run::<BinaryField16b>();
665 test.run::<BinaryField32b>();
666 test.run::<BinaryField64b>();
667 test.run::<BinaryField128b>();
668
669 test.run::<PackedBinaryField1x1b>();
671 test.run::<PackedBinaryField2x1b>();
672 test.run::<PackedBinaryField1x2b>();
673 test.run::<PackedBinaryField4x1b>();
674 test.run::<PackedBinaryField2x2b>();
675 test.run::<PackedBinaryField1x4b>();
676 test.run::<PackedBinaryField8x1b>();
677 test.run::<PackedBinaryField4x2b>();
678 test.run::<PackedBinaryField2x4b>();
679 test.run::<PackedBinaryField1x8b>();
680 test.run::<PackedBinaryField16x1b>();
681 test.run::<PackedBinaryField8x2b>();
682 test.run::<PackedBinaryField4x4b>();
683 test.run::<PackedBinaryField2x8b>();
684 test.run::<PackedBinaryField1x16b>();
685 test.run::<PackedBinaryField32x1b>();
686 test.run::<PackedBinaryField16x2b>();
687 test.run::<PackedBinaryField8x4b>();
688 test.run::<PackedBinaryField4x8b>();
689 test.run::<PackedBinaryField2x16b>();
690 test.run::<PackedBinaryField1x32b>();
691 test.run::<PackedBinaryField64x1b>();
692 test.run::<PackedBinaryField32x2b>();
693 test.run::<PackedBinaryField16x4b>();
694 test.run::<PackedBinaryField8x8b>();
695 test.run::<PackedBinaryField4x16b>();
696 test.run::<PackedBinaryField2x32b>();
697 test.run::<PackedBinaryField1x64b>();
698 test.run::<PackedBinaryField128x1b>();
699 test.run::<PackedBinaryField64x2b>();
700 test.run::<PackedBinaryField32x4b>();
701 test.run::<PackedBinaryField16x8b>();
702 test.run::<PackedBinaryField8x16b>();
703 test.run::<PackedBinaryField4x32b>();
704 test.run::<PackedBinaryField2x64b>();
705 test.run::<PackedBinaryField1x128b>();
706 test.run::<PackedBinaryField256x1b>();
707 test.run::<PackedBinaryField128x2b>();
708 test.run::<PackedBinaryField64x4b>();
709 test.run::<PackedBinaryField32x8b>();
710 test.run::<PackedBinaryField16x16b>();
711 test.run::<PackedBinaryField8x32b>();
712 test.run::<PackedBinaryField4x64b>();
713 test.run::<PackedBinaryField2x128b>();
714 test.run::<PackedBinaryField512x1b>();
715 test.run::<PackedBinaryField256x2b>();
716 test.run::<PackedBinaryField128x4b>();
717 test.run::<PackedBinaryField64x8b>();
718 test.run::<PackedBinaryField32x16b>();
719 test.run::<PackedBinaryField16x32b>();
720 test.run::<PackedBinaryField8x64b>();
721 test.run::<PackedBinaryField4x128b>();
722
723 test.run::<AESTowerField8b>();
725 test.run::<AESTowerField16b>();
726 test.run::<AESTowerField32b>();
727 test.run::<AESTowerField64b>();
728 test.run::<AESTowerField128b>();
729
730 test.run::<PackedAESBinaryField1x8b>();
732 test.run::<PackedAESBinaryField2x8b>();
733 test.run::<PackedAESBinaryField1x16b>();
734 test.run::<PackedAESBinaryField4x8b>();
735 test.run::<PackedAESBinaryField2x16b>();
736 test.run::<PackedAESBinaryField1x32b>();
737 test.run::<PackedAESBinaryField8x8b>();
738 test.run::<PackedAESBinaryField4x16b>();
739 test.run::<PackedAESBinaryField2x32b>();
740 test.run::<PackedAESBinaryField1x64b>();
741 test.run::<PackedAESBinaryField16x8b>();
742 test.run::<PackedAESBinaryField8x16b>();
743 test.run::<PackedAESBinaryField4x32b>();
744 test.run::<PackedAESBinaryField2x64b>();
745 test.run::<PackedAESBinaryField1x128b>();
746 test.run::<PackedAESBinaryField32x8b>();
747 test.run::<PackedAESBinaryField16x16b>();
748 test.run::<PackedAESBinaryField8x32b>();
749 test.run::<PackedAESBinaryField4x64b>();
750 test.run::<PackedAESBinaryField2x128b>();
751 test.run::<PackedAESBinaryField64x8b>();
752 test.run::<PackedAESBinaryField32x16b>();
753 test.run::<PackedAESBinaryField16x32b>();
754 test.run::<PackedAESBinaryField8x64b>();
755 test.run::<PackedAESBinaryField4x128b>();
756
757 test.run::<ByteSlicedAES16x128b>();
759 test.run::<ByteSlicedAES16x64b>();
760 test.run::<ByteSlicedAES2x16x64b>();
761 test.run::<ByteSlicedAES16x32b>();
762 test.run::<ByteSlicedAES4x16x32b>();
763 test.run::<ByteSlicedAES16x16b>();
764 test.run::<ByteSlicedAES8x16x16b>();
765 test.run::<ByteSlicedAES16x8b>();
766 test.run::<ByteSlicedAES16x16x8b>();
767
768 test.run::<ByteSliced16x128x1b>();
769 test.run::<ByteSliced8x128x1b>();
770 test.run::<ByteSliced4x128x1b>();
771 test.run::<ByteSliced2x128x1b>();
772 test.run::<ByteSliced1x128x1b>();
773
774 test.run::<ByteSlicedAES32x128b>();
775 test.run::<ByteSlicedAES32x64b>();
776 test.run::<ByteSlicedAES2x32x64b>();
777 test.run::<ByteSlicedAES32x32b>();
778 test.run::<ByteSlicedAES4x32x32b>();
779 test.run::<ByteSlicedAES32x16b>();
780 test.run::<ByteSlicedAES8x32x16b>();
781 test.run::<ByteSlicedAES32x8b>();
782 test.run::<ByteSlicedAES16x32x8b>();
783
784 test.run::<ByteSliced16x256x1b>();
785 test.run::<ByteSliced8x256x1b>();
786 test.run::<ByteSliced4x256x1b>();
787 test.run::<ByteSliced2x256x1b>();
788 test.run::<ByteSliced1x256x1b>();
789
790 test.run::<ByteSlicedAES64x128b>();
791 test.run::<ByteSlicedAES64x64b>();
792 test.run::<ByteSlicedAES2x64x64b>();
793 test.run::<ByteSlicedAES64x32b>();
794 test.run::<ByteSlicedAES4x64x32b>();
795 test.run::<ByteSlicedAES64x16b>();
796 test.run::<ByteSlicedAES8x64x16b>();
797 test.run::<ByteSlicedAES64x8b>();
798 test.run::<ByteSlicedAES16x64x8b>();
799
800 test.run::<ByteSliced16x512x1b>();
801 test.run::<ByteSliced8x512x1b>();
802 test.run::<ByteSliced4x512x1b>();
803 test.run::<ByteSliced2x512x1b>();
804 test.run::<ByteSliced1x512x1b>();
805
806 test.run::<BinaryField128bPolyval>();
808
809 test.run::<PackedBinaryPolyval1x128b>();
811 test.run::<PackedBinaryPolyval2x128b>();
812 test.run::<PackedBinaryPolyval4x128b>();
813 }
814
815 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
816 let packed = P::random(&mut rng);
817 let mut iter = packed.iter();
818 for i in 0..P::WIDTH {
819 assert_eq!(packed.get(i), iter.next().unwrap());
820 }
821 assert!(iter.next().is_none());
822 }
823
824 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
825 let packed = P::random(&mut rng);
826 let mut iter = packed.into_iter();
827 for i in 0..P::WIDTH {
828 assert_eq!(packed.get(i), iter.next().unwrap());
829 }
830 assert!(iter.next().is_none());
831 }
832
833 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
834 for len in [0, 1, 5] {
835 let packed = std::iter::repeat_with(|| P::random(&mut rng))
836 .take(len)
837 .collect::<Vec<_>>();
838
839 let elements_count = len * P::WIDTH;
840 for offset in [
841 0,
842 1,
843 rng.random_range(0..elements_count.max(1)),
844 elements_count.saturating_sub(1),
845 elements_count,
846 ] {
847 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
848 let expected = (offset..elements_count)
849 .map(|i| get_packed_slice(&packed, i))
850 .collect::<Vec<_>>();
851
852 assert_eq!(actual, expected);
853 }
854 }
855 }
856
857 struct PackedFieldIterationTest;
858
859 impl PackedFieldTest for PackedFieldIterationTest {
860 fn run<P: PackedField>(&self) {
861 let mut rng = StdRng::seed_from_u64(0);
862
863 check_value_iteration::<P>(&mut rng);
864 check_ref_iteration::<P>(&mut rng);
865 check_slice_iteration::<P>(&mut rng);
866 }
867 }
868
869 #[test]
870 fn test_iteration() {
871 run_for_all_packed_fields(&PackedFieldIterationTest);
872 }
873
874 fn check_copy_from_scalars<P: PackedField>(mut rng: impl RngCore) {
875 let scalars = (0..100)
876 .map(|_| <<P as PackedField>::Scalar as Field>::random(&mut rng))
877 .collect::<Vec<_>>();
878
879 let mut packed_copy = vec![P::zero(); 100];
880
881 for len in [0, 2, 4, 8, 12, 16] {
882 copy_packed_from_scalars_slice(&scalars[0..len], &mut packed_copy);
883
884 for (i, &scalar) in scalars[0..len].iter().enumerate() {
885 assert_eq!(get_packed_slice(&packed_copy, i), scalar);
886 }
887 for i in len..100 {
888 assert_eq!(get_packed_slice(&packed_copy, i), P::Scalar::ZERO);
889 }
890 }
891 }
892
893 #[test]
894 fn test_copy_from_scalars() {
895 let mut rng = StdRng::seed_from_u64(0);
896
897 check_copy_from_scalars::<PackedBinaryField16x8b>(&mut rng);
898 check_copy_from_scalars::<PackedBinaryField32x4b>(&mut rng);
899 }
900
901 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
902 assert_eq!(collection.len(), expected.len());
903
904 for (i, v) in expected.iter().enumerate() {
905 assert_eq!(&collection.get(i), v);
906 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
907 }
908 }
909
910 fn check_collection_get_set<F: Field>(
911 collection: &mut impl RandomAccessSequenceMut<F>,
912 random: &mut impl FnMut() -> F,
913 ) {
914 for i in 0..collection.len() {
915 let value = random();
916 collection.set(i, value);
917 assert_eq!(collection.get(i), value);
918 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
919 }
920 }
921
922 #[test]
923 fn check_packed_slice() {
924 let slice: &[PackedBinaryField16x8b] = &[];
925 let packed_slice = PackedSlice::new(slice);
926 check_collection(&packed_slice, &[]);
927 let packed_slice = PackedSlice::new_with_len(slice, 0);
928 check_collection(&packed_slice, &[]);
929
930 let mut rng = StdRng::seed_from_u64(0);
931 let slice: &[PackedBinaryField16x8b] = &[
932 PackedBinaryField16x8b::random(&mut rng),
933 PackedBinaryField16x8b::random(&mut rng),
934 ];
935 let packed_slice = PackedSlice::new(slice);
936 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
937
938 let packed_slice = PackedSlice::new_with_len(slice, 3);
939 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
940 }
941
942 #[test]
943 fn check_packed_slice_mut() {
944 let mut rng = StdRng::seed_from_u64(0);
945 let mut random = || <BinaryField8b as Field>::random(&mut rng);
946
947 let slice: &mut [PackedBinaryField16x8b] = &mut [];
948 let packed_slice = PackedSliceMut::new(slice);
949 check_collection(&packed_slice, &[]);
950 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
951 check_collection(&packed_slice, &[]);
952
953 let mut rng = StdRng::seed_from_u64(0);
954 let slice: &mut [PackedBinaryField16x8b] = &mut [
955 PackedBinaryField16x8b::random(&mut rng),
956 PackedBinaryField16x8b::random(&mut rng),
957 ];
958 let values = PackedField::iter_slice(slice).collect_vec();
959 let mut packed_slice = PackedSliceMut::new(slice);
960 check_collection(&packed_slice, &values);
961 check_collection_get_set(&mut packed_slice, &mut random);
962
963 let values = PackedField::iter_slice(slice).collect_vec();
964 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
965 check_collection(&packed_slice, &values[..3]);
966 check_collection_get_set(&mut packed_slice, &mut random);
967 }
968}