1use std::{
8 fmt::Debug,
9 iter::{self, Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::iter::IterExtensions;
14use bytemuck::Zeroable;
15use rand::RngCore;
16
17use super::{
18 arithmetic_traits::{Broadcast, MulAlpha, Square},
19 binary_field_arithmetic::TowerFieldArithmetic,
20 Error,
21};
22use crate::{
23 arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, Field, PackedExtension,
24};
25
26pub trait PackedField:
32 Default
33 + Debug
34 + Clone
35 + Copy
36 + Eq
37 + Sized
38 + Add<Output = Self>
39 + Sub<Output = Self>
40 + Mul<Output = Self>
41 + AddAssign
42 + SubAssign
43 + MulAssign
44 + Add<Self::Scalar, Output = Self>
45 + Sub<Self::Scalar, Output = Self>
46 + Mul<Self::Scalar, Output = Self>
47 + AddAssign<Self::Scalar>
48 + SubAssign<Self::Scalar>
49 + MulAssign<Self::Scalar>
50 + Sum
52 + Product
53 + Send
54 + Sync
55 + Zeroable
56 + 'static
57{
58 type Scalar: Field;
59
60 const LOG_WIDTH: usize;
62
63 const WIDTH: usize = 1 << Self::LOG_WIDTH;
67
68 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
72
73 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
77
78 #[inline]
80 fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
81 (i < Self::WIDTH)
82 .then_some(unsafe { self.get_unchecked(i) })
83 .ok_or(Error::IndexOutOfRange {
84 index: i,
85 max: Self::WIDTH,
86 })
87 }
88
89 #[inline]
91 fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
92 (i < Self::WIDTH)
93 .then(|| unsafe { self.set_unchecked(i, scalar) })
94 .ok_or(Error::IndexOutOfRange {
95 index: i,
96 max: Self::WIDTH,
97 })
98 }
99
100 #[inline]
102 fn get(&self, i: usize) -> Self::Scalar {
103 self.get_checked(i).expect("index must be less than width")
104 }
105
106 #[inline]
108 fn set(&mut self, i: usize, scalar: Self::Scalar) {
109 self.set_checked(i, scalar).expect("index must be less than width")
110 }
111
112 #[inline]
113 fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
114 (0..Self::WIDTH).map_skippable(move |i|
115 unsafe { self.get_unchecked(i) })
117 }
118
119 #[inline]
120 fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
121 (0..Self::WIDTH).map_skippable(move |i|
122 unsafe { self.get_unchecked(i) })
124 }
125
126 #[inline]
127 fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
128 slice.iter().flat_map(Self::iter)
129 }
130
131 #[inline]
132 fn zero() -> Self {
133 Self::broadcast(Self::Scalar::ZERO)
134 }
135
136 #[inline]
137 fn one() -> Self {
138 Self::broadcast(Self::Scalar::ONE)
139 }
140
141 #[inline(always)]
143 fn set_single(scalar: Self::Scalar) -> Self {
144 let mut result = Self::default();
145 result.set(0, scalar);
146
147 result
148 }
149
150 fn random(rng: impl RngCore) -> Self;
151 fn broadcast(scalar: Self::Scalar) -> Self;
152
153 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
155
156 fn try_from_fn<E>(
158 mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
159 ) -> Result<Self, E> {
160 let mut result = Self::default();
161 for i in 0..Self::WIDTH {
162 let scalar = f(i)?;
163 unsafe {
164 result.set_unchecked(i, scalar);
165 };
166 }
167 Ok(result)
168 }
169
170 fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
176 let mut result = Self::default();
177 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
178 result.set(i, val);
179 }
180 result
181 }
182
183 fn square(self) -> Self;
185
186 fn pow(self, exp: u64) -> Self {
188 let mut res = Self::one();
189 for i in (0..64).rev() {
190 res = res.square();
191 if ((exp >> i) & 1) == 1 {
192 res.mul_assign(self)
193 }
194 }
195 res
196 }
197
198 fn invert_or_zero(self) -> Self;
200
201 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
217
218 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
231
232 #[inline]
264 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
265 assert!(log_block_len <= Self::LOG_WIDTH);
266 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
267
268 unsafe {
270 self.spread_unchecked(log_block_len, block_idx)
271 }
272 }
273
274 #[inline]
279 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
280 let block_len = 1 << log_block_len;
281 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
282
283 Self::from_scalars(
284 self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
285 )
286 }
287}
288
289pub fn iter_packed_slice_with_offset<P: PackedField>(
293 packed: &[P],
294 offset: usize,
295) -> impl Iterator<Item = P::Scalar> + '_ + Send {
296 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
297 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
298 } else {
299 (&[], 0)
300 };
301
302 P::iter_slice(packed).skip(offset)
303}
304
305#[inline]
306pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
307 unsafe { packed[i / P::WIDTH].get_unchecked(i % P::WIDTH) }
309}
310
311#[inline]
315pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
316 packed
317 .get_unchecked(i / P::WIDTH)
318 .get_unchecked(i % P::WIDTH)
319}
320
321pub fn get_packed_slice_checked<P: PackedField>(
322 packed: &[P],
323 i: usize,
324) -> Result<P::Scalar, Error> {
325 packed
326 .get(i / P::WIDTH)
327 .map(|el| el.get(i % P::WIDTH))
328 .ok_or(Error::IndexOutOfRange {
329 index: i,
330 max: packed.len() * P::WIDTH,
331 })
332}
333
334pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
338 packed: &mut [P],
339 i: usize,
340 scalar: P::Scalar,
341) {
342 unsafe {
343 packed
344 .get_unchecked_mut(i / P::WIDTH)
345 .set_unchecked(i % P::WIDTH, scalar)
346 }
347}
348
349pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
350 unsafe { packed[i / P::WIDTH].set_unchecked(i % P::WIDTH, scalar) }
352}
353
354pub fn set_packed_slice_checked<P: PackedField>(
355 packed: &mut [P],
356 i: usize,
357 scalar: P::Scalar,
358) -> Result<(), Error> {
359 packed
360 .get_mut(i / P::WIDTH)
361 .map(|el| el.set(i % P::WIDTH, scalar))
362 .ok_or(Error::IndexOutOfRange {
363 index: i,
364 max: packed.len() * P::WIDTH,
365 })
366}
367
368pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
369 packed.len() * P::WIDTH
370}
371
372#[inline]
376pub fn packed_from_fn_with_offset<P: PackedField>(
377 offset: usize,
378 mut f: impl FnMut(usize) -> P::Scalar,
379) -> P {
380 P::from_fn(|i| f(i + offset * P::WIDTH))
381}
382
383pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
385 use crate::underlier::UnderlierType;
386
387 let subfield_bits = FS::Underlier::BITS;
390 let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
391
392 if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
393 P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
394 } else {
395 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
396 }
397}
398
399pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
400 scalars
401 .chunks(P::WIDTH)
402 .map(|chunk| P::from_scalars(chunk.iter().copied()))
403 .collect()
404}
405
406impl<F: Field> Broadcast<F> for F {
407 fn broadcast(scalar: F) -> Self {
408 scalar
409 }
410}
411
412impl<T: TowerFieldArithmetic> MulAlpha for T {
413 #[inline]
414 fn mul_alpha(self) -> Self {
415 <Self as TowerFieldArithmetic>::multiply_alpha(self)
416 }
417}
418
419impl<F: Field> PackedField for F {
420 type Scalar = F;
421
422 const LOG_WIDTH: usize = 0;
423
424 #[inline]
425 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
426 *self
427 }
428
429 #[inline]
430 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
431 *self = scalar;
432 }
433
434 #[inline]
435 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
436 iter::once(*self)
437 }
438
439 #[inline]
440 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
441 iter::once(self)
442 }
443
444 #[inline]
445 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
446 slice.iter().copied()
447 }
448
449 fn random(rng: impl RngCore) -> Self {
450 <Self as Field>::random(rng)
451 }
452
453 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
454 panic!("cannot interleave when WIDTH = 1");
455 }
456
457 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
458 panic!("cannot transpose when WIDTH = 1");
459 }
460
461 fn broadcast(scalar: Self::Scalar) -> Self {
462 scalar
463 }
464
465 fn square(self) -> Self {
466 <Self as Square>::square(self)
467 }
468
469 fn invert_or_zero(self) -> Self {
470 <Self as InvertOrZero>::invert_or_zero(self)
471 }
472
473 #[inline]
474 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
475 f(0)
476 }
477
478 #[inline]
479 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
480 self
481 }
482}
483
484pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
486
487impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
488
489#[cfg(test)]
490mod tests {
491 use rand::{
492 distributions::{Distribution, Uniform},
493 rngs::StdRng,
494 SeedableRng,
495 };
496
497 use super::*;
498 use crate::{
499 arch::{
500 byte_sliced::*, packed_1::*, packed_128::*, packed_16::*, packed_2::*, packed_256::*,
501 packed_32::*, packed_4::*, packed_512::*, packed_64::*, packed_8::*, packed_aes_128::*,
502 packed_aes_16::*, packed_aes_256::*, packed_aes_32::*, packed_aes_512::*,
503 packed_aes_64::*, packed_aes_8::*, packed_polyval_128::*, packed_polyval_256::*,
504 packed_polyval_512::*,
505 },
506 AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
507 BinaryField128b, BinaryField128bPolyval, BinaryField16b, BinaryField1b, BinaryField2b,
508 BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, PackedField,
509 };
510
511 trait PackedFieldTest {
512 fn run<P: PackedField>(&self);
513 }
514
515 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
517 test.run::<BinaryField1b>();
520 test.run::<BinaryField2b>();
521 test.run::<BinaryField4b>();
522 test.run::<BinaryField8b>();
523 test.run::<BinaryField16b>();
524 test.run::<BinaryField32b>();
525 test.run::<BinaryField64b>();
526 test.run::<BinaryField128b>();
527
528 test.run::<PackedBinaryField1x1b>();
530 test.run::<PackedBinaryField2x1b>();
531 test.run::<PackedBinaryField1x2b>();
532 test.run::<PackedBinaryField4x1b>();
533 test.run::<PackedBinaryField2x2b>();
534 test.run::<PackedBinaryField1x4b>();
535 test.run::<PackedBinaryField8x1b>();
536 test.run::<PackedBinaryField4x2b>();
537 test.run::<PackedBinaryField2x4b>();
538 test.run::<PackedBinaryField1x8b>();
539 test.run::<PackedBinaryField16x1b>();
540 test.run::<PackedBinaryField8x2b>();
541 test.run::<PackedBinaryField4x4b>();
542 test.run::<PackedBinaryField2x8b>();
543 test.run::<PackedBinaryField1x16b>();
544 test.run::<PackedBinaryField32x1b>();
545 test.run::<PackedBinaryField16x2b>();
546 test.run::<PackedBinaryField8x4b>();
547 test.run::<PackedBinaryField4x8b>();
548 test.run::<PackedBinaryField2x16b>();
549 test.run::<PackedBinaryField1x32b>();
550 test.run::<PackedBinaryField64x1b>();
551 test.run::<PackedBinaryField32x2b>();
552 test.run::<PackedBinaryField16x4b>();
553 test.run::<PackedBinaryField8x8b>();
554 test.run::<PackedBinaryField4x16b>();
555 test.run::<PackedBinaryField2x32b>();
556 test.run::<PackedBinaryField1x64b>();
557 test.run::<PackedBinaryField128x1b>();
558 test.run::<PackedBinaryField64x2b>();
559 test.run::<PackedBinaryField32x4b>();
560 test.run::<PackedBinaryField16x8b>();
561 test.run::<PackedBinaryField8x16b>();
562 test.run::<PackedBinaryField4x32b>();
563 test.run::<PackedBinaryField2x64b>();
564 test.run::<PackedBinaryField1x128b>();
565 test.run::<PackedBinaryField256x1b>();
566 test.run::<PackedBinaryField128x2b>();
567 test.run::<PackedBinaryField64x4b>();
568 test.run::<PackedBinaryField32x8b>();
569 test.run::<PackedBinaryField16x16b>();
570 test.run::<PackedBinaryField8x32b>();
571 test.run::<PackedBinaryField4x64b>();
572 test.run::<PackedBinaryField2x128b>();
573 test.run::<PackedBinaryField512x1b>();
574 test.run::<PackedBinaryField256x2b>();
575 test.run::<PackedBinaryField128x4b>();
576 test.run::<PackedBinaryField64x8b>();
577 test.run::<PackedBinaryField32x16b>();
578 test.run::<PackedBinaryField16x32b>();
579 test.run::<PackedBinaryField8x64b>();
580 test.run::<PackedBinaryField4x128b>();
581
582 test.run::<AESTowerField8b>();
584 test.run::<AESTowerField16b>();
585 test.run::<AESTowerField32b>();
586 test.run::<AESTowerField64b>();
587 test.run::<AESTowerField128b>();
588
589 test.run::<PackedAESBinaryField1x8b>();
591 test.run::<PackedAESBinaryField2x8b>();
592 test.run::<PackedAESBinaryField1x16b>();
593 test.run::<PackedAESBinaryField4x8b>();
594 test.run::<PackedAESBinaryField2x16b>();
595 test.run::<PackedAESBinaryField1x32b>();
596 test.run::<PackedAESBinaryField8x8b>();
597 test.run::<PackedAESBinaryField4x16b>();
598 test.run::<PackedAESBinaryField2x32b>();
599 test.run::<PackedAESBinaryField1x64b>();
600 test.run::<PackedAESBinaryField16x8b>();
601 test.run::<PackedAESBinaryField8x16b>();
602 test.run::<PackedAESBinaryField4x32b>();
603 test.run::<PackedAESBinaryField2x64b>();
604 test.run::<PackedAESBinaryField1x128b>();
605 test.run::<PackedAESBinaryField32x8b>();
606 test.run::<PackedAESBinaryField16x16b>();
607 test.run::<PackedAESBinaryField8x32b>();
608 test.run::<PackedAESBinaryField4x64b>();
609 test.run::<PackedAESBinaryField2x128b>();
610 test.run::<PackedAESBinaryField64x8b>();
611 test.run::<PackedAESBinaryField32x16b>();
612 test.run::<PackedAESBinaryField16x32b>();
613 test.run::<PackedAESBinaryField8x64b>();
614 test.run::<PackedAESBinaryField4x128b>();
615
616 test.run::<ByteSlicedAES16x128b>();
618 test.run::<ByteSlicedAES16x64b>();
619 test.run::<ByteSlicedAES2x16x64b>();
620 test.run::<ByteSlicedAES16x32b>();
621 test.run::<ByteSlicedAES4x16x32b>();
622 test.run::<ByteSlicedAES16x16b>();
623 test.run::<ByteSlicedAES8x16x16b>();
624 test.run::<ByteSlicedAES16x8b>();
625 test.run::<ByteSlicedAES16x16x8b>();
626 test.run::<ByteSlicedAES32x128b>();
627 test.run::<ByteSlicedAES32x64b>();
628 test.run::<ByteSlicedAES2x32x64b>();
629 test.run::<ByteSlicedAES32x32b>();
630 test.run::<ByteSlicedAES4x32x32b>();
631 test.run::<ByteSlicedAES32x16b>();
632 test.run::<ByteSlicedAES8x32x16b>();
633 test.run::<ByteSlicedAES32x8b>();
634 test.run::<ByteSlicedAES16x32x8b>();
635 test.run::<ByteSlicedAES64x128b>();
636 test.run::<ByteSlicedAES64x64b>();
637 test.run::<ByteSlicedAES2x64x64b>();
638 test.run::<ByteSlicedAES64x32b>();
639 test.run::<ByteSlicedAES4x64x32b>();
640 test.run::<ByteSlicedAES64x16b>();
641 test.run::<ByteSlicedAES8x64x16b>();
642 test.run::<ByteSlicedAES64x8b>();
643 test.run::<ByteSlicedAES16x64x8b>();
644
645 test.run::<BinaryField128bPolyval>();
647
648 test.run::<PackedBinaryPolyval1x128b>();
650 test.run::<PackedBinaryPolyval2x128b>();
651 test.run::<PackedBinaryPolyval4x128b>();
652 }
653
654 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
655 let packed = P::random(&mut rng);
656 let mut iter = packed.iter();
657 for i in 0..P::WIDTH {
658 assert_eq!(packed.get(i), iter.next().unwrap());
659 }
660 assert!(iter.next().is_none());
661 }
662
663 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
664 let packed = P::random(&mut rng);
665 let mut iter = packed.into_iter();
666 for i in 0..P::WIDTH {
667 assert_eq!(packed.get(i), iter.next().unwrap());
668 }
669 assert!(iter.next().is_none());
670 }
671
672 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
673 for len in [0, 1, 5] {
674 let packed = std::iter::repeat_with(|| P::random(&mut rng))
675 .take(len)
676 .collect::<Vec<_>>();
677
678 let elements_count = len * P::WIDTH;
679 for offset in [
680 0,
681 1,
682 Uniform::new(0, elements_count.max(1)).sample(&mut rng),
683 elements_count.saturating_sub(1),
684 elements_count,
685 ] {
686 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
687 let expected = (offset..elements_count)
688 .map(|i| get_packed_slice(&packed, i))
689 .collect::<Vec<_>>();
690
691 assert_eq!(actual, expected);
692 }
693 }
694 }
695
696 struct PackedFieldIterationTest;
697
698 impl PackedFieldTest for PackedFieldIterationTest {
699 fn run<P: PackedField>(&self) {
700 let mut rng = StdRng::seed_from_u64(0);
701
702 check_value_iteration::<P>(&mut rng);
703 check_ref_iteration::<P>(&mut rng);
704 check_slice_iteration::<P>(&mut rng);
705 }
706 }
707
708 #[test]
709 fn test_iteration() {
710 run_for_all_packed_fields(&PackedFieldIterationTest);
711 }
712}