1use std::{
8 fmt::Debug,
9 iter::{self, Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14 iter::IterExtensions,
15 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18use rand::RngCore;
19
20use super::{
21 Error,
22 arithmetic_traits::{Broadcast, MulAlpha, Square},
23 binary_field_arithmetic::TowerFieldArithmetic,
24};
25use crate::{
26 BinaryField, Field, PackedExtension, arithmetic_traits::InvertOrZero,
27 is_packed_field_indexable, underlier::WithUnderlier, unpack_if_possible_mut,
28};
29
30pub trait PackedField:
36 Default
37 + Debug
38 + Clone
39 + Copy
40 + Eq
41 + Sized
42 + Add<Output = Self>
43 + Sub<Output = Self>
44 + Mul<Output = Self>
45 + AddAssign
46 + SubAssign
47 + MulAssign
48 + Add<Self::Scalar, Output = Self>
49 + Sub<Self::Scalar, Output = Self>
50 + Mul<Self::Scalar, Output = Self>
51 + AddAssign<Self::Scalar>
52 + SubAssign<Self::Scalar>
53 + MulAssign<Self::Scalar>
54 + Sum
56 + Product
57 + Send
58 + Sync
59 + Zeroable
60 + 'static
61{
62 type Scalar: Field;
63
64 const LOG_WIDTH: usize;
66
67 const WIDTH: usize = 1 << Self::LOG_WIDTH;
71
72 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
76
77 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
81
82 #[inline]
84 fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
85 (i < Self::WIDTH)
86 .then_some(unsafe { self.get_unchecked(i) })
87 .ok_or(Error::IndexOutOfRange {
88 index: i,
89 max: Self::WIDTH,
90 })
91 }
92
93 #[inline]
95 fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
96 (i < Self::WIDTH)
97 .then(|| unsafe { self.set_unchecked(i, scalar) })
98 .ok_or(Error::IndexOutOfRange {
99 index: i,
100 max: Self::WIDTH,
101 })
102 }
103
104 #[inline]
106 fn get(&self, i: usize) -> Self::Scalar {
107 self.get_checked(i).expect("index must be less than width")
108 }
109
110 #[inline]
112 fn set(&mut self, i: usize, scalar: Self::Scalar) {
113 self.set_checked(i, scalar).expect("index must be less than width")
114 }
115
116 #[inline]
117 fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
118 (0..Self::WIDTH).map_skippable(move |i|
119 unsafe { self.get_unchecked(i) })
121 }
122
123 #[inline]
124 fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
125 (0..Self::WIDTH).map_skippable(move |i|
126 unsafe { self.get_unchecked(i) })
128 }
129
130 #[inline]
131 fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
132 slice.iter().flat_map(Self::iter)
133 }
134
135 #[inline]
136 fn zero() -> Self {
137 Self::broadcast(Self::Scalar::ZERO)
138 }
139
140 #[inline]
141 fn one() -> Self {
142 Self::broadcast(Self::Scalar::ONE)
143 }
144
145 #[inline(always)]
147 fn set_single(scalar: Self::Scalar) -> Self {
148 let mut result = Self::default();
149 result.set(0, scalar);
150
151 result
152 }
153
154 fn random(rng: impl RngCore) -> Self;
155 fn broadcast(scalar: Self::Scalar) -> Self;
156
157 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
159
160 fn try_from_fn<E>(
162 mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
163 ) -> Result<Self, E> {
164 let mut result = Self::default();
165 for i in 0..Self::WIDTH {
166 let scalar = f(i)?;
167 unsafe {
168 result.set_unchecked(i, scalar);
169 };
170 }
171 Ok(result)
172 }
173
174 fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
180 let mut result = Self::default();
181 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
182 result.set(i, val);
183 }
184 result
185 }
186
187 fn square(self) -> Self;
189
190 fn pow(self, exp: u64) -> Self {
192 let mut res = Self::one();
193 for i in (0..64).rev() {
194 res = res.square();
195 if ((exp >> i) & 1) == 1 {
196 res.mul_assign(self)
197 }
198 }
199 res
200 }
201
202 fn invert_or_zero(self) -> Self;
204
205 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
221
222 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
235
236 #[inline]
268 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
269 assert!(log_block_len <= Self::LOG_WIDTH);
270 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
271
272 unsafe {
274 self.spread_unchecked(log_block_len, block_idx)
275 }
276 }
277
278 #[inline]
283 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
284 let block_len = 1 << log_block_len;
285 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
286
287 Self::from_scalars(
288 self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
289 )
290 }
291}
292
293#[inline]
298pub fn iter_packed_slice_with_offset<P: PackedField>(
299 packed: &[P],
300 offset: usize,
301) -> impl Iterator<Item = P::Scalar> + '_ + Send {
302 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
303 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
304 } else {
305 (&[], 0)
306 };
307
308 P::iter_slice(packed).skip(offset)
309}
310
311#[inline(always)]
312pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
313 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
314
315 unsafe { get_packed_slice_unchecked(packed, i) }
316}
317
318#[inline(always)]
322pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
323 if is_packed_field_indexable::<P>() {
324 unsafe { *(packed.as_ptr() as *const P::Scalar).add(i) }
328 } else {
329 unsafe {
334 packed
335 .get_unchecked(i >> P::LOG_WIDTH)
336 .get_unchecked(i % P::WIDTH)
337 }
338 }
339}
340
341#[inline]
342pub fn get_packed_slice_checked<P: PackedField>(
343 packed: &[P],
344 i: usize,
345) -> Result<P::Scalar, Error> {
346 if i >> P::LOG_WIDTH < packed.len() {
347 Ok(unsafe { get_packed_slice_unchecked(packed, i) })
349 } else {
350 Err(Error::IndexOutOfRange {
351 index: i,
352 max: len_packed_slice(packed),
353 })
354 }
355}
356
357#[inline]
361pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
362 packed: &mut [P],
363 i: usize,
364 scalar: P::Scalar,
365) {
366 if is_packed_field_indexable::<P>() {
367 unsafe {
371 *(packed.as_mut_ptr() as *mut P::Scalar).add(i) = scalar;
372 }
373 } else {
374 unsafe {
378 packed
379 .get_unchecked_mut(i >> P::LOG_WIDTH)
380 .set_unchecked(i % P::WIDTH, scalar)
381 }
382 }
383}
384
385#[inline]
386pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
387 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
388
389 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
390}
391
392#[inline]
393pub fn set_packed_slice_checked<P: PackedField>(
394 packed: &mut [P],
395 i: usize,
396 scalar: P::Scalar,
397) -> Result<(), Error> {
398 if i >> P::LOG_WIDTH < packed.len() {
399 unsafe { set_packed_slice_unchecked(packed, i, scalar) };
401 Ok(())
402 } else {
403 Err(Error::IndexOutOfRange {
404 index: i,
405 max: len_packed_slice(packed),
406 })
407 }
408}
409
410#[inline(always)]
411pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
412 packed.len() << P::LOG_WIDTH
413}
414
415#[inline]
419pub fn packed_from_fn_with_offset<P: PackedField>(
420 offset: usize,
421 mut f: impl FnMut(usize) -> P::Scalar,
422) -> P {
423 P::from_fn(|i| f(i + offset * P::WIDTH))
424}
425
426pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
428 use crate::underlier::UnderlierType;
429
430 let subfield_bits = FS::Underlier::BITS;
433 let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
434
435 if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
436 P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
437 } else {
438 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
439 }
440}
441
442pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
444 scalars
445 .chunks(P::WIDTH)
446 .map(|chunk| P::from_scalars(chunk.iter().copied()))
447 .collect()
448}
449
450pub fn copy_packed_from_scalars_slice<P: PackedField>(src: &[P::Scalar], dst: &mut [P]) {
452 unpack_if_possible_mut(
453 dst,
454 |scalars| {
455 scalars[0..src.len()].copy_from_slice(src);
456 },
457 |packed| {
458 let chunks = src.chunks_exact(P::WIDTH);
459 let remainder = chunks.remainder();
460 for (chunk, packed) in chunks.zip(packed.iter_mut()) {
461 *packed = P::from_scalars(chunk.iter().copied());
462 }
463
464 if !remainder.is_empty() {
465 let offset = (src.len() >> P::LOG_WIDTH) << P::LOG_WIDTH;
466 let packed = &mut packed[offset];
467 for (i, scalar) in remainder.iter().enumerate() {
468 unsafe { packed.set_unchecked(i, *scalar) };
470 }
471 }
472 },
473 );
474}
475
476#[derive(Clone)]
478pub struct PackedSlice<'a, P: PackedField> {
479 slice: &'a [P],
480 len: usize,
481}
482
483impl<'a, P: PackedField> PackedSlice<'a, P> {
484 #[inline(always)]
485 pub fn new(slice: &'a [P]) -> Self {
486 Self {
487 slice,
488 len: len_packed_slice(slice),
489 }
490 }
491
492 #[inline(always)]
493 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
494 assert!(len <= len_packed_slice(slice));
495
496 Self { slice, len }
497 }
498}
499
500impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
501 #[inline(always)]
502 fn len(&self) -> usize {
503 self.len
504 }
505
506 #[inline(always)]
507 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
508 unsafe { get_packed_slice_unchecked(self.slice, index) }
509 }
510}
511
512pub struct PackedSliceMut<'a, P: PackedField> {
514 slice: &'a mut [P],
515 len: usize,
516}
517
518impl<'a, P: PackedField> PackedSliceMut<'a, P> {
519 #[inline(always)]
520 pub fn new(slice: &'a mut [P]) -> Self {
521 let len = len_packed_slice(slice);
522 Self { slice, len }
523 }
524
525 #[inline(always)]
526 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
527 assert!(len <= len_packed_slice(slice));
528
529 Self { slice, len }
530 }
531}
532
533impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
534 #[inline(always)]
535 fn len(&self) -> usize {
536 self.len
537 }
538
539 #[inline(always)]
540 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
541 unsafe { get_packed_slice_unchecked(self.slice, index) }
542 }
543}
544impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
545 #[inline(always)]
546 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
547 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
548 }
549}
550
551impl<F: Field> Broadcast<F> for F {
552 fn broadcast(scalar: F) -> Self {
553 scalar
554 }
555}
556
557impl<T: TowerFieldArithmetic> MulAlpha for T {
558 #[inline]
559 fn mul_alpha(self) -> Self {
560 <Self as TowerFieldArithmetic>::multiply_alpha(self)
561 }
562}
563
564impl<F: Field> PackedField for F {
565 type Scalar = F;
566
567 const LOG_WIDTH: usize = 0;
568
569 #[inline]
570 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
571 *self
572 }
573
574 #[inline]
575 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
576 *self = scalar;
577 }
578
579 #[inline]
580 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
581 iter::once(*self)
582 }
583
584 #[inline]
585 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
586 iter::once(self)
587 }
588
589 #[inline]
590 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
591 slice.iter().copied()
592 }
593
594 fn random(rng: impl RngCore) -> Self {
595 <Self as Field>::random(rng)
596 }
597
598 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
599 panic!("cannot interleave when WIDTH = 1");
600 }
601
602 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
603 panic!("cannot transpose when WIDTH = 1");
604 }
605
606 fn broadcast(scalar: Self::Scalar) -> Self {
607 scalar
608 }
609
610 fn square(self) -> Self {
611 <Self as Square>::square(self)
612 }
613
614 fn invert_or_zero(self) -> Self {
615 <Self as InvertOrZero>::invert_or_zero(self)
616 }
617
618 #[inline]
619 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
620 f(0)
621 }
622
623 #[inline]
624 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
625 self
626 }
627}
628
629pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
631
632impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
633
634#[cfg(test)]
635mod tests {
636 use itertools::Itertools;
637 use rand::{
638 SeedableRng,
639 distributions::{Distribution, Uniform},
640 rngs::StdRng,
641 };
642
643 use super::*;
644 use crate::{
645 AESTowerField8b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField128b,
646 BinaryField1b, BinaryField2b, BinaryField4b, BinaryField8b, BinaryField16b, BinaryField32b,
647 BinaryField64b, BinaryField128b, BinaryField128bPolyval, PackedField,
648 arch::{
649 byte_sliced::*, packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*,
650 packed_32::*, packed_64::*, packed_128::*, packed_256::*, packed_512::*,
651 packed_aes_8::*, packed_aes_16::*, packed_aes_32::*, packed_aes_64::*,
652 packed_aes_128::*, packed_aes_256::*, packed_aes_512::*, packed_polyval_128::*,
653 packed_polyval_256::*, packed_polyval_512::*,
654 },
655 };
656
657 trait PackedFieldTest {
658 fn run<P: PackedField>(&self);
659 }
660
661 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
663 test.run::<BinaryField1b>();
665 test.run::<BinaryField2b>();
666 test.run::<BinaryField4b>();
667 test.run::<BinaryField8b>();
668 test.run::<BinaryField16b>();
669 test.run::<BinaryField32b>();
670 test.run::<BinaryField64b>();
671 test.run::<BinaryField128b>();
672
673 test.run::<PackedBinaryField1x1b>();
675 test.run::<PackedBinaryField2x1b>();
676 test.run::<PackedBinaryField1x2b>();
677 test.run::<PackedBinaryField4x1b>();
678 test.run::<PackedBinaryField2x2b>();
679 test.run::<PackedBinaryField1x4b>();
680 test.run::<PackedBinaryField8x1b>();
681 test.run::<PackedBinaryField4x2b>();
682 test.run::<PackedBinaryField2x4b>();
683 test.run::<PackedBinaryField1x8b>();
684 test.run::<PackedBinaryField16x1b>();
685 test.run::<PackedBinaryField8x2b>();
686 test.run::<PackedBinaryField4x4b>();
687 test.run::<PackedBinaryField2x8b>();
688 test.run::<PackedBinaryField1x16b>();
689 test.run::<PackedBinaryField32x1b>();
690 test.run::<PackedBinaryField16x2b>();
691 test.run::<PackedBinaryField8x4b>();
692 test.run::<PackedBinaryField4x8b>();
693 test.run::<PackedBinaryField2x16b>();
694 test.run::<PackedBinaryField1x32b>();
695 test.run::<PackedBinaryField64x1b>();
696 test.run::<PackedBinaryField32x2b>();
697 test.run::<PackedBinaryField16x4b>();
698 test.run::<PackedBinaryField8x8b>();
699 test.run::<PackedBinaryField4x16b>();
700 test.run::<PackedBinaryField2x32b>();
701 test.run::<PackedBinaryField1x64b>();
702 test.run::<PackedBinaryField128x1b>();
703 test.run::<PackedBinaryField64x2b>();
704 test.run::<PackedBinaryField32x4b>();
705 test.run::<PackedBinaryField16x8b>();
706 test.run::<PackedBinaryField8x16b>();
707 test.run::<PackedBinaryField4x32b>();
708 test.run::<PackedBinaryField2x64b>();
709 test.run::<PackedBinaryField1x128b>();
710 test.run::<PackedBinaryField256x1b>();
711 test.run::<PackedBinaryField128x2b>();
712 test.run::<PackedBinaryField64x4b>();
713 test.run::<PackedBinaryField32x8b>();
714 test.run::<PackedBinaryField16x16b>();
715 test.run::<PackedBinaryField8x32b>();
716 test.run::<PackedBinaryField4x64b>();
717 test.run::<PackedBinaryField2x128b>();
718 test.run::<PackedBinaryField512x1b>();
719 test.run::<PackedBinaryField256x2b>();
720 test.run::<PackedBinaryField128x4b>();
721 test.run::<PackedBinaryField64x8b>();
722 test.run::<PackedBinaryField32x16b>();
723 test.run::<PackedBinaryField16x32b>();
724 test.run::<PackedBinaryField8x64b>();
725 test.run::<PackedBinaryField4x128b>();
726
727 test.run::<AESTowerField8b>();
729 test.run::<AESTowerField16b>();
730 test.run::<AESTowerField32b>();
731 test.run::<AESTowerField64b>();
732 test.run::<AESTowerField128b>();
733
734 test.run::<PackedAESBinaryField1x8b>();
736 test.run::<PackedAESBinaryField2x8b>();
737 test.run::<PackedAESBinaryField1x16b>();
738 test.run::<PackedAESBinaryField4x8b>();
739 test.run::<PackedAESBinaryField2x16b>();
740 test.run::<PackedAESBinaryField1x32b>();
741 test.run::<PackedAESBinaryField8x8b>();
742 test.run::<PackedAESBinaryField4x16b>();
743 test.run::<PackedAESBinaryField2x32b>();
744 test.run::<PackedAESBinaryField1x64b>();
745 test.run::<PackedAESBinaryField16x8b>();
746 test.run::<PackedAESBinaryField8x16b>();
747 test.run::<PackedAESBinaryField4x32b>();
748 test.run::<PackedAESBinaryField2x64b>();
749 test.run::<PackedAESBinaryField1x128b>();
750 test.run::<PackedAESBinaryField32x8b>();
751 test.run::<PackedAESBinaryField16x16b>();
752 test.run::<PackedAESBinaryField8x32b>();
753 test.run::<PackedAESBinaryField4x64b>();
754 test.run::<PackedAESBinaryField2x128b>();
755 test.run::<PackedAESBinaryField64x8b>();
756 test.run::<PackedAESBinaryField32x16b>();
757 test.run::<PackedAESBinaryField16x32b>();
758 test.run::<PackedAESBinaryField8x64b>();
759 test.run::<PackedAESBinaryField4x128b>();
760
761 test.run::<ByteSlicedAES16x128b>();
763 test.run::<ByteSlicedAES16x64b>();
764 test.run::<ByteSlicedAES2x16x64b>();
765 test.run::<ByteSlicedAES16x32b>();
766 test.run::<ByteSlicedAES4x16x32b>();
767 test.run::<ByteSlicedAES16x16b>();
768 test.run::<ByteSlicedAES8x16x16b>();
769 test.run::<ByteSlicedAES16x8b>();
770 test.run::<ByteSlicedAES16x16x8b>();
771
772 test.run::<ByteSliced16x128x1b>();
773 test.run::<ByteSliced8x128x1b>();
774 test.run::<ByteSliced4x128x1b>();
775 test.run::<ByteSliced2x128x1b>();
776 test.run::<ByteSliced1x128x1b>();
777
778 test.run::<ByteSlicedAES32x128b>();
779 test.run::<ByteSlicedAES32x64b>();
780 test.run::<ByteSlicedAES2x32x64b>();
781 test.run::<ByteSlicedAES32x32b>();
782 test.run::<ByteSlicedAES4x32x32b>();
783 test.run::<ByteSlicedAES32x16b>();
784 test.run::<ByteSlicedAES8x32x16b>();
785 test.run::<ByteSlicedAES32x8b>();
786 test.run::<ByteSlicedAES16x32x8b>();
787
788 test.run::<ByteSliced16x256x1b>();
789 test.run::<ByteSliced8x256x1b>();
790 test.run::<ByteSliced4x256x1b>();
791 test.run::<ByteSliced2x256x1b>();
792 test.run::<ByteSliced1x256x1b>();
793
794 test.run::<ByteSlicedAES64x128b>();
795 test.run::<ByteSlicedAES64x64b>();
796 test.run::<ByteSlicedAES2x64x64b>();
797 test.run::<ByteSlicedAES64x32b>();
798 test.run::<ByteSlicedAES4x64x32b>();
799 test.run::<ByteSlicedAES64x16b>();
800 test.run::<ByteSlicedAES8x64x16b>();
801 test.run::<ByteSlicedAES64x8b>();
802 test.run::<ByteSlicedAES16x64x8b>();
803
804 test.run::<ByteSliced16x512x1b>();
805 test.run::<ByteSliced8x512x1b>();
806 test.run::<ByteSliced4x512x1b>();
807 test.run::<ByteSliced2x512x1b>();
808 test.run::<ByteSliced1x512x1b>();
809
810 test.run::<BinaryField128bPolyval>();
812
813 test.run::<PackedBinaryPolyval1x128b>();
815 test.run::<PackedBinaryPolyval2x128b>();
816 test.run::<PackedBinaryPolyval4x128b>();
817 }
818
819 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
820 let packed = P::random(&mut rng);
821 let mut iter = packed.iter();
822 for i in 0..P::WIDTH {
823 assert_eq!(packed.get(i), iter.next().unwrap());
824 }
825 assert!(iter.next().is_none());
826 }
827
828 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
829 let packed = P::random(&mut rng);
830 let mut iter = packed.into_iter();
831 for i in 0..P::WIDTH {
832 assert_eq!(packed.get(i), iter.next().unwrap());
833 }
834 assert!(iter.next().is_none());
835 }
836
837 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
838 for len in [0, 1, 5] {
839 let packed = std::iter::repeat_with(|| P::random(&mut rng))
840 .take(len)
841 .collect::<Vec<_>>();
842
843 let elements_count = len * P::WIDTH;
844 for offset in [
845 0,
846 1,
847 Uniform::new(0, elements_count.max(1)).sample(&mut rng),
848 elements_count.saturating_sub(1),
849 elements_count,
850 ] {
851 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
852 let expected = (offset..elements_count)
853 .map(|i| get_packed_slice(&packed, i))
854 .collect::<Vec<_>>();
855
856 assert_eq!(actual, expected);
857 }
858 }
859 }
860
861 struct PackedFieldIterationTest;
862
863 impl PackedFieldTest for PackedFieldIterationTest {
864 fn run<P: PackedField>(&self) {
865 let mut rng = StdRng::seed_from_u64(0);
866
867 check_value_iteration::<P>(&mut rng);
868 check_ref_iteration::<P>(&mut rng);
869 check_slice_iteration::<P>(&mut rng);
870 }
871 }
872
873 #[test]
874 fn test_iteration() {
875 run_for_all_packed_fields(&PackedFieldIterationTest);
876 }
877
878 fn check_copy_from_scalars<P: PackedField>(mut rng: impl RngCore) {
879 let scalars = (0..100)
880 .map(|_| <<P as PackedField>::Scalar as Field>::random(&mut rng))
881 .collect::<Vec<_>>();
882
883 let mut packed_copy = vec![P::zero(); 100];
884
885 for len in [0, 2, 4, 8, 12, 16] {
886 copy_packed_from_scalars_slice(&scalars[0..len], &mut packed_copy);
887
888 for (i, &scalar) in scalars[0..len].iter().enumerate() {
889 assert_eq!(get_packed_slice(&packed_copy, i), scalar);
890 }
891 for i in len..100 {
892 assert_eq!(get_packed_slice(&packed_copy, i), P::Scalar::ZERO);
893 }
894 }
895 }
896
897 #[test]
898 fn test_copy_from_scalars() {
899 let mut rng = StdRng::seed_from_u64(0);
900
901 check_copy_from_scalars::<PackedBinaryField16x8b>(&mut rng);
902 check_copy_from_scalars::<PackedBinaryField32x4b>(&mut rng);
903 }
904
905 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
906 assert_eq!(collection.len(), expected.len());
907
908 for (i, v) in expected.iter().enumerate() {
909 assert_eq!(&collection.get(i), v);
910 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
911 }
912 }
913
914 fn check_collection_get_set<F: Field>(
915 collection: &mut impl RandomAccessSequenceMut<F>,
916 r#gen: &mut impl FnMut() -> F,
917 ) {
918 for i in 0..collection.len() {
919 let value = r#gen();
920 collection.set(i, value);
921 assert_eq!(collection.get(i), value);
922 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
923 }
924 }
925
926 #[test]
927 fn check_packed_slice() {
928 let slice: &[PackedBinaryField16x8b] = &[];
929 let packed_slice = PackedSlice::new(slice);
930 check_collection(&packed_slice, &[]);
931 let packed_slice = PackedSlice::new_with_len(slice, 0);
932 check_collection(&packed_slice, &[]);
933
934 let mut rng = StdRng::seed_from_u64(0);
935 let slice: &[PackedBinaryField16x8b] = &[
936 PackedBinaryField16x8b::random(&mut rng),
937 PackedBinaryField16x8b::random(&mut rng),
938 ];
939 let packed_slice = PackedSlice::new(slice);
940 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
941
942 let packed_slice = PackedSlice::new_with_len(slice, 3);
943 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
944 }
945
946 #[test]
947 fn check_packed_slice_mut() {
948 let mut rng = StdRng::seed_from_u64(0);
949 let mut r#gen = || <BinaryField8b as Field>::random(&mut rng);
950
951 let slice: &mut [PackedBinaryField16x8b] = &mut [];
952 let packed_slice = PackedSliceMut::new(slice);
953 check_collection(&packed_slice, &[]);
954 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
955 check_collection(&packed_slice, &[]);
956
957 let mut rng = StdRng::seed_from_u64(0);
958 let slice: &mut [PackedBinaryField16x8b] = &mut [
959 PackedBinaryField16x8b::random(&mut rng),
960 PackedBinaryField16x8b::random(&mut rng),
961 ];
962 let values = PackedField::iter_slice(slice).collect_vec();
963 let mut packed_slice = PackedSliceMut::new(slice);
964 check_collection(&packed_slice, &values);
965 check_collection_get_set(&mut packed_slice, &mut r#gen);
966
967 let values = PackedField::iter_slice(slice).collect_vec();
968 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
969 check_collection(&packed_slice, &values[..3]);
970 check_collection_get_set(&mut packed_slice, &mut r#gen);
971 }
972}