1use std::{
8 fmt::Debug,
9 iter,
10 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14 iter::IterExtensions,
15 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18
19use super::{PackedExtension, Random, arithmetic_traits::Square};
20use crate::{BinaryField, Field, field::FieldOps};
21
22pub trait PackedField:
28 Default
29 + Debug
30 + Clone
31 + Copy
32 + Eq
33 + Sized
34 + FieldOps
35 + Add<Self::Scalar, Output = Self>
36 + Sub<Self::Scalar, Output = Self>
37 + Mul<Self::Scalar, Output = Self>
38 + AddAssign<Self::Scalar>
39 + SubAssign<Self::Scalar>
40 + MulAssign<Self::Scalar>
41 + Send
42 + Sync
43 + Zeroable
44 + Random
45 + 'static
46{
47 const LOG_WIDTH: usize;
49
50 const WIDTH: usize = 1 << Self::LOG_WIDTH;
54
55 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
59
60 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
64
65 #[inline]
71 fn get(&self, i: usize) -> Self::Scalar {
72 assert!(i < Self::WIDTH, "index {i} out of range for width {}", Self::WIDTH);
73 unsafe { self.get_unchecked(i) }
75 }
76
77 #[inline]
83 fn set(&mut self, i: usize, scalar: Self::Scalar) {
84 assert!(i < Self::WIDTH, "index {i} out of range for width {}", Self::WIDTH);
85 unsafe { self.set_unchecked(i, scalar) }
87 }
88
89 #[inline]
90 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
91 (0..Self::WIDTH).map_skippable(move |i|
92 unsafe { self.get_unchecked(i) })
94 }
95
96 #[inline]
97 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
98 (0..Self::WIDTH).map_skippable(move |i|
99 unsafe { self.get_unchecked(i) })
101 }
102
103 #[inline]
104 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
105 slice.iter().flat_map(Self::iter)
106 }
107
108 #[inline(always)]
110 fn set_single(scalar: Self::Scalar) -> Self {
111 let mut result = Self::default();
112 result.set(0, scalar);
113
114 result
115 }
116
117 fn broadcast(scalar: Self::Scalar) -> Self;
118
119 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
121
122 fn try_from_fn<E>(mut f: impl FnMut(usize) -> Result<Self::Scalar, E>) -> Result<Self, E> {
124 let mut result = Self::default();
125 for i in 0..Self::WIDTH {
126 let scalar = f(i)?;
127 unsafe {
128 result.set_unchecked(i, scalar);
129 };
130 }
131 Ok(result)
132 }
133
134 #[inline]
140 fn from_scalars(values: impl IntoIterator<Item = Self::Scalar>) -> Self {
141 let mut result = Self::default();
142 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
143 result.set(i, val);
144 }
145 result
146 }
147
148 fn pow(self, exp: u64) -> Self {
150 let mut res = Self::one();
151 for i in (0..64).rev() {
152 res = Square::square(res);
153 if ((exp >> i) & 1) == 1 {
154 res.mul_assign(self)
155 }
156 }
157 res
158 }
159
160 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
176
177 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
190
191 #[inline]
223 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
224 assert!(log_block_len <= Self::LOG_WIDTH);
225 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
226
227 unsafe { self.spread_unchecked(log_block_len, block_idx) }
229 }
230
231 #[inline]
237 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
238 let block_len = 1 << log_block_len;
239 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
240
241 Self::from_scalars(
242 self.iter()
243 .skip(block_idx * block_len)
244 .take(block_len)
245 .flat_map(|elem| iter::repeat_n(elem, repeat)),
246 )
247 }
248}
249
250#[inline]
255pub fn iter_packed_slice_with_offset<P: PackedField>(
256 packed: &[P],
257 offset: usize,
258) -> impl Iterator<Item = P::Scalar> + '_ + Send {
259 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
260 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
261 } else {
262 (&[], 0)
263 };
264
265 P::iter_slice(packed).skip(offset)
266}
267
268#[inline(always)]
269pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
270 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
271
272 unsafe { get_packed_slice_unchecked(packed, i) }
273}
274
275#[inline(always)]
279pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
280 unsafe {
287 packed
288 .get_unchecked(i >> P::LOG_WIDTH)
289 .get_unchecked(i % P::WIDTH)
290 }
291}
292
293#[inline]
297pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
298 packed: &mut [P],
299 i: usize,
300 scalar: P::Scalar,
301) {
302 unsafe {
308 packed
309 .get_unchecked_mut(i >> P::LOG_WIDTH)
310 .set_unchecked(i % P::WIDTH, scalar)
311 }
312}
313
314#[inline]
315pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
316 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
317
318 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
319}
320
321#[inline(always)]
322pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
323 packed.len() << P::LOG_WIDTH
324}
325
326#[inline]
330pub fn packed_from_fn_with_offset<P: PackedField>(
331 offset: usize,
332 mut f: impl FnMut(usize) -> P::Scalar,
333) -> P {
334 P::from_fn(|i| f(i + offset * P::WIDTH))
335}
336
337pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
339 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
340}
341
342pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
344 scalars
345 .chunks(P::WIDTH)
346 .map(|chunk| P::from_scalars(chunk.iter().copied()))
347 .collect()
348}
349
350#[derive(Clone)]
352pub struct PackedSlice<'a, P: PackedField> {
353 slice: &'a [P],
354 len: usize,
355}
356
357impl<'a, P: PackedField> PackedSlice<'a, P> {
358 #[inline(always)]
359 pub fn new(slice: &'a [P]) -> Self {
360 Self {
361 slice,
362 len: len_packed_slice(slice),
363 }
364 }
365
366 #[inline(always)]
367 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
368 assert!(len <= len_packed_slice(slice));
369
370 Self { slice, len }
371 }
372}
373
374impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
375 #[inline(always)]
376 fn len(&self) -> usize {
377 self.len
378 }
379
380 #[inline(always)]
381 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
382 unsafe { get_packed_slice_unchecked(self.slice, index) }
383 }
384}
385
386pub struct PackedSliceMut<'a, P: PackedField> {
388 slice: &'a mut [P],
389 len: usize,
390}
391
392impl<'a, P: PackedField> PackedSliceMut<'a, P> {
393 #[inline(always)]
394 pub fn new(slice: &'a mut [P]) -> Self {
395 let len = len_packed_slice(slice);
396 Self { slice, len }
397 }
398
399 #[inline(always)]
400 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
401 assert!(len <= len_packed_slice(slice));
402
403 Self { slice, len }
404 }
405}
406
407impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
408 #[inline(always)]
409 fn len(&self) -> usize {
410 self.len
411 }
412
413 #[inline(always)]
414 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
415 unsafe { get_packed_slice_unchecked(self.slice, index) }
416 }
417}
418impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
419 #[inline(always)]
420 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
421 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
422 }
423}
424
425impl<F: Field> PackedField for F {
426 const LOG_WIDTH: usize = 0;
427
428 #[inline]
429 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
430 *self
431 }
432
433 #[inline]
434 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
435 *self = scalar;
436 }
437
438 #[inline]
439 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
440 iter::once(*self)
441 }
442
443 #[inline]
444 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
445 iter::once(self)
446 }
447
448 #[inline]
449 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
450 slice.iter().copied()
451 }
452
453 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
454 panic!("cannot interleave when WIDTH = 1");
455 }
456
457 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
458 panic!("cannot transpose when WIDTH = 1");
459 }
460
461 #[inline]
462 fn broadcast(scalar: Self::Scalar) -> Self {
463 scalar
464 }
465
466 #[inline]
467 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
468 f(0)
469 }
470
471 #[inline]
472 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
473 self
474 }
475}
476
477pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
479
480impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
481
482#[cfg(test)]
483mod tests {
484 use itertools::Itertools;
485 use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
486
487 use super::*;
488 use crate::{
489 AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedBinaryGhash1x128b,
490 PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
491 arch::{
492 packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*, packed_32::*,
493 packed_64::*, packed_128::*, packed_256::*, packed_512::*, packed_aes_8::*,
494 packed_aes_16::*, packed_aes_32::*, packed_aes_64::*, packed_aes_128::*,
495 packed_aes_256::*, packed_aes_512::*,
496 },
497 };
498
499 trait PackedFieldTest {
500 fn run<P: PackedField>(&self);
501 }
502
503 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
505 test.run::<BinaryField1b>();
507 test.run::<PackedBinaryField1x1b>();
508 test.run::<PackedBinaryField2x1b>();
509 test.run::<PackedBinaryField4x1b>();
510 test.run::<PackedBinaryField8x1b>();
511 test.run::<PackedBinaryField16x1b>();
512 test.run::<PackedBinaryField32x1b>();
513 test.run::<PackedBinaryField64x1b>();
514 test.run::<PackedBinaryField128x1b>();
515 test.run::<PackedBinaryField256x1b>();
516 test.run::<PackedBinaryField512x1b>();
517
518 test.run::<AESTowerField8b>();
520 test.run::<PackedAESBinaryField1x8b>();
521 test.run::<PackedAESBinaryField2x8b>();
522 test.run::<PackedAESBinaryField4x8b>();
523 test.run::<PackedAESBinaryField8x8b>();
524 test.run::<PackedAESBinaryField16x8b>();
525 test.run::<PackedAESBinaryField32x8b>();
526 test.run::<PackedAESBinaryField64x8b>();
527
528 test.run::<BinaryField128bGhash>();
530 test.run::<PackedBinaryGhash1x128b>();
531 test.run::<PackedBinaryGhash2x128b>();
532 test.run::<PackedBinaryGhash4x128b>();
533 }
534
535 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
536 let packed = P::random(&mut rng);
537 let mut iter = packed.iter();
538 for i in 0..P::WIDTH {
539 assert_eq!(packed.get(i), iter.next().unwrap());
540 }
541 assert!(iter.next().is_none());
542 }
543
544 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
545 let packed = P::random(&mut rng);
546 let mut iter = packed.into_iter();
547 for i in 0..P::WIDTH {
548 assert_eq!(packed.get(i), iter.next().unwrap());
549 }
550 assert!(iter.next().is_none());
551 }
552
553 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
554 for len in [0, 1, 5] {
555 let packed = std::iter::repeat_with(|| P::random(&mut rng))
556 .take(len)
557 .collect::<Vec<_>>();
558
559 let elements_count = len * P::WIDTH;
560 for offset in [
561 0,
562 1,
563 rng.random_range(0..elements_count.max(1)),
564 elements_count.saturating_sub(1),
565 elements_count,
566 ] {
567 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
568 let expected = (offset..elements_count)
569 .map(|i| get_packed_slice(&packed, i))
570 .collect::<Vec<_>>();
571
572 assert_eq!(actual, expected);
573 }
574 }
575 }
576
577 struct PackedFieldIterationTest;
578
579 impl PackedFieldTest for PackedFieldIterationTest {
580 fn run<P: PackedField>(&self) {
581 let mut rng = StdRng::seed_from_u64(0);
582
583 check_value_iteration::<P>(&mut rng);
584 check_ref_iteration::<P>(&mut rng);
585 check_slice_iteration::<P>(&mut rng);
586 }
587 }
588
589 #[test]
590 fn test_iteration() {
591 run_for_all_packed_fields(&PackedFieldIterationTest);
592 }
593
594 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
595 assert_eq!(collection.len(), expected.len());
596
597 for (i, v) in expected.iter().enumerate() {
598 assert_eq!(&collection.get(i), v);
599 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
600 }
601 }
602
603 fn check_collection_get_set<F: Field>(
604 collection: &mut impl RandomAccessSequenceMut<F>,
605 random: &mut impl FnMut() -> F,
606 ) {
607 for i in 0..collection.len() {
608 let value = random();
609 collection.set(i, value);
610 assert_eq!(collection.get(i), value);
611 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
612 }
613 }
614
615 #[test]
616 fn check_packed_slice() {
617 let slice: &[PackedAESBinaryField16x8b] = &[];
618 let packed_slice = PackedSlice::new(slice);
619 check_collection(&packed_slice, &[]);
620 let packed_slice = PackedSlice::new_with_len(slice, 0);
621 check_collection(&packed_slice, &[]);
622
623 let mut rng = StdRng::seed_from_u64(0);
624 let slice: &[PackedAESBinaryField16x8b] = &[
625 PackedAESBinaryField16x8b::random(&mut rng),
626 PackedAESBinaryField16x8b::random(&mut rng),
627 ];
628 let packed_slice = PackedSlice::new(slice);
629 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
630
631 let packed_slice = PackedSlice::new_with_len(slice, 3);
632 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
633 }
634
635 #[test]
636 fn check_packed_slice_mut() {
637 let mut rng = StdRng::seed_from_u64(0);
638 let mut random = || AESTowerField8b::random(&mut rng);
639
640 let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
641 let packed_slice = PackedSliceMut::new(slice);
642 check_collection(&packed_slice, &[]);
643 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
644 check_collection(&packed_slice, &[]);
645
646 let mut rng = StdRng::seed_from_u64(0);
647 let slice: &mut [PackedAESBinaryField16x8b] = &mut [
648 PackedAESBinaryField16x8b::random(&mut rng),
649 PackedAESBinaryField16x8b::random(&mut rng),
650 ];
651 let values = PackedField::iter_slice(slice).collect_vec();
652 let mut packed_slice = PackedSliceMut::new(slice);
653 check_collection(&packed_slice, &values);
654 check_collection_get_set(&mut packed_slice, &mut random);
655
656 let values = PackedField::iter_slice(slice).collect_vec();
657 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
658 check_collection(&packed_slice, &values[..3]);
659 check_collection_get_set(&mut packed_slice, &mut random);
660 }
661}