1use std::{
8 fmt::Debug,
9 iter,
10 ops::{Add, AddAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14 iter::IterExtensions,
15 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18
19use super::{PackedExtension, Random, arithmetic_traits::Square};
20use crate::{BinaryField, Field, field::FieldOps};
21
22pub trait PackedField:
28 Default
29 + Debug
30 + Clone
31 + Copy
32 + Eq
33 + Sized
34 + FieldOps<Self::Scalar>
35 + Add<Self::Scalar, Output = Self>
36 + Sub<Self::Scalar, Output = Self>
37 + AddAssign<Self::Scalar>
38 + SubAssign<Self::Scalar>
39 + Send
40 + Sync
41 + Zeroable
42 + Random
43 + 'static
44{
45 type Scalar: Field;
46
47 const LOG_WIDTH: usize;
49
50 const WIDTH: usize = 1 << Self::LOG_WIDTH;
54
55 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
59
60 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
64
65 #[inline]
71 fn get(&self, i: usize) -> Self::Scalar {
72 assert!(i < Self::WIDTH, "index {i} out of range for width {}", Self::WIDTH);
73 unsafe { self.get_unchecked(i) }
75 }
76
77 #[inline]
83 fn set(&mut self, i: usize, scalar: Self::Scalar) {
84 assert!(i < Self::WIDTH, "index {i} out of range for width {}", Self::WIDTH);
85 unsafe { self.set_unchecked(i, scalar) }
87 }
88
89 #[inline]
90 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
91 (0..Self::WIDTH).map_skippable(move |i|
92 unsafe { self.get_unchecked(i) })
94 }
95
96 #[inline]
97 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
98 (0..Self::WIDTH).map_skippable(move |i|
99 unsafe { self.get_unchecked(i) })
101 }
102
103 #[inline]
104 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
105 slice.iter().flat_map(Self::iter)
106 }
107
108 #[inline(always)]
110 fn set_single(scalar: Self::Scalar) -> Self {
111 let mut result = Self::default();
112 result.set(0, scalar);
113
114 result
115 }
116
117 fn broadcast(scalar: Self::Scalar) -> Self;
118
119 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
121
122 fn try_from_fn<E>(mut f: impl FnMut(usize) -> Result<Self::Scalar, E>) -> Result<Self, E> {
124 let mut result = Self::default();
125 for i in 0..Self::WIDTH {
126 let scalar = f(i)?;
127 unsafe {
128 result.set_unchecked(i, scalar);
129 };
130 }
131 Ok(result)
132 }
133
134 #[inline]
140 fn from_scalars(values: impl IntoIterator<Item = Self::Scalar>) -> Self {
141 let mut result = Self::default();
142 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
143 result.set(i, val);
144 }
145 result
146 }
147
148 fn pow(self, exp: u64) -> Self {
150 let mut res = Self::one();
151 for i in (0..64).rev() {
152 res = Square::square(res);
153 if ((exp >> i) & 1) == 1 {
154 res.mul_assign(self)
155 }
156 }
157 res
158 }
159
160 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
176
177 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
190
191 #[inline]
223 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
224 assert!(log_block_len <= Self::LOG_WIDTH);
225 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
226
227 unsafe { self.spread_unchecked(log_block_len, block_idx) }
229 }
230
231 #[inline]
237 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
238 let block_len = 1 << log_block_len;
239 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
240
241 Self::from_scalars(
242 self.iter()
243 .skip(block_idx * block_len)
244 .take(block_len)
245 .flat_map(|elem| iter::repeat_n(elem, repeat)),
246 )
247 }
248}
249
250#[inline]
255pub fn iter_packed_slice_with_offset<P: PackedField>(
256 packed: &[P],
257 offset: usize,
258) -> impl Iterator<Item = P::Scalar> + '_ + Send {
259 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
260 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
261 } else {
262 (&[], 0)
263 };
264
265 P::iter_slice(packed).skip(offset)
266}
267
268#[inline(always)]
269pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
270 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
271
272 unsafe { get_packed_slice_unchecked(packed, i) }
273}
274
275#[inline(always)]
279pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
280 unsafe {
287 packed
288 .get_unchecked(i >> P::LOG_WIDTH)
289 .get_unchecked(i % P::WIDTH)
290 }
291}
292
293#[inline]
297pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
298 packed: &mut [P],
299 i: usize,
300 scalar: P::Scalar,
301) {
302 unsafe {
308 packed
309 .get_unchecked_mut(i >> P::LOG_WIDTH)
310 .set_unchecked(i % P::WIDTH, scalar)
311 }
312}
313
314#[inline]
315pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
316 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
317
318 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
319}
320
321#[inline(always)]
322pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
323 packed.len() << P::LOG_WIDTH
324}
325
326#[inline]
330pub fn packed_from_fn_with_offset<P: PackedField>(
331 offset: usize,
332 mut f: impl FnMut(usize) -> P::Scalar,
333) -> P {
334 P::from_fn(|i| f(i + offset * P::WIDTH))
335}
336
337pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
339 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
340}
341
342pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
344 scalars
345 .chunks(P::WIDTH)
346 .map(|chunk| P::from_scalars(chunk.iter().copied()))
347 .collect()
348}
349
350#[derive(Clone)]
352pub struct PackedSlice<'a, P: PackedField> {
353 slice: &'a [P],
354 len: usize,
355}
356
357impl<'a, P: PackedField> PackedSlice<'a, P> {
358 #[inline(always)]
359 pub fn new(slice: &'a [P]) -> Self {
360 Self {
361 slice,
362 len: len_packed_slice(slice),
363 }
364 }
365
366 #[inline(always)]
367 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
368 assert!(len <= len_packed_slice(slice));
369
370 Self { slice, len }
371 }
372}
373
374impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
375 #[inline(always)]
376 fn len(&self) -> usize {
377 self.len
378 }
379
380 #[inline(always)]
381 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
382 unsafe { get_packed_slice_unchecked(self.slice, index) }
383 }
384}
385
386pub struct PackedSliceMut<'a, P: PackedField> {
388 slice: &'a mut [P],
389 len: usize,
390}
391
392impl<'a, P: PackedField> PackedSliceMut<'a, P> {
393 #[inline(always)]
394 pub fn new(slice: &'a mut [P]) -> Self {
395 let len = len_packed_slice(slice);
396 Self { slice, len }
397 }
398
399 #[inline(always)]
400 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
401 assert!(len <= len_packed_slice(slice));
402
403 Self { slice, len }
404 }
405}
406
407impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
408 #[inline(always)]
409 fn len(&self) -> usize {
410 self.len
411 }
412
413 #[inline(always)]
414 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
415 unsafe { get_packed_slice_unchecked(self.slice, index) }
416 }
417}
418impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
419 #[inline(always)]
420 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
421 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
422 }
423}
424
425impl<F: Field> PackedField for F {
426 type Scalar = F;
427
428 const LOG_WIDTH: usize = 0;
429
430 #[inline]
431 unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
432 *self
433 }
434
435 #[inline]
436 unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
437 *self = scalar;
438 }
439
440 #[inline]
441 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
442 iter::once(*self)
443 }
444
445 #[inline]
446 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
447 iter::once(self)
448 }
449
450 #[inline]
451 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
452 slice.iter().copied()
453 }
454
455 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
456 panic!("cannot interleave when WIDTH = 1");
457 }
458
459 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
460 panic!("cannot transpose when WIDTH = 1");
461 }
462
463 #[inline]
464 fn broadcast(scalar: Self::Scalar) -> Self {
465 scalar
466 }
467
468 #[inline]
469 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
470 f(0)
471 }
472
473 #[inline]
474 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
475 self
476 }
477}
478
479pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
481
482impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
483
484#[cfg(test)]
485mod tests {
486 use itertools::Itertools;
487 use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
488
489 use super::*;
490 use crate::{
491 AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedBinaryGhash1x128b,
492 PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
493 arch::{
494 packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*, packed_32::*,
495 packed_64::*, packed_128::*, packed_256::*, packed_512::*, packed_aes_8::*,
496 packed_aes_16::*, packed_aes_32::*, packed_aes_64::*, packed_aes_128::*,
497 packed_aes_256::*, packed_aes_512::*,
498 },
499 };
500
501 trait PackedFieldTest {
502 fn run<P: PackedField>(&self);
503 }
504
505 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
507 test.run::<BinaryField1b>();
509 test.run::<PackedBinaryField1x1b>();
510 test.run::<PackedBinaryField2x1b>();
511 test.run::<PackedBinaryField4x1b>();
512 test.run::<PackedBinaryField8x1b>();
513 test.run::<PackedBinaryField16x1b>();
514 test.run::<PackedBinaryField32x1b>();
515 test.run::<PackedBinaryField64x1b>();
516 test.run::<PackedBinaryField128x1b>();
517 test.run::<PackedBinaryField256x1b>();
518 test.run::<PackedBinaryField512x1b>();
519
520 test.run::<AESTowerField8b>();
522 test.run::<PackedAESBinaryField1x8b>();
523 test.run::<PackedAESBinaryField2x8b>();
524 test.run::<PackedAESBinaryField4x8b>();
525 test.run::<PackedAESBinaryField8x8b>();
526 test.run::<PackedAESBinaryField16x8b>();
527 test.run::<PackedAESBinaryField32x8b>();
528 test.run::<PackedAESBinaryField64x8b>();
529
530 test.run::<BinaryField128bGhash>();
532 test.run::<PackedBinaryGhash1x128b>();
533 test.run::<PackedBinaryGhash2x128b>();
534 test.run::<PackedBinaryGhash4x128b>();
535 }
536
537 fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
538 let packed = P::random(&mut rng);
539 let mut iter = packed.iter();
540 for i in 0..P::WIDTH {
541 assert_eq!(packed.get(i), iter.next().unwrap());
542 }
543 assert!(iter.next().is_none());
544 }
545
546 fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
547 let packed = P::random(&mut rng);
548 let mut iter = packed.into_iter();
549 for i in 0..P::WIDTH {
550 assert_eq!(packed.get(i), iter.next().unwrap());
551 }
552 assert!(iter.next().is_none());
553 }
554
555 fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
556 for len in [0, 1, 5] {
557 let packed = std::iter::repeat_with(|| P::random(&mut rng))
558 .take(len)
559 .collect::<Vec<_>>();
560
561 let elements_count = len * P::WIDTH;
562 for offset in [
563 0,
564 1,
565 rng.random_range(0..elements_count.max(1)),
566 elements_count.saturating_sub(1),
567 elements_count,
568 ] {
569 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
570 let expected = (offset..elements_count)
571 .map(|i| get_packed_slice(&packed, i))
572 .collect::<Vec<_>>();
573
574 assert_eq!(actual, expected);
575 }
576 }
577 }
578
579 struct PackedFieldIterationTest;
580
581 impl PackedFieldTest for PackedFieldIterationTest {
582 fn run<P: PackedField>(&self) {
583 let mut rng = StdRng::seed_from_u64(0);
584
585 check_value_iteration::<P>(&mut rng);
586 check_ref_iteration::<P>(&mut rng);
587 check_slice_iteration::<P>(&mut rng);
588 }
589 }
590
591 #[test]
592 fn test_iteration() {
593 run_for_all_packed_fields(&PackedFieldIterationTest);
594 }
595
596 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
597 assert_eq!(collection.len(), expected.len());
598
599 for (i, v) in expected.iter().enumerate() {
600 assert_eq!(&collection.get(i), v);
601 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
602 }
603 }
604
605 fn check_collection_get_set<F: Field>(
606 collection: &mut impl RandomAccessSequenceMut<F>,
607 random: &mut impl FnMut() -> F,
608 ) {
609 for i in 0..collection.len() {
610 let value = random();
611 collection.set(i, value);
612 assert_eq!(collection.get(i), value);
613 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
614 }
615 }
616
617 #[test]
618 fn check_packed_slice() {
619 let slice: &[PackedAESBinaryField16x8b] = &[];
620 let packed_slice = PackedSlice::new(slice);
621 check_collection(&packed_slice, &[]);
622 let packed_slice = PackedSlice::new_with_len(slice, 0);
623 check_collection(&packed_slice, &[]);
624
625 let mut rng = StdRng::seed_from_u64(0);
626 let slice: &[PackedAESBinaryField16x8b] = &[
627 PackedAESBinaryField16x8b::random(&mut rng),
628 PackedAESBinaryField16x8b::random(&mut rng),
629 ];
630 let packed_slice = PackedSlice::new(slice);
631 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
632
633 let packed_slice = PackedSlice::new_with_len(slice, 3);
634 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
635 }
636
637 #[test]
638 fn check_packed_slice_mut() {
639 let mut rng = StdRng::seed_from_u64(0);
640 let mut random = || AESTowerField8b::random(&mut rng);
641
642 let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
643 let packed_slice = PackedSliceMut::new(slice);
644 check_collection(&packed_slice, &[]);
645 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
646 check_collection(&packed_slice, &[]);
647
648 let mut rng = StdRng::seed_from_u64(0);
649 let slice: &mut [PackedAESBinaryField16x8b] = &mut [
650 PackedAESBinaryField16x8b::random(&mut rng),
651 PackedAESBinaryField16x8b::random(&mut rng),
652 ];
653 let values = PackedField::iter_slice(slice).collect_vec();
654 let mut packed_slice = PackedSliceMut::new(slice);
655 check_collection(&packed_slice, &values);
656 check_collection_get_set(&mut packed_slice, &mut random);
657
658 let values = PackedField::iter_slice(slice).collect_vec();
659 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
660 check_collection(&packed_slice, &values[..3]);
661 check_collection_get_set(&mut packed_slice, &mut random);
662 }
663}