1use std::{
8 fmt::Debug,
9 iter::{self, Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14 iter::IterExtensions,
15 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18use rand::RngCore;
19
20use super::{
21 arithmetic_traits::{Broadcast, MulAlpha, Square},
22 binary_field_arithmetic::TowerFieldArithmetic,
23 Error,
24};
25use crate::{
26 arithmetic_traits::InvertOrZero, is_packed_field_indexable, underlier::WithUnderlier,
27 unpack_if_possible_mut, BinaryField, Field, PackedExtension,
28};
29
30pub trait PackedField:
36 Default
37 + Debug
38 + Clone
39 + Copy
40 + Eq
41 + Sized
42 + Add<Output = Self>
43 + Sub<Output = Self>
44 + Mul<Output = Self>
45 + AddAssign
46 + SubAssign
47 + MulAssign
48 + Add<Self::Scalar, Output = Self>
49 + Sub<Self::Scalar, Output = Self>
50 + Mul<Self::Scalar, Output = Self>
51 + AddAssign<Self::Scalar>
52 + SubAssign<Self::Scalar>
53 + MulAssign<Self::Scalar>
54 + Sum
56 + Product
57 + Send
58 + Sync
59 + Zeroable
60 + 'static
61{
62 type Scalar: Field;
63
64 const LOG_WIDTH: usize;
66
67 const WIDTH: usize = 1 << Self::LOG_WIDTH;
71
72 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
76
77 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
81
82 #[inline]
84 fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
85 (i < Self::WIDTH)
86 .then_some(unsafe { self.get_unchecked(i) })
87 .ok_or(Error::IndexOutOfRange {
88 index: i,
89 max: Self::WIDTH,
90 })
91 }
92
93 #[inline]
95 fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
96 (i < Self::WIDTH)
97 .then(|| unsafe { self.set_unchecked(i, scalar) })
98 .ok_or(Error::IndexOutOfRange {
99 index: i,
100 max: Self::WIDTH,
101 })
102 }
103
104 #[inline]
106 fn get(&self, i: usize) -> Self::Scalar {
107 self.get_checked(i).expect("index must be less than width")
108 }
109
110 #[inline]
112 fn set(&mut self, i: usize, scalar: Self::Scalar) {
113 self.set_checked(i, scalar).expect("index must be less than width")
114 }
115
116 #[inline]
117 fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
118 (0..Self::WIDTH).map_skippable(move |i|
119 unsafe { self.get_unchecked(i) })
121 }
122
123 #[inline]
124 fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
125 (0..Self::WIDTH).map_skippable(move |i|
126 unsafe { self.get_unchecked(i) })
128 }
129
130 #[inline]
131 fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
132 slice.iter().flat_map(Self::iter)
133 }
134
135 #[inline]
136 fn zero() -> Self {
137 Self::broadcast(Self::Scalar::ZERO)
138 }
139
140 #[inline]
141 fn one() -> Self {
142 Self::broadcast(Self::Scalar::ONE)
143 }
144
145 #[inline(always)]
147 fn set_single(scalar: Self::Scalar) -> Self {
148 let mut result = Self::default();
149 result.set(0, scalar);
150
151 result
152 }
153
154 fn random(rng: impl RngCore) -> Self;
155 fn broadcast(scalar: Self::Scalar) -> Self;
156
157 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
159
160 fn try_from_fn<E>(
162 mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
163 ) -> Result<Self, E> {
164 let mut result = Self::default();
165 for i in 0..Self::WIDTH {
166 let scalar = f(i)?;
167 unsafe {
168 result.set_unchecked(i, scalar);
169 };
170 }
171 Ok(result)
172 }
173
174 fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
180 let mut result = Self::default();
181 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
182 result.set(i, val);
183 }
184 result
185 }
186
187 fn square(self) -> Self;
189
190 fn pow(self, exp: u64) -> Self {
192 let mut res = Self::one();
193 for i in (0..64).rev() {
194 res = res.square();
195 if ((exp >> i) & 1) == 1 {
196 res.mul_assign(self)
197 }
198 }
199 res
200 }
201
202 fn invert_or_zero(self) -> Self;
204
205 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
221
222 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
235
236 #[inline]
268 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
269 assert!(log_block_len <= Self::LOG_WIDTH);
270 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
271
272 unsafe {
274 self.spread_unchecked(log_block_len, block_idx)
275 }
276 }
277
278 #[inline]
283 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
284 let block_len = 1 << log_block_len;
285 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
286
287 Self::from_scalars(
288 self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
289 )
290 }
291}
292
293#[inline]
297pub fn iter_packed_slice_with_offset<P: PackedField>(
298 packed: &[P],
299 offset: usize,
300) -> impl Iterator<Item = P::Scalar> + '_ + Send {
301 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
302 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
303 } else {
304 (&[], 0)
305 };
306
307 P::iter_slice(packed).skip(offset)
308}
309
310#[inline(always)]
311pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
312 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
313
314 unsafe { get_packed_slice_unchecked(packed, i) }
315}
316
317#[inline(always)]
321pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
322 if is_packed_field_indexable::<P>() {
323 unsafe { *(packed.as_ptr() as *const P::Scalar).add(i) }
327 } else {
328 unsafe {
332 packed
333 .get_unchecked(i >> P::LOG_WIDTH)
334 .get_unchecked(i % P::WIDTH)
335 }
336 }
337}
338
339#[inline]
340pub fn get_packed_slice_checked<P: PackedField>(
341 packed: &[P],
342 i: usize,
343) -> Result<P::Scalar, Error> {
344 if i >> P::LOG_WIDTH < packed.len() {
345 Ok(unsafe { get_packed_slice_unchecked(packed, i) })
347 } else {
348 Err(Error::IndexOutOfRange {
349 index: i,
350 max: len_packed_slice(packed),
351 })
352 }
353}
354
355#[inline]
359pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
360 packed: &mut [P],
361 i: usize,
362 scalar: P::Scalar,
363) {
364 if is_packed_field_indexable::<P>() {
365 unsafe {
369 *(packed.as_mut_ptr() as *mut P::Scalar).add(i) = scalar;
370 }
371 } else {
372 unsafe {
376 packed
377 .get_unchecked_mut(i >> P::LOG_WIDTH)
378 .set_unchecked(i % P::WIDTH, scalar)
379 }
380 }
381}
382
383#[inline]
384pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
385 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
386
387 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
388}
389
390#[inline]
391pub fn set_packed_slice_checked<P: PackedField>(
392 packed: &mut [P],
393 i: usize,
394 scalar: P::Scalar,
395) -> Result<(), Error> {
396 if i >> P::LOG_WIDTH < packed.len() {
397 unsafe { set_packed_slice_unchecked(packed, i, scalar) };
399 Ok(())
400 } else {
401 Err(Error::IndexOutOfRange {
402 index: i,
403 max: len_packed_slice(packed),
404 })
405 }
406}
407
408#[inline(always)]
409pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
410 packed.len() << P::LOG_WIDTH
411}
412
413#[inline]
417pub fn packed_from_fn_with_offset<P: PackedField>(
418 offset: usize,
419 mut f: impl FnMut(usize) -> P::Scalar,
420) -> P {
421 P::from_fn(|i| f(i + offset * P::WIDTH))
422}
423
424pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
426 use crate::underlier::UnderlierType;
427
428 let subfield_bits = FS::Underlier::BITS;
431 let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
432
433 if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
434 P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
435 } else {
436 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
437 }
438}
439
440pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
442 scalars
443 .chunks(P::WIDTH)
444 .map(|chunk| P::from_scalars(chunk.iter().copied()))
445 .collect()
446}
447
448pub fn copy_packed_from_scalars_slice<P: PackedField>(src: &[P::Scalar], dst: &mut [P]) {
450 unpack_if_possible_mut(
451 dst,
452 |scalars| {
453 scalars[0..src.len()].copy_from_slice(src);
454 },
455 |packed| {
456 let chunks = src.chunks_exact(P::WIDTH);
457 let remainder = chunks.remainder();
458 for (chunk, packed) in chunks.zip(packed.iter_mut()) {
459 *packed = P::from_scalars(chunk.iter().copied());
460 }
461
462 if !remainder.is_empty() {
463 let offset = (src.len() >> P::LOG_WIDTH) << P::LOG_WIDTH;
464 let packed = &mut packed[offset];
465 for (i, scalar) in remainder.iter().enumerate() {
466 unsafe { packed.set_unchecked(i, *scalar) };
468 }
469 }
470 },
471 );
472}
473
474#[derive(Clone)]
476pub struct PackedSlice<'a, P: PackedField> {
477 slice: &'a [P],
478 len: usize,
479}
480
481impl<'a, P: PackedField> PackedSlice<'a, P> {
482 #[inline(always)]
483 pub fn new(slice: &'a [P]) -> Self {
484 Self {
485 slice,
486 len: len_packed_slice(slice),
487 }
488 }
489
490 #[inline(always)]
491 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
492 assert!(len <= len_packed_slice(slice));
493
494 Self { slice, len }
495 }
496}
497
498impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
499 #[inline(always)]
500 fn len(&self) -> usize {
501 self.len
502 }
503
504 #[inline(always)]
505 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
506 get_packed_slice_unchecked(self.slice, index)
507 }
508}
509
510pub struct PackedSliceMut<'a, P: PackedField> {
512 slice: &'a mut [P],
513 len: usize,
514}
515
516impl<'a, P: PackedField> PackedSliceMut<'a, P> {
517 #[inline(always)]
518 pub fn new(slice: &'a mut [P]) -> Self {
519 let len = len_packed_slice(slice);
520 Self { slice, len }
521 }
522
523 #[inline(always)]
524 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
525 assert!(len <= len_packed_slice(slice));
526
527 Self { slice, len }
528 }
529}
530
531impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
532 #[inline(always)]
533 fn len(&self) -> usize {
534 self.len
535 }
536
537 #[inline(always)]
538 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
539 get_packed_slice_unchecked(self.slice, index)
540 }
541}
542impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
543 #[inline(always)]
544 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
545 set_packed_slice_unchecked(self.slice, index, value);
546 }
547}
548
549impl<F: Field> Broadcast<F> for F {
550 fn broadcast(scalar: F) -> Self {
551 scalar
552 }
553}
554
555impl<T: TowerFieldArithmetic> MulAlpha for T {
556 #[inline]
557 fn mul_alpha(self) -> Self {
558 <Self as TowerFieldArithmetic>::multiply_alpha(self)
559 }
560}
561
562impl<F: Field> PackedField for F {
563 type Scalar = F;
564
565 const LOG_WIDTH: usize = 0;
566
567 #[inline]
568 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
569 *self
570 }
571
572 #[inline]
573 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
574 *self = scalar;
575 }
576
577 #[inline]
578 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
579 iter::once(*self)
580 }
581
582 #[inline]
583 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
584 iter::once(self)
585 }
586
587 #[inline]
588 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
589 slice.iter().copied()
590 }
591
592 fn random(rng: impl RngCore) -> Self {
593 <Self as Field>::random(rng)
594 }
595
596 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
597 panic!("cannot interleave when WIDTH = 1");
598 }
599
600 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
601 panic!("cannot transpose when WIDTH = 1");
602 }
603
604 fn broadcast(scalar: Self::Scalar) -> Self {
605 scalar
606 }
607
608 fn square(self) -> Self {
609 <Self as Square>::square(self)
610 }
611
612 fn invert_or_zero(self) -> Self {
613 <Self as InvertOrZero>::invert_or_zero(self)
614 }
615
616 #[inline]
617 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
618 f(0)
619 }
620
621 #[inline]
622 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
623 self
624 }
625}
626
627pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
629
630impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
631
632#[cfg(test)]
633mod tests {
634 use itertools::Itertools;
635 use rand::{
636 distributions::{Distribution, Uniform},
637 rngs::StdRng,
638 SeedableRng,
639 };
640
641 use super::*;
642 use crate::{
643 arch::{
644 byte_sliced::*, packed_1::*, packed_128::*, packed_16::*, packed_2::*, packed_256::*,
645 packed_32::*, packed_4::*, packed_512::*, packed_64::*, packed_8::*, packed_aes_128::*,
646 packed_aes_16::*, packed_aes_256::*, packed_aes_32::*, packed_aes_512::*,
647 packed_aes_64::*, packed_aes_8::*, packed_polyval_128::*, packed_polyval_256::*,
648 packed_polyval_512::*,
649 },
650 AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
651 BinaryField128b, BinaryField128bPolyval, BinaryField16b, BinaryField1b, BinaryField2b,
652 BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, PackedField,
653 };
654
655 trait PackedFieldTest {
656 fn run<P: PackedField>(&self);
657 }
658
659 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
661 test.run::<BinaryField1b>();
663 test.run::<BinaryField2b>();
664 test.run::<BinaryField4b>();
665 test.run::<BinaryField8b>();
666 test.run::<BinaryField16b>();
667 test.run::<BinaryField32b>();
668 test.run::<BinaryField64b>();
669 test.run::<BinaryField128b>();
670
671 test.run::<PackedBinaryField1x1b>();
673 test.run::<PackedBinaryField2x1b>();
674 test.run::<PackedBinaryField1x2b>();
675 test.run::<PackedBinaryField4x1b>();
676 test.run::<PackedBinaryField2x2b>();
677 test.run::<PackedBinaryField1x4b>();
678 test.run::<PackedBinaryField8x1b>();
679 test.run::<PackedBinaryField4x2b>();
680 test.run::<PackedBinaryField2x4b>();
681 test.run::<PackedBinaryField1x8b>();
682 test.run::<PackedBinaryField16x1b>();
683 test.run::<PackedBinaryField8x2b>();
684 test.run::<PackedBinaryField4x4b>();
685 test.run::<PackedBinaryField2x8b>();
686 test.run::<PackedBinaryField1x16b>();
687 test.run::<PackedBinaryField32x1b>();
688 test.run::<PackedBinaryField16x2b>();
689 test.run::<PackedBinaryField8x4b>();
690 test.run::<PackedBinaryField4x8b>();
691 test.run::<PackedBinaryField2x16b>();
692 test.run::<PackedBinaryField1x32b>();
693 test.run::<PackedBinaryField64x1b>();
694 test.run::<PackedBinaryField32x2b>();
695 test.run::<PackedBinaryField16x4b>();
696 test.run::<PackedBinaryField8x8b>();
697 test.run::<PackedBinaryField4x16b>();
698 test.run::<PackedBinaryField2x32b>();
699 test.run::<PackedBinaryField1x64b>();
700 test.run::<PackedBinaryField128x1b>();
701 test.run::<PackedBinaryField64x2b>();
702 test.run::<PackedBinaryField32x4b>();
703 test.run::<PackedBinaryField16x8b>();
704 test.run::<PackedBinaryField8x16b>();
705 test.run::<PackedBinaryField4x32b>();
706 test.run::<PackedBinaryField2x64b>();
707 test.run::<PackedBinaryField1x128b>();
708 test.run::<PackedBinaryField256x1b>();
709 test.run::<PackedBinaryField128x2b>();
710 test.run::<PackedBinaryField64x4b>();
711 test.run::<PackedBinaryField32x8b>();
712 test.run::<PackedBinaryField16x16b>();
713 test.run::<PackedBinaryField8x32b>();
714 test.run::<PackedBinaryField4x64b>();
715 test.run::<PackedBinaryField2x128b>();
716 test.run::<PackedBinaryField512x1b>();
717 test.run::<PackedBinaryField256x2b>();
718 test.run::<PackedBinaryField128x4b>();
719 test.run::<PackedBinaryField64x8b>();
720 test.run::<PackedBinaryField32x16b>();
721 test.run::<PackedBinaryField16x32b>();
722 test.run::<PackedBinaryField8x64b>();
723 test.run::<PackedBinaryField4x128b>();
724
725 test.run::<AESTowerField8b>();
727 test.run::<AESTowerField16b>();
728 test.run::<AESTowerField32b>();
729 test.run::<AESTowerField64b>();
730 test.run::<AESTowerField128b>();
731
732 test.run::<PackedAESBinaryField1x8b>();
734 test.run::<PackedAESBinaryField2x8b>();
735 test.run::<PackedAESBinaryField1x16b>();
736 test.run::<PackedAESBinaryField4x8b>();
737 test.run::<PackedAESBinaryField2x16b>();
738 test.run::<PackedAESBinaryField1x32b>();
739 test.run::<PackedAESBinaryField8x8b>();
740 test.run::<PackedAESBinaryField4x16b>();
741 test.run::<PackedAESBinaryField2x32b>();
742 test.run::<PackedAESBinaryField1x64b>();
743 test.run::<PackedAESBinaryField16x8b>();
744 test.run::<PackedAESBinaryField8x16b>();
745 test.run::<PackedAESBinaryField4x32b>();
746 test.run::<PackedAESBinaryField2x64b>();
747 test.run::<PackedAESBinaryField1x128b>();
748 test.run::<PackedAESBinaryField32x8b>();
749 test.run::<PackedAESBinaryField16x16b>();
750 test.run::<PackedAESBinaryField8x32b>();
751 test.run::<PackedAESBinaryField4x64b>();
752 test.run::<PackedAESBinaryField2x128b>();
753 test.run::<PackedAESBinaryField64x8b>();
754 test.run::<PackedAESBinaryField32x16b>();
755 test.run::<PackedAESBinaryField16x32b>();
756 test.run::<PackedAESBinaryField8x64b>();
757 test.run::<PackedAESBinaryField4x128b>();
758
759 test.run::<ByteSlicedAES16x128b>();
761 test.run::<ByteSlicedAES16x64b>();
762 test.run::<ByteSlicedAES2x16x64b>();
763 test.run::<ByteSlicedAES16x32b>();
764 test.run::<ByteSlicedAES4x16x32b>();
765 test.run::<ByteSlicedAES16x16b>();
766 test.run::<ByteSlicedAES8x16x16b>();
767 test.run::<ByteSlicedAES16x8b>();
768 test.run::<ByteSlicedAES16x16x8b>();
769
770 test.run::<ByteSliced16x128x1b>();
771 test.run::<ByteSliced8x128x1b>();
772 test.run::<ByteSliced4x128x1b>();
773 test.run::<ByteSliced2x128x1b>();
774 test.run::<ByteSliced1x128x1b>();
775
776 test.run::<ByteSlicedAES32x128b>();
777 test.run::<ByteSlicedAES32x64b>();
778 test.run::<ByteSlicedAES2x32x64b>();
779 test.run::<ByteSlicedAES32x32b>();
780 test.run::<ByteSlicedAES4x32x32b>();
781 test.run::<ByteSlicedAES32x16b>();
782 test.run::<ByteSlicedAES8x32x16b>();
783 test.run::<ByteSlicedAES32x8b>();
784 test.run::<ByteSlicedAES16x32x8b>();
785
786 test.run::<ByteSliced16x256x1b>();
787 test.run::<ByteSliced8x256x1b>();
788 test.run::<ByteSliced4x256x1b>();
789 test.run::<ByteSliced2x256x1b>();
790 test.run::<ByteSliced1x256x1b>();
791
792 test.run::<ByteSlicedAES64x128b>();
793 test.run::<ByteSlicedAES64x64b>();
794 test.run::<ByteSlicedAES2x64x64b>();
795 test.run::<ByteSlicedAES64x32b>();
796 test.run::<ByteSlicedAES4x64x32b>();
797 test.run::<ByteSlicedAES64x16b>();
798 test.run::<ByteSlicedAES8x64x16b>();
799 test.run::<ByteSlicedAES64x8b>();
800 test.run::<ByteSlicedAES16x64x8b>();
801
802 test.run::<ByteSliced16x512x1b>();
803 test.run::<ByteSliced8x512x1b>();
804 test.run::<ByteSliced4x512x1b>();
805 test.run::<ByteSliced2x512x1b>();
806 test.run::<ByteSliced1x512x1b>();
807
808 test.run::<BinaryField128bPolyval>();
810
811 test.run::<PackedBinaryPolyval1x128b>();
813 test.run::<PackedBinaryPolyval2x128b>();
814 test.run::<PackedBinaryPolyval4x128b>();
815 }
816
817 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
818 let packed = P::random(&mut rng);
819 let mut iter = packed.iter();
820 for i in 0..P::WIDTH {
821 assert_eq!(packed.get(i), iter.next().unwrap());
822 }
823 assert!(iter.next().is_none());
824 }
825
826 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
827 let packed = P::random(&mut rng);
828 let mut iter = packed.into_iter();
829 for i in 0..P::WIDTH {
830 assert_eq!(packed.get(i), iter.next().unwrap());
831 }
832 assert!(iter.next().is_none());
833 }
834
835 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
836 for len in [0, 1, 5] {
837 let packed = std::iter::repeat_with(|| P::random(&mut rng))
838 .take(len)
839 .collect::<Vec<_>>();
840
841 let elements_count = len * P::WIDTH;
842 for offset in [
843 0,
844 1,
845 Uniform::new(0, elements_count.max(1)).sample(&mut rng),
846 elements_count.saturating_sub(1),
847 elements_count,
848 ] {
849 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
850 let expected = (offset..elements_count)
851 .map(|i| get_packed_slice(&packed, i))
852 .collect::<Vec<_>>();
853
854 assert_eq!(actual, expected);
855 }
856 }
857 }
858
859 struct PackedFieldIterationTest;
860
861 impl PackedFieldTest for PackedFieldIterationTest {
862 fn run<P: PackedField>(&self) {
863 let mut rng = StdRng::seed_from_u64(0);
864
865 check_value_iteration::<P>(&mut rng);
866 check_ref_iteration::<P>(&mut rng);
867 check_slice_iteration::<P>(&mut rng);
868 }
869 }
870
871 #[test]
872 fn test_iteration() {
873 run_for_all_packed_fields(&PackedFieldIterationTest);
874 }
875
876 fn check_copy_from_scalars<P: PackedField>(mut rng: impl RngCore) {
877 let scalars = (0..100)
878 .map(|_| <<P as PackedField>::Scalar as Field>::random(&mut rng))
879 .collect::<Vec<_>>();
880
881 let mut packed_copy = vec![P::zero(); 100];
882
883 for len in [0, 2, 4, 8, 12, 16] {
884 copy_packed_from_scalars_slice(&scalars[0..len], &mut packed_copy);
885
886 for (i, &scalar) in scalars[0..len].iter().enumerate() {
887 assert_eq!(get_packed_slice(&packed_copy, i), scalar);
888 }
889 for i in len..100 {
890 assert_eq!(get_packed_slice(&packed_copy, i), P::Scalar::ZERO);
891 }
892 }
893 }
894
895 #[test]
896 fn test_copy_from_scalars() {
897 let mut rng = StdRng::seed_from_u64(0);
898
899 check_copy_from_scalars::<PackedBinaryField16x8b>(&mut rng);
900 check_copy_from_scalars::<PackedBinaryField32x4b>(&mut rng);
901 }
902
903 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
904 assert_eq!(collection.len(), expected.len());
905
906 for (i, v) in expected.iter().enumerate() {
907 assert_eq!(&collection.get(i), v);
908 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
909 }
910 }
911
912 fn check_collection_get_set<F: Field>(
913 collection: &mut impl RandomAccessSequenceMut<F>,
914 gen: &mut impl FnMut() -> F,
915 ) {
916 for i in 0..collection.len() {
917 let value = gen();
918 collection.set(i, value);
919 assert_eq!(collection.get(i), value);
920 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
921 }
922 }
923
924 #[test]
925 fn check_packed_slice() {
926 let slice: &[PackedBinaryField16x8b] = &[];
927 let packed_slice = PackedSlice::new(slice);
928 check_collection(&packed_slice, &[]);
929 let packed_slice = PackedSlice::new_with_len(slice, 0);
930 check_collection(&packed_slice, &[]);
931
932 let mut rng = StdRng::seed_from_u64(0);
933 let slice: &[PackedBinaryField16x8b] = &[
934 PackedBinaryField16x8b::random(&mut rng),
935 PackedBinaryField16x8b::random(&mut rng),
936 ];
937 let packed_slice = PackedSlice::new(slice);
938 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
939
940 let packed_slice = PackedSlice::new_with_len(slice, 3);
941 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
942 }
943
944 #[test]
945 fn check_packed_slice_mut() {
946 let mut rng = StdRng::seed_from_u64(0);
947 let mut gen = || <BinaryField8b as Field>::random(&mut rng);
948
949 let slice: &mut [PackedBinaryField16x8b] = &mut [];
950 let packed_slice = PackedSliceMut::new(slice);
951 check_collection(&packed_slice, &[]);
952 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
953 check_collection(&packed_slice, &[]);
954
955 let mut rng = StdRng::seed_from_u64(0);
956 let slice: &mut [PackedBinaryField16x8b] = &mut [
957 PackedBinaryField16x8b::random(&mut rng),
958 PackedBinaryField16x8b::random(&mut rng),
959 ];
960 let values = PackedField::iter_slice(slice).collect_vec();
961 let mut packed_slice = PackedSliceMut::new(slice);
962 check_collection(&packed_slice, &values);
963 check_collection_get_set(&mut packed_slice, &mut gen);
964
965 let values = PackedField::iter_slice(slice).collect_vec();
966 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
967 check_collection(&packed_slice, &values[..3]);
968 check_collection_get_set(&mut packed_slice, &mut gen);
969 }
970}