1use std::{
8 fmt::Debug,
9 iter::{self, Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::iter::IterExtensions;
14use bytemuck::Zeroable;
15use rand::RngCore;
16
17use super::{
18 arithmetic_traits::{Broadcast, MulAlpha, Square},
19 binary_field_arithmetic::TowerFieldArithmetic,
20 Error,
21};
22use crate::{
23 arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, Field, PackedExtension,
24};
25
26pub trait PackedField:
32 Default
33 + Debug
34 + Clone
35 + Copy
36 + Eq
37 + Sized
38 + Add<Output = Self>
39 + Sub<Output = Self>
40 + Mul<Output = Self>
41 + AddAssign
42 + SubAssign
43 + MulAssign
44 + Add<Self::Scalar, Output = Self>
45 + Sub<Self::Scalar, Output = Self>
46 + Mul<Self::Scalar, Output = Self>
47 + AddAssign<Self::Scalar>
48 + SubAssign<Self::Scalar>
49 + MulAssign<Self::Scalar>
50 + Sum
52 + Product
53 + Send
54 + Sync
55 + Zeroable
56 + 'static
57{
58 type Scalar: Field;
59
60 const LOG_WIDTH: usize;
62
63 const WIDTH: usize = 1 << Self::LOG_WIDTH;
67
68 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
72
73 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
77
78 #[inline]
80 fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
81 (i < Self::WIDTH)
82 .then_some(unsafe { self.get_unchecked(i) })
83 .ok_or(Error::IndexOutOfRange {
84 index: i,
85 max: Self::WIDTH,
86 })
87 }
88
89 #[inline]
91 fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
92 (i < Self::WIDTH)
93 .then(|| unsafe { self.set_unchecked(i, scalar) })
94 .ok_or(Error::IndexOutOfRange {
95 index: i,
96 max: Self::WIDTH,
97 })
98 }
99
100 #[inline]
102 fn get(&self, i: usize) -> Self::Scalar {
103 self.get_checked(i).expect("index must be less than width")
104 }
105
106 #[inline]
108 fn set(&mut self, i: usize, scalar: Self::Scalar) {
109 self.set_checked(i, scalar).expect("index must be less than width")
110 }
111
112 #[inline]
113 fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
114 (0..Self::WIDTH).map_skippable(move |i|
115 unsafe { self.get_unchecked(i) })
117 }
118
119 #[inline]
120 fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
121 (0..Self::WIDTH).map_skippable(move |i|
122 unsafe { self.get_unchecked(i) })
124 }
125
126 #[inline]
127 fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
128 slice.iter().flat_map(Self::iter)
129 }
130
131 #[inline]
132 fn zero() -> Self {
133 Self::broadcast(Self::Scalar::ZERO)
134 }
135
136 #[inline]
137 fn one() -> Self {
138 Self::broadcast(Self::Scalar::ONE)
139 }
140
141 #[inline(always)]
143 fn set_single(scalar: Self::Scalar) -> Self {
144 let mut result = Self::default();
145 result.set(0, scalar);
146
147 result
148 }
149
150 fn random(rng: impl RngCore) -> Self;
151 fn broadcast(scalar: Self::Scalar) -> Self;
152
153 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
155
156 fn try_from_fn<E>(
158 mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
159 ) -> Result<Self, E> {
160 let mut result = Self::default();
161 for i in 0..Self::WIDTH {
162 let scalar = f(i)?;
163 unsafe {
164 result.set_unchecked(i, scalar);
165 };
166 }
167 Ok(result)
168 }
169
170 fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
176 let mut result = Self::default();
177 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
178 result.set(i, val);
179 }
180 result
181 }
182
183 fn square(self) -> Self;
185
186 fn pow(self, exp: u64) -> Self {
188 let mut res = Self::one();
189 for i in (0..64).rev() {
190 res = res.square();
191 if ((exp >> i) & 1) == 1 {
192 res.mul_assign(self)
193 }
194 }
195 res
196 }
197
198 fn invert_or_zero(self) -> Self;
200
201 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
217
218 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
231
232 #[inline]
264 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
265 assert!(log_block_len <= Self::LOG_WIDTH);
266 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
267
268 unsafe {
270 self.spread_unchecked(log_block_len, block_idx)
271 }
272 }
273
274 #[inline]
279 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
280 let block_len = 1 << log_block_len;
281 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
282
283 Self::from_scalars(
284 self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
285 )
286 }
287}
288
289pub fn iter_packed_slice_with_offset<P: PackedField>(
293 packed: &[P],
294 offset: usize,
295) -> impl Iterator<Item = P::Scalar> + '_ + Send {
296 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
297 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
298 } else {
299 (&[], 0)
300 };
301
302 P::iter_slice(packed).skip(offset)
303}
304
305#[inline]
306pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
307 unsafe { packed[i / P::WIDTH].get_unchecked(i % P::WIDTH) }
309}
310
311#[inline]
315pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
316 packed
317 .get_unchecked(i / P::WIDTH)
318 .get_unchecked(i % P::WIDTH)
319}
320
321pub fn get_packed_slice_checked<P: PackedField>(
322 packed: &[P],
323 i: usize,
324) -> Result<P::Scalar, Error> {
325 packed
326 .get(i / P::WIDTH)
327 .map(|el| el.get(i % P::WIDTH))
328 .ok_or(Error::IndexOutOfRange {
329 index: i,
330 max: packed.len() * P::WIDTH,
331 })
332}
333
334pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
338 packed: &mut [P],
339 i: usize,
340 scalar: P::Scalar,
341) {
342 unsafe {
343 packed
344 .get_unchecked_mut(i / P::WIDTH)
345 .set_unchecked(i % P::WIDTH, scalar)
346 }
347}
348
349pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
350 unsafe { packed[i / P::WIDTH].set_unchecked(i % P::WIDTH, scalar) }
352}
353
354pub fn set_packed_slice_checked<P: PackedField>(
355 packed: &mut [P],
356 i: usize,
357 scalar: P::Scalar,
358) -> Result<(), Error> {
359 packed
360 .get_mut(i / P::WIDTH)
361 .map(|el| el.set(i % P::WIDTH, scalar))
362 .ok_or(Error::IndexOutOfRange {
363 index: i,
364 max: packed.len() * P::WIDTH,
365 })
366}
367
368pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
369 packed.len() * P::WIDTH
370}
371
372#[inline]
376pub fn packed_from_fn_with_offset<P: PackedField>(
377 offset: usize,
378 mut f: impl FnMut(usize) -> P::Scalar,
379) -> P {
380 P::from_fn(|i| f(i + offset * P::WIDTH))
381}
382
383pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
385 use crate::underlier::UnderlierType;
386
387 let subfield_bits = FS::Underlier::BITS;
390 let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
391
392 if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
393 P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
394 } else {
395 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
396 }
397}
398
399pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
400 let mut packed_slice = vec![P::default(); scalars.len() / P::WIDTH];
401 for (i, scalar) in scalars.iter().enumerate() {
402 set_packed_slice(&mut packed_slice, i, *scalar);
403 }
404 packed_slice
405}
406
407impl<F: Field> Broadcast<F> for F {
408 fn broadcast(scalar: F) -> Self {
409 scalar
410 }
411}
412
413impl<T: TowerFieldArithmetic> MulAlpha for T {
414 #[inline]
415 fn mul_alpha(self) -> Self {
416 <Self as TowerFieldArithmetic>::multiply_alpha(self)
417 }
418}
419
420impl<F: Field> PackedField for F {
421 type Scalar = F;
422
423 const LOG_WIDTH: usize = 0;
424
425 #[inline]
426 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
427 *self
428 }
429
430 #[inline]
431 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
432 *self = scalar;
433 }
434
435 #[inline]
436 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
437 iter::once(*self)
438 }
439
440 #[inline]
441 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
442 iter::once(self)
443 }
444
445 #[inline]
446 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
447 slice.iter().copied()
448 }
449
450 fn random(rng: impl RngCore) -> Self {
451 <Self as Field>::random(rng)
452 }
453
454 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
455 panic!("cannot interleave when WIDTH = 1");
456 }
457
458 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
459 panic!("cannot transpose when WIDTH = 1");
460 }
461
462 fn broadcast(scalar: Self::Scalar) -> Self {
463 scalar
464 }
465
466 fn square(self) -> Self {
467 <Self as Square>::square(self)
468 }
469
470 fn invert_or_zero(self) -> Self {
471 <Self as InvertOrZero>::invert_or_zero(self)
472 }
473
474 #[inline]
475 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
476 f(0)
477 }
478
479 #[inline]
480 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
481 self
482 }
483}
484
485pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
487
488impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
489
490#[cfg(test)]
491mod tests {
492 use rand::{
493 distributions::{Distribution, Uniform},
494 rngs::StdRng,
495 SeedableRng,
496 };
497
498 use super::*;
499 use crate::{
500 AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
501 BinaryField128b, BinaryField128bPolyval, BinaryField16b, BinaryField1b, BinaryField2b,
502 BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, ByteSlicedAES32x128b,
503 ByteSlicedAES32x16b, ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b,
504 PackedBinaryField128x1b, PackedBinaryField128x2b, PackedBinaryField128x4b,
505 PackedBinaryField16x16b, PackedBinaryField16x1b, PackedBinaryField16x2b,
506 PackedBinaryField16x32b, PackedBinaryField16x4b, PackedBinaryField16x8b,
507 PackedBinaryField1x128b, PackedBinaryField1x16b, PackedBinaryField1x1b,
508 PackedBinaryField1x2b, PackedBinaryField1x32b, PackedBinaryField1x4b,
509 PackedBinaryField1x64b, PackedBinaryField1x8b, PackedBinaryField256x1b,
510 PackedBinaryField256x2b, PackedBinaryField2x128b, PackedBinaryField2x16b,
511 PackedBinaryField2x1b, PackedBinaryField2x2b, PackedBinaryField2x32b,
512 PackedBinaryField2x4b, PackedBinaryField2x64b, PackedBinaryField2x8b,
513 PackedBinaryField32x16b, PackedBinaryField32x1b, PackedBinaryField32x2b,
514 PackedBinaryField32x4b, PackedBinaryField32x8b, PackedBinaryField4x128b,
515 PackedBinaryField4x16b, PackedBinaryField4x1b, PackedBinaryField4x2b,
516 PackedBinaryField4x32b, PackedBinaryField4x4b, PackedBinaryField4x64b,
517 PackedBinaryField4x8b, PackedBinaryField512x1b, PackedBinaryField64x1b,
518 PackedBinaryField64x2b, PackedBinaryField64x4b, PackedBinaryField64x8b,
519 PackedBinaryField8x16b, PackedBinaryField8x1b, PackedBinaryField8x2b,
520 PackedBinaryField8x32b, PackedBinaryField8x4b, PackedBinaryField8x64b,
521 PackedBinaryField8x8b, PackedBinaryPolyval1x128b, PackedBinaryPolyval2x128b,
522 PackedBinaryPolyval4x128b, PackedField,
523 };
524
525 trait PackedFieldTest {
526 fn run<P: PackedField>(&self);
527 }
528
529 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
531 test.run::<BinaryField1b>();
534 test.run::<BinaryField2b>();
535 test.run::<BinaryField4b>();
536 test.run::<BinaryField8b>();
537 test.run::<BinaryField16b>();
538 test.run::<BinaryField32b>();
539 test.run::<BinaryField64b>();
540 test.run::<BinaryField128b>();
541
542 test.run::<PackedBinaryField1x1b>();
544 test.run::<PackedBinaryField2x1b>();
545 test.run::<PackedBinaryField1x2b>();
546 test.run::<PackedBinaryField4x1b>();
547 test.run::<PackedBinaryField2x2b>();
548 test.run::<PackedBinaryField1x4b>();
549 test.run::<PackedBinaryField8x1b>();
550 test.run::<PackedBinaryField4x2b>();
551 test.run::<PackedBinaryField2x4b>();
552 test.run::<PackedBinaryField1x8b>();
553 test.run::<PackedBinaryField16x1b>();
554 test.run::<PackedBinaryField8x2b>();
555 test.run::<PackedBinaryField4x4b>();
556 test.run::<PackedBinaryField2x8b>();
557 test.run::<PackedBinaryField1x16b>();
558 test.run::<PackedBinaryField32x1b>();
559 test.run::<PackedBinaryField16x2b>();
560 test.run::<PackedBinaryField8x4b>();
561 test.run::<PackedBinaryField4x8b>();
562 test.run::<PackedBinaryField2x16b>();
563 test.run::<PackedBinaryField1x32b>();
564 test.run::<PackedBinaryField64x1b>();
565 test.run::<PackedBinaryField32x2b>();
566 test.run::<PackedBinaryField16x4b>();
567 test.run::<PackedBinaryField8x8b>();
568 test.run::<PackedBinaryField4x16b>();
569 test.run::<PackedBinaryField2x32b>();
570 test.run::<PackedBinaryField1x64b>();
571 test.run::<PackedBinaryField128x1b>();
572 test.run::<PackedBinaryField64x2b>();
573 test.run::<PackedBinaryField32x4b>();
574 test.run::<PackedBinaryField16x8b>();
575 test.run::<PackedBinaryField8x16b>();
576 test.run::<PackedBinaryField4x32b>();
577 test.run::<PackedBinaryField2x64b>();
578 test.run::<PackedBinaryField1x128b>();
579 test.run::<PackedBinaryField256x1b>();
580 test.run::<PackedBinaryField128x2b>();
581 test.run::<PackedBinaryField64x4b>();
582 test.run::<PackedBinaryField32x8b>();
583 test.run::<PackedBinaryField16x16b>();
584 test.run::<PackedBinaryField8x32b>();
585 test.run::<PackedBinaryField4x64b>();
586 test.run::<PackedBinaryField2x128b>();
587 test.run::<PackedBinaryField512x1b>();
588 test.run::<PackedBinaryField256x2b>();
589 test.run::<PackedBinaryField128x4b>();
590 test.run::<PackedBinaryField64x8b>();
591 test.run::<PackedBinaryField32x16b>();
592 test.run::<PackedBinaryField16x32b>();
593 test.run::<PackedBinaryField8x64b>();
594 test.run::<PackedBinaryField4x128b>();
595
596 test.run::<AESTowerField8b>();
598 test.run::<AESTowerField16b>();
599 test.run::<AESTowerField32b>();
600 test.run::<AESTowerField64b>();
601 test.run::<AESTowerField128b>();
602
603 test.run::<PackedBinaryField1x8b>();
605 test.run::<PackedBinaryField2x8b>();
606 test.run::<PackedBinaryField1x16b>();
607 test.run::<PackedBinaryField4x8b>();
608 test.run::<PackedBinaryField2x16b>();
609 test.run::<PackedBinaryField1x32b>();
610 test.run::<PackedBinaryField8x8b>();
611 test.run::<PackedBinaryField4x16b>();
612 test.run::<PackedBinaryField2x32b>();
613 test.run::<PackedBinaryField1x64b>();
614 test.run::<PackedBinaryField16x8b>();
615 test.run::<PackedBinaryField8x16b>();
616 test.run::<PackedBinaryField4x32b>();
617 test.run::<PackedBinaryField2x64b>();
618 test.run::<PackedBinaryField1x128b>();
619 test.run::<PackedBinaryField32x8b>();
620 test.run::<PackedBinaryField16x16b>();
621 test.run::<PackedBinaryField8x32b>();
622 test.run::<PackedBinaryField4x64b>();
623 test.run::<PackedBinaryField2x128b>();
624 test.run::<PackedBinaryField64x8b>();
625 test.run::<PackedBinaryField32x16b>();
626 test.run::<PackedBinaryField16x32b>();
627 test.run::<PackedBinaryField8x64b>();
628 test.run::<PackedBinaryField4x128b>();
629 test.run::<ByteSlicedAES32x8b>();
630 test.run::<ByteSlicedAES32x64b>();
631 test.run::<ByteSlicedAES32x16b>();
632 test.run::<ByteSlicedAES32x32b>();
633 test.run::<ByteSlicedAES32x128b>();
634
635 test.run::<BinaryField128bPolyval>();
637
638 test.run::<PackedBinaryPolyval1x128b>();
640 test.run::<PackedBinaryPolyval2x128b>();
641 test.run::<PackedBinaryPolyval4x128b>();
642 }
643
644 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
645 let packed = P::random(&mut rng);
646 let mut iter = packed.iter();
647 for i in 0..P::WIDTH {
648 assert_eq!(packed.get(i), iter.next().unwrap());
649 }
650 assert!(iter.next().is_none());
651 }
652
653 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
654 let packed = P::random(&mut rng);
655 let mut iter = packed.into_iter();
656 for i in 0..P::WIDTH {
657 assert_eq!(packed.get(i), iter.next().unwrap());
658 }
659 assert!(iter.next().is_none());
660 }
661
662 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
663 for len in [0, 1, 5] {
664 let packed = std::iter::repeat_with(|| P::random(&mut rng))
665 .take(len)
666 .collect::<Vec<_>>();
667
668 let elements_count = len * P::WIDTH;
669 for offset in [
670 0,
671 1,
672 Uniform::new(0, elements_count.max(1)).sample(&mut rng),
673 elements_count.saturating_sub(1),
674 elements_count,
675 ] {
676 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
677 let expected = (offset..elements_count)
678 .map(|i| get_packed_slice(&packed, i))
679 .collect::<Vec<_>>();
680
681 assert_eq!(actual, expected);
682 }
683 }
684 }
685
686 struct PackedFieldIterationTest;
687
688 impl PackedFieldTest for PackedFieldIterationTest {
689 fn run<P: PackedField>(&self) {
690 let mut rng = StdRng::seed_from_u64(0);
691
692 check_value_iteration::<P>(&mut rng);
693 check_ref_iteration::<P>(&mut rng);
694 check_slice_iteration::<P>(&mut rng);
695 }
696 }
697
698 #[test]
699 fn test_iteration() {
700 run_for_all_packed_fields(&PackedFieldIterationTest);
701 }
702}