1use std::{
9 fmt::Debug,
10 iter,
11 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
12};
13
14use binius_utils::{
15 iter::IterExtensions,
16 random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
17};
18use bytemuck::Zeroable;
19
20use super::{PackedExtension, Random, arithmetic_traits::Square};
21use crate::{BinaryField, Divisible, Field, Maskable, WideMul, field::FieldOps};
22
23pub trait PackedField:
29 Default
30 + Debug
31 + Clone
32 + Copy
33 + Eq
34 + Sized
35 + FieldOps
36 + Add<Self::Scalar, Output = Self>
37 + Sub<Self::Scalar, Output = Self>
38 + Mul<Self::Scalar, Output = Self>
39 + AddAssign<Self::Scalar>
40 + SubAssign<Self::Scalar>
41 + MulAssign<Self::Scalar>
42 + Send
43 + Sync
44 + Zeroable
45 + Random
46 + WideMul<Output: Debug + Send + Sync + 'static>
47 + 'static
48 + Divisible<Self::Scalar>
52 + Maskable<Self::Scalar>
54{
55 const LOG_WIDTH: usize = <Self as Divisible<Self::Scalar>>::LOG_N;
59
60 const WIDTH: usize = 1 << Self::LOG_WIDTH;
64
65 #[inline]
66 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
67 (0..Self::WIDTH).map_skippable(move |i|
68 unsafe { self.get_unchecked(i) })
70 }
71
72 #[inline]
73 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
74 (0..Self::WIDTH).map_skippable(move |i|
75 unsafe { self.get_unchecked(i) })
77 }
78
79 #[inline]
80 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
81 slice.iter().flat_map(Self::iter)
82 }
83
84 #[inline(always)]
86 fn set_single(scalar: Self::Scalar) -> Self {
87 let mut result = Self::default();
88 result.set(0, scalar);
89 result
90 }
91
92 fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
94
95 fn try_from_fn<E>(mut f: impl FnMut(usize) -> Result<Self::Scalar, E>) -> Result<Self, E> {
97 let mut result = Self::default();
98 for i in 0..Self::WIDTH {
99 let scalar = f(i)?;
100 unsafe {
101 result.set_unchecked(i, scalar);
102 };
103 }
104 Ok(result)
105 }
106
107 #[inline]
113 fn from_scalars(values: impl IntoIterator<Item = Self::Scalar>) -> Self {
114 let mut result = Self::default();
115 for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
116 result.set(i, val);
117 }
118 result
119 }
120
121 fn pow(self, exp: u64) -> Self {
123 let mut res = Self::one();
124 for i in (0..64).rev() {
125 res = Square::square(res);
126 if ((exp >> i) & 1) == 1 {
127 res.mul_assign(self)
128 }
129 }
130 res
131 }
132
133 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
149
150 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
163
164 #[inline]
196 fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
197 assert!(log_block_len <= Self::LOG_WIDTH);
198 assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
199
200 unsafe { self.spread_unchecked(log_block_len, block_idx) }
202 }
203
204 #[inline]
210 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
211 let block_len = 1 << log_block_len;
212 let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
213
214 Self::from_scalars(
215 self.iter()
216 .skip(block_idx * block_len)
217 .take(block_len)
218 .flat_map(|elem| iter::repeat_n(elem, repeat)),
219 )
220 }
221}
222
223#[inline]
228pub fn iter_packed_slice_with_offset<P: PackedField>(
229 packed: &[P],
230 offset: usize,
231) -> impl Iterator<Item = P::Scalar> + '_ + Send {
232 let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
233 (&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
234 } else {
235 (&[], 0)
236 };
237
238 P::iter_slice(packed).skip(offset)
239}
240
241#[inline(always)]
242pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
243 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
244
245 unsafe { get_packed_slice_unchecked(packed, i) }
246}
247
248#[inline(always)]
252pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
253 unsafe {
260 packed
261 .get_unchecked(i >> P::LOG_WIDTH)
262 .get_unchecked(i % P::WIDTH)
263 }
264}
265
266#[inline]
270pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
271 packed: &mut [P],
272 i: usize,
273 scalar: P::Scalar,
274) {
275 unsafe {
281 packed
282 .get_unchecked_mut(i >> P::LOG_WIDTH)
283 .set_unchecked(i % P::WIDTH, scalar)
284 }
285}
286
287#[inline]
288pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
289 assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
290
291 unsafe { set_packed_slice_unchecked(packed, i, scalar) }
292}
293
294#[inline(always)]
295pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
296 packed.len() << P::LOG_WIDTH
297}
298
299#[inline]
303pub fn packed_from_fn_with_offset<P: PackedField>(
304 offset: usize,
305 mut f: impl FnMut(usize) -> P::Scalar,
306) -> P {
307 P::from_fn(|i| f(i + offset * P::WIDTH))
308}
309
310pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
312 P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
313}
314
315pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
317 scalars
318 .chunks(P::WIDTH)
319 .map(|chunk| P::from_scalars(chunk.iter().copied()))
320 .collect()
321}
322
323#[derive(Clone)]
325pub struct PackedSlice<'a, P: PackedField> {
326 slice: &'a [P],
327 len: usize,
328}
329
330impl<'a, P: PackedField> PackedSlice<'a, P> {
331 #[inline(always)]
332 pub const fn new(slice: &'a [P]) -> Self {
333 Self {
334 slice,
335 len: len_packed_slice(slice),
336 }
337 }
338
339 #[inline(always)]
340 pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
341 assert!(len <= len_packed_slice(slice));
342
343 Self { slice, len }
344 }
345}
346
347impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
348 #[inline(always)]
349 fn len(&self) -> usize {
350 self.len
351 }
352
353 #[inline(always)]
354 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
355 unsafe { get_packed_slice_unchecked(self.slice, index) }
356 }
357}
358
359pub struct PackedSliceMut<'a, P: PackedField> {
361 slice: &'a mut [P],
362 len: usize,
363}
364
365impl<'a, P: PackedField> PackedSliceMut<'a, P> {
366 #[inline(always)]
367 pub const fn new(slice: &'a mut [P]) -> Self {
368 let len = len_packed_slice(slice);
369 Self { slice, len }
370 }
371
372 #[inline(always)]
373 pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
374 assert!(len <= len_packed_slice(slice));
375
376 Self { slice, len }
377 }
378}
379
380impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
381 #[inline(always)]
382 fn len(&self) -> usize {
383 self.len
384 }
385
386 #[inline(always)]
387 unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
388 unsafe { get_packed_slice_unchecked(self.slice, index) }
389 }
390}
391impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
392 #[inline(always)]
393 unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
394 unsafe { set_packed_slice_unchecked(self.slice, index, value) }
395 }
396}
397
398impl<F: Field> PackedField for F {
399 #[inline]
404 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
405 iter::once(*self)
406 }
407
408 #[inline]
409 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
410 iter::once(self)
411 }
412
413 #[inline]
414 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
415 slice.iter().copied()
416 }
417
418 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
419 panic!("cannot interleave when WIDTH = 1");
420 }
421
422 fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
423 panic!("cannot transpose when WIDTH = 1");
424 }
425
426 #[inline]
427 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
428 f(0)
429 }
430
431 #[inline]
432 unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
433 self
434 }
435}
436
437pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
439
440impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
441
442#[cfg(test)]
443mod tests {
444 use itertools::Itertools;
445 use rand::prelude::*;
446
447 use super::*;
448 use crate::{
449 AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedAESBinaryField1x8b,
450 PackedAESBinaryField16x8b, PackedAESBinaryField32x8b, PackedAESBinaryField64x8b,
451 PackedBinaryField1x1b, PackedBinaryField2x1b, PackedBinaryField4x1b, PackedBinaryField8x1b,
452 PackedBinaryField16x1b, PackedBinaryField32x1b, PackedBinaryField64x1b,
453 PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b,
454 PackedBinaryGhash1x128b, PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
455 };
456
457 trait PackedFieldTest {
458 fn run<P: PackedField>(&self);
459 }
460
461 fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
463 test.run::<BinaryField1b>();
465 test.run::<PackedBinaryField1x1b>();
466 test.run::<PackedBinaryField2x1b>();
467 test.run::<PackedBinaryField4x1b>();
468 test.run::<PackedBinaryField8x1b>();
469 test.run::<PackedBinaryField16x1b>();
470 test.run::<PackedBinaryField32x1b>();
471 test.run::<PackedBinaryField64x1b>();
472 test.run::<PackedBinaryField128x1b>();
473 test.run::<PackedBinaryField256x1b>();
474 test.run::<PackedBinaryField512x1b>();
475
476 test.run::<AESTowerField8b>();
478 test.run::<PackedAESBinaryField1x8b>();
479 test.run::<PackedAESBinaryField16x8b>();
480 test.run::<PackedAESBinaryField32x8b>();
481 test.run::<PackedAESBinaryField64x8b>();
482
483 test.run::<BinaryField128bGhash>();
485 test.run::<PackedBinaryGhash1x128b>();
486 test.run::<PackedBinaryGhash2x128b>();
487 test.run::<PackedBinaryGhash4x128b>();
488 }
489
490 fn check_value_iteration<P: PackedField>(mut rng: impl Rng) {
491 let packed = P::random(&mut rng);
492 let mut iter = packed.iter();
493 for i in 0..P::WIDTH {
494 assert_eq!(packed.get(i), iter.next().unwrap());
495 }
496 assert!(iter.next().is_none());
497 }
498
499 fn check_ref_iteration<P: PackedField>(mut rng: impl Rng) {
500 let packed = P::random(&mut rng);
501 let mut iter = packed.into_iter();
502 for i in 0..P::WIDTH {
503 assert_eq!(packed.get(i), iter.next().unwrap());
504 }
505 assert!(iter.next().is_none());
506 }
507
508 fn check_slice_iteration<P: PackedField>(mut rng: impl Rng) {
509 for len in [0, 1, 5] {
510 let packed = std::iter::repeat_with(|| P::random(&mut rng))
511 .take(len)
512 .collect::<Vec<_>>();
513
514 let elements_count = len * P::WIDTH;
515 for offset in [
516 0,
517 1,
518 rng.random_range(0..elements_count.max(1)),
519 elements_count.saturating_sub(1),
520 elements_count,
521 ] {
522 let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
523 let expected = (offset..elements_count)
524 .map(|i| get_packed_slice(&packed, i))
525 .collect::<Vec<_>>();
526
527 assert_eq!(actual, expected);
528 }
529 }
530 }
531
532 struct PackedFieldIterationTest;
533
534 impl PackedFieldTest for PackedFieldIterationTest {
535 fn run<P: PackedField>(&self) {
536 let mut rng = StdRng::seed_from_u64(0);
537
538 check_value_iteration::<P>(&mut rng);
539 check_ref_iteration::<P>(&mut rng);
540 check_slice_iteration::<P>(&mut rng);
541 }
542 }
543
544 #[test]
545 fn test_iteration() {
546 run_for_all_packed_fields(&PackedFieldIterationTest);
547 }
548
549 fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
550 assert_eq!(collection.len(), expected.len());
551
552 for (i, v) in expected.iter().enumerate() {
553 assert_eq!(&collection.get(i), v);
554 assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
555 }
556 }
557
558 fn check_collection_get_set<F: Field>(
559 collection: &mut impl RandomAccessSequenceMut<F>,
560 random: &mut impl FnMut() -> F,
561 ) {
562 for i in 0..collection.len() {
563 let value = random();
564 collection.set(i, value);
565 assert_eq!(collection.get(i), value);
566 assert_eq!(unsafe { collection.get_unchecked(i) }, value);
567 }
568 }
569
570 #[test]
571 fn check_packed_slice() {
572 let slice: &[PackedAESBinaryField16x8b] = &[];
573 let packed_slice = PackedSlice::new(slice);
574 check_collection(&packed_slice, &[]);
575 let packed_slice = PackedSlice::new_with_len(slice, 0);
576 check_collection(&packed_slice, &[]);
577
578 let mut rng = StdRng::seed_from_u64(0);
579 let slice: &[PackedAESBinaryField16x8b] = &[
580 PackedAESBinaryField16x8b::random(&mut rng),
581 PackedAESBinaryField16x8b::random(&mut rng),
582 ];
583 let packed_slice = PackedSlice::new(slice);
584 check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
585
586 let packed_slice = PackedSlice::new_with_len(slice, 3);
587 check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
588 }
589
590 #[test]
591 fn check_packed_slice_mut() {
592 let mut rng = StdRng::seed_from_u64(0);
593 let mut random = || AESTowerField8b::random(&mut rng);
594
595 let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
596 let packed_slice = PackedSliceMut::new(slice);
597 check_collection(&packed_slice, &[]);
598 let packed_slice = PackedSliceMut::new_with_len(slice, 0);
599 check_collection(&packed_slice, &[]);
600
601 let mut rng = StdRng::seed_from_u64(0);
602 let slice: &mut [PackedAESBinaryField16x8b] = &mut [
603 PackedAESBinaryField16x8b::random(&mut rng),
604 PackedAESBinaryField16x8b::random(&mut rng),
605 ];
606 let values = PackedField::iter_slice(slice).collect_vec();
607 let mut packed_slice = PackedSliceMut::new(slice);
608 check_collection(&packed_slice, &values);
609 check_collection_get_set(&mut packed_slice, &mut random);
610
611 let values = PackedField::iter_slice(slice).collect_vec();
612 let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
613 check_collection(&packed_slice, &values[..3]);
614 check_collection_get_set(&mut packed_slice, &mut random);
615 }
616}