1use std::{
9 fmt::Debug,
10 iter,
11 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
12};
13
14use binius_utils::{
15 iter::IterExtensions,
16 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
17};
18use bytemuck::Zeroable;
19
20use super::{PackedExtension, Random, arithmetic_traits::Square};
21use crate::{BinaryField, Field, WideMul, field::FieldOps};
22
23pub trait PackedField:
29 Default
30 + Debug
31 + Clone
32 + Copy
33 + Eq
34 + Sized
35 + FieldOps
36 + Add<Self::Scalar, Output = Self>
37 + Sub<Self::Scalar, Output = Self>
38 + Mul<Self::Scalar, Output = Self>
39 + AddAssign<Self::Scalar>
40 + SubAssign<Self::Scalar>
41 + MulAssign<Self::Scalar>
42 + Send
43 + Sync
44 + Zeroable
45 + Random
46 + WideMul<Output: Debug + Send + Sync + 'static>
47 + 'static
48{
49 const LOG_WIDTH: usize;
51
52 const WIDTH: usize = 1 << Self::LOG_WIDTH;
56
57 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
61
62 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
66
67 #[inline]
73 fn get(&self, i: usize) -> Self::Scalar {
74 assert!(i < Self::WIDTH, "index {i} out of range for width {}", Self::WIDTH);
75 unsafe { self.get_unchecked(i) }
77 }
78
79 #[inline]
85 fn set(&mut self, i: usize, scalar: Self::Scalar) {
86 assert!(i < Self::WIDTH, "index {i} out of range for width {}", Self::WIDTH);
87 unsafe { self.set_unchecked(i, scalar) }
89 }
90
91 #[inline]
92 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
93 (0..Self::WIDTH).map_skippable(move |i|
94 unsafe { self.get_unchecked(i) })
96 }
97
98 #[inline]
99 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
100 (0..Self::WIDTH).map_skippable(move |i|
101 unsafe { self.get_unchecked(i) })
103 }
104
105 #[inline]
106 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
107 slice.iter().flat_map(Self::iter)
108 }
109
110 #[inline(always)]
112 fn set_single(scalar: Self::Scalar) -> Self {
113 let mut result = Self::default();
114 result.set(0, scalar);
115
116 result
117 }
118
119 fn broadcast(scalar: Self::Scalar) -> Self;
120
121 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
123
124 fn try_from_fn<E>(mut f: impl FnMut(usize) -> Result<Self::Scalar, E>) -> Result<Self, E> {
126 let mut result = Self::default();
127 for i in 0..Self::WIDTH {
128 let scalar = f(i)?;
129 unsafe {
130 result.set_unchecked(i, scalar);
131 };
132 }
133 Ok(result)
134 }
135
136 #[inline]
142 fn from_scalars(values: impl IntoIterator<Item = Self::Scalar>) -> Self {
143 let mut result = Self::default();
144 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
145 result.set(i, val);
146 }
147 result
148 }
149
150 fn pow(self, exp: u64) -> Self {
152 let mut res = Self::one();
153 for i in (0..64).rev() {
154 res = Square::square(res);
155 if ((exp >> i) & 1) == 1 {
156 res.mul_assign(self)
157 }
158 }
159 res
160 }
161
162 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
178
179 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
192
193 #[inline]
225 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
226 assert!(log_block_len <= Self::LOG_WIDTH);
227 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
228
229 unsafe { self.spread_unchecked(log_block_len, block_idx) }
231 }
232
233 #[inline]
239 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
240 let block_len = 1 << log_block_len;
241 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
242
243 Self::from_scalars(
244 self.iter()
245 .skip(block_idx * block_len)
246 .take(block_len)
247 .flat_map(|elem| iter::repeat_n(elem, repeat)),
248 )
249 }
250}
251
252#[inline]
257pub fn iter_packed_slice_with_offset<P: PackedField>(
258 packed: &[P],
259 offset: usize,
260) -> impl Iterator<Item = P::Scalar> + '_ + Send {
261 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
262 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
263 } else {
264 (&[], 0)
265 };
266
267 P::iter_slice(packed).skip(offset)
268}
269
270#[inline(always)]
271pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
272 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
273
274 unsafe { get_packed_slice_unchecked(packed, i) }
275}
276
277#[inline(always)]
281pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
282 unsafe {
289 packed
290 .get_unchecked(i >> P::LOG_WIDTH)
291 .get_unchecked(i % P::WIDTH)
292 }
293}
294
295#[inline]
299pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
300 packed: &mut [P],
301 i: usize,
302 scalar: P::Scalar,
303) {
304 unsafe {
310 packed
311 .get_unchecked_mut(i >> P::LOG_WIDTH)
312 .set_unchecked(i % P::WIDTH, scalar)
313 }
314}
315
316#[inline]
317pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
318 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
319
320 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
321}
322
323#[inline(always)]
324pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
325 packed.len() << P::LOG_WIDTH
326}
327
328#[inline]
332pub fn packed_from_fn_with_offset<P: PackedField>(
333 offset: usize,
334 mut f: impl FnMut(usize) -> P::Scalar,
335) -> P {
336 P::from_fn(|i| f(i + offset * P::WIDTH))
337}
338
339pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
341 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
342}
343
344pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
346 scalars
347 .chunks(P::WIDTH)
348 .map(|chunk| P::from_scalars(chunk.iter().copied()))
349 .collect()
350}
351
352#[derive(Clone)]
354pub struct PackedSlice<'a, P: PackedField> {
355 slice: &'a [P],
356 len: usize,
357}
358
359impl<'a, P: PackedField> PackedSlice<'a, P> {
360 #[inline(always)]
361 pub fn new(slice: &'a [P]) -> Self {
362 Self {
363 slice,
364 len: len_packed_slice(slice),
365 }
366 }
367
368 #[inline(always)]
369 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
370 assert!(len <= len_packed_slice(slice));
371
372 Self { slice, len }
373 }
374}
375
376impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
377 #[inline(always)]
378 fn len(&self) -> usize {
379 self.len
380 }
381
382 #[inline(always)]
383 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
384 unsafe { get_packed_slice_unchecked(self.slice, index) }
385 }
386}
387
388pub struct PackedSliceMut<'a, P: PackedField> {
390 slice: &'a mut [P],
391 len: usize,
392}
393
394impl<'a, P: PackedField> PackedSliceMut<'a, P> {
395 #[inline(always)]
396 pub fn new(slice: &'a mut [P]) -> Self {
397 let len = len_packed_slice(slice);
398 Self { slice, len }
399 }
400
401 #[inline(always)]
402 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
403 assert!(len <= len_packed_slice(slice));
404
405 Self { slice, len }
406 }
407}
408
409impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
410 #[inline(always)]
411 fn len(&self) -> usize {
412 self.len
413 }
414
415 #[inline(always)]
416 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
417 unsafe { get_packed_slice_unchecked(self.slice, index) }
418 }
419}
420impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
421 #[inline(always)]
422 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
423 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
424 }
425}
426
427impl<F: Field> PackedField for F {
428 const LOG_WIDTH: usize = 0;
429
430 #[inline]
431 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
432 *self
433 }
434
435 #[inline]
436 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
437 *self = scalar;
438 }
439
440 #[inline]
441 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
442 iter::once(*self)
443 }
444
445 #[inline]
446 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
447 iter::once(self)
448 }
449
450 #[inline]
451 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
452 slice.iter().copied()
453 }
454
455 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
456 panic!("cannot interleave when WIDTH = 1");
457 }
458
459 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
460 panic!("cannot transpose when WIDTH = 1");
461 }
462
463 #[inline]
464 fn broadcast(scalar: Self::Scalar) -> Self {
465 scalar
466 }
467
468 #[inline]
469 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
470 f(0)
471 }
472
473 #[inline]
474 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
475 self
476 }
477}
478
479pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
481
482impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
483
484#[cfg(test)]
485mod tests {
486 use itertools::Itertools;
487 use rand::prelude::*;
488
489 use super::*;
490 use crate::{
491 AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedBinaryGhash1x128b,
492 PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
493 arch::{
494 packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*, packed_32::*,
495 packed_64::*, packed_128::*, packed_256::*, packed_512::*, packed_aes_8::*,
496 packed_aes_16::*, packed_aes_32::*, packed_aes_64::*, packed_aes_128::*,
497 packed_aes_256::*, packed_aes_512::*,
498 },
499 };
500
501 trait PackedFieldTest {
502 fn run<P: PackedField>(&self);
503 }
504
505 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
507 test.run::<BinaryField1b>();
509 test.run::<PackedBinaryField1x1b>();
510 test.run::<PackedBinaryField2x1b>();
511 test.run::<PackedBinaryField4x1b>();
512 test.run::<PackedBinaryField8x1b>();
513 test.run::<PackedBinaryField16x1b>();
514 test.run::<PackedBinaryField32x1b>();
515 test.run::<PackedBinaryField64x1b>();
516 test.run::<PackedBinaryField128x1b>();
517 test.run::<PackedBinaryField256x1b>();
518 test.run::<PackedBinaryField512x1b>();
519
520 test.run::<AESTowerField8b>();
522 test.run::<PackedAESBinaryField1x8b>();
523 test.run::<PackedAESBinaryField2x8b>();
524 test.run::<PackedAESBinaryField4x8b>();
525 test.run::<PackedAESBinaryField8x8b>();
526 test.run::<PackedAESBinaryField16x8b>();
527 test.run::<PackedAESBinaryField32x8b>();
528 test.run::<PackedAESBinaryField64x8b>();
529
530 test.run::<BinaryField128bGhash>();
532 test.run::<PackedBinaryGhash1x128b>();
533 test.run::<PackedBinaryGhash2x128b>();
534 test.run::<PackedBinaryGhash4x128b>();
535 }
536
537 fn check_value_iteration<P: PackedField>(mut rng: impl Rng) {
538 let packed = P::random(&mut rng);
539 let mut iter = packed.iter();
540 for i in 0..P::WIDTH {
541 assert_eq!(packed.get(i), iter.next().unwrap());
542 }
543 assert!(iter.next().is_none());
544 }
545
546 fn check_ref_iteration<P: PackedField>(mut rng: impl Rng) {
547 let packed = P::random(&mut rng);
548 let mut iter = packed.into_iter();
549 for i in 0..P::WIDTH {
550 assert_eq!(packed.get(i), iter.next().unwrap());
551 }
552 assert!(iter.next().is_none());
553 }
554
555 fn check_slice_iteration<P: PackedField>(mut rng: impl Rng) {
556 for len in [0, 1, 5] {
557 let packed = std::iter::repeat_with(|| P::random(&mut rng))
558 .take(len)
559 .collect::<Vec<_>>();
560
561 let elements_count = len * P::WIDTH;
562 for offset in [
563 0,
564 1,
565 rng.random_range(0..elements_count.max(1)),
566 elements_count.saturating_sub(1),
567 elements_count,
568 ] {
569 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
570 let expected = (offset..elements_count)
571 .map(|i| get_packed_slice(&packed, i))
572 .collect::<Vec<_>>();
573
574 assert_eq!(actual, expected);
575 }
576 }
577 }
578
579 struct PackedFieldIterationTest;
580
581 impl PackedFieldTest for PackedFieldIterationTest {
582 fn run<P: PackedField>(&self) {
583 let mut rng = StdRng::seed_from_u64(0);
584
585 check_value_iteration::<P>(&mut rng);
586 check_ref_iteration::<P>(&mut rng);
587 check_slice_iteration::<P>(&mut rng);
588 }
589 }
590
591 #[test]
592 fn test_iteration() {
593 run_for_all_packed_fields(&PackedFieldIterationTest);
594 }
595
596 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
597 assert_eq!(collection.len(), expected.len());
598
599 for (i, v) in expected.iter().enumerate() {
600 assert_eq!(&collection.get(i), v);
601 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
602 }
603 }
604
605 fn check_collection_get_set<F: Field>(
606 collection: &mut impl RandomAccessSequenceMut<F>,
607 random: &mut impl FnMut() -> F,
608 ) {
609 for i in 0..collection.len() {
610 let value = random();
611 collection.set(i, value);
612 assert_eq!(collection.get(i), value);
613 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
614 }
615 }
616
617 #[test]
618 fn check_packed_slice() {
619 let slice: &[PackedAESBinaryField16x8b] = &[];
620 let packed_slice = PackedSlice::new(slice);
621 check_collection(&packed_slice, &[]);
622 let packed_slice = PackedSlice::new_with_len(slice, 0);
623 check_collection(&packed_slice, &[]);
624
625 let mut rng = StdRng::seed_from_u64(0);
626 let slice: &[PackedAESBinaryField16x8b] = &[
627 PackedAESBinaryField16x8b::random(&mut rng),
628 PackedAESBinaryField16x8b::random(&mut rng),
629 ];
630 let packed_slice = PackedSlice::new(slice);
631 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
632
633 let packed_slice = PackedSlice::new_with_len(slice, 3);
634 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
635 }
636
637 #[test]
638 fn check_packed_slice_mut() {
639 let mut rng = StdRng::seed_from_u64(0);
640 let mut random = || AESTowerField8b::random(&mut rng);
641
642 let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
643 let packed_slice = PackedSliceMut::new(slice);
644 check_collection(&packed_slice, &[]);
645 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
646 check_collection(&packed_slice, &[]);
647
648 let mut rng = StdRng::seed_from_u64(0);
649 let slice: &mut [PackedAESBinaryField16x8b] = &mut [
650 PackedAESBinaryField16x8b::random(&mut rng),
651 PackedAESBinaryField16x8b::random(&mut rng),
652 ];
653 let values = PackedField::iter_slice(slice).collect_vec();
654 let mut packed_slice = PackedSliceMut::new(slice);
655 check_collection(&packed_slice, &values);
656 check_collection_get_set(&mut packed_slice, &mut random);
657
658 let values = PackedField::iter_slice(slice).collect_vec();
659 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
660 check_collection(&packed_slice, &values[..3]);
661 check_collection_get_set(&mut packed_slice, &mut random);
662 }
663}