1use std::{
4 array,
5 fmt::Debug,
6 iter::{zip, Product, Sum},
7 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_utils::checked_arithmetics::checked_log_2;
11use bytemuck::{Pod, Zeroable};
12
13use super::{invert::invert_or_zero, multiply::mul, square::square};
14use crate::{
15 as_packed_field::{PackScalar, PackedType},
16 binary_field::BinaryField,
17 linear_transformation::{
18 FieldLinearTransformation, IDTransformation, PackedTransformationFactory, Transformation,
19 },
20 packed_aes_field::PackedAESBinaryField32x8b,
21 tower_levels::{TowerLevel, TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8},
22 underlier::{UnderlierWithBitOps, WithUnderlier},
23 AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
24 BinaryField1b, ExtensionField, PackedAESBinaryField16x8b, PackedAESBinaryField64x8b,
25 PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedExtension,
26 PackedField,
27};
28
29pub struct TransformationWrapperNxN<Inner, const N: usize>([[Inner; N]; N]);
33
34macro_rules! define_byte_sliced_3d {
44 ($name:ident, $scalar_type:ty, $packed_storage:ty, $scalar_tower_level: ty, $storage_tower_level: ty) => {
45 #[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
46 #[repr(transparent)]
47 pub struct $name {
48 pub(super) data: [[$packed_storage; <$scalar_tower_level as TowerLevel>::WIDTH]; <$storage_tower_level as TowerLevel>::WIDTH / <$scalar_tower_level as TowerLevel>::WIDTH],
49 }
50
51 impl $name {
52 pub const BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
53
54 const SCALAR_BYTES: usize = <$scalar_type>::N_BITS / 8;
55 pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
56 const HEIGHT: usize = Self::HEIGHT_BYTES / Self::SCALAR_BYTES;
57 const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT);
58
59 #[allow(clippy::modulo_one)]
64 #[inline(always)]
65 pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
66 let row = byte_index % Self::HEIGHT_BYTES;
67 self.data
68 .get_unchecked(row / Self::SCALAR_BYTES)
69 .get_unchecked(row % Self::SCALAR_BYTES)
70 .get_unchecked(byte_index / Self::HEIGHT_BYTES)
71 .to_underlier()
72 }
73
74 #[inline]
76 pub fn transpose_to(&self, out: &mut [<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES]) {
77 let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
78 *underliers = bytemuck::must_cast(self.data);
79
80 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(underliers);
81 }
82
83 #[inline]
85 pub fn transpose_from(
86 underliers: &[<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES],
87 ) -> Self {
88 let mut result = Self {
89 data: bytemuck::must_cast(*underliers),
90 };
91
92 <$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<$storage_tower_level>(bytemuck::must_cast_mut(&mut result.data));
93
94 result
95 }
96 }
97
98 impl Default for $name {
99 fn default() -> Self {
100 Self {
101 data: bytemuck::Zeroable::zeroed(),
102 }
103 }
104 }
105
106 impl Debug for $name {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 let values_str = self
109 .iter()
110 .map(|value| format!("{}", value))
111 .collect::<Vec<_>>()
112 .join(",");
113
114 write!(f, "ByteSlicedAES{}x{}([{}])", Self::WIDTH, <$scalar_type>::N_BITS, values_str)
115 }
116 }
117
118 impl PackedField for $name {
119 type Scalar = $scalar_type;
120
121 const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
122
123 #[allow(clippy::modulo_one)]
124 #[inline(always)]
125 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
126 let element_rows = self.data.get_unchecked(i % Self::HEIGHT);
127 Self::Scalar::from_bases((0..Self::SCALAR_BYTES).map(|byte_index| {
128 element_rows
129 .get_unchecked(byte_index)
130 .get_unchecked(i / Self::HEIGHT)
131 }))
132 .expect("byte index is within bounds")
133 }
134
135 #[allow(clippy::modulo_one)]
136 #[inline(always)]
137 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
138 let element_rows = self.data.get_unchecked_mut(i % Self::HEIGHT);
139 for byte_index in 0..Self::SCALAR_BYTES {
140 element_rows
141 .get_unchecked_mut(byte_index)
142 .set_unchecked(
143 i / Self::HEIGHT,
144 scalar.get_base_unchecked(byte_index),
145 );
146 }
147 }
148
149 fn random(mut rng: impl rand::RngCore) -> Self {
150 let data = array::from_fn(|_| array::from_fn(|_| <$packed_storage>::random(&mut rng)));
151 Self { data }
152 }
153
154 #[allow(unreachable_patterns)]
155 #[inline]
156 fn broadcast(scalar: Self::Scalar) -> Self {
157 let data: [[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT] = match Self::SCALAR_BYTES {
158 1 => {
159 let packed_broadcast =
160 <$packed_storage>::broadcast(unsafe { scalar.get_base_unchecked(0) });
161 array::from_fn(|_| array::from_fn(|_| packed_broadcast))
162 }
163 Self::HEIGHT_BYTES => array::from_fn(|_| array::from_fn(|byte_index| {
164 <$packed_storage>::broadcast(unsafe {
165 scalar.get_base_unchecked(byte_index)
166 })
167 })),
168 _ => {
169 let mut data = <[[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT]>::zeroed();
170 for byte_index in 0..Self::SCALAR_BYTES {
171 let broadcast = <$packed_storage>::broadcast(unsafe {
172 scalar.get_base_unchecked(byte_index)
173 });
174
175 for i in 0..Self::HEIGHT {
176 data[i][byte_index] = broadcast;
177 }
178 }
179
180 data
181 }
182 };
183
184 Self { data }
185 }
186
187 #[inline]
188 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
189 let mut result = Self::default();
190
191 for i in 0..Self::WIDTH {
193 unsafe { result.set_unchecked(i, f(i)) };
195 }
196
197 result
198 }
199
200 #[inline]
201 fn square(self) -> Self {
202 let mut result = Self::default();
203
204 for i in 0..Self::HEIGHT {
205 square::<$packed_storage, $scalar_tower_level>(
206 &self.data[i],
207 &mut result.data[i],
208 );
209 }
210
211 result
212 }
213
214 #[inline]
215 fn invert_or_zero(self) -> Self {
216 let mut result = Self::default();
217
218 for i in 0..Self::HEIGHT {
219 invert_or_zero::<$packed_storage, $scalar_tower_level>(
220 &self.data[i],
221 &mut result.data[i],
222 );
223 }
224
225 result
226 }
227
228 #[inline(always)]
229 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
230
231 let self_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&self.data);
232 let other_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&other.data);
233
234 let (data_1, data_2) = interleave_byte_sliced(self_data, other_data, log_block_len + checked_log_2(Self::SCALAR_BYTES));
236 (
237 Self {
238 data: bytemuck::must_cast(data_1),
239 },
240 Self {
241 data: bytemuck::must_cast(data_2),
242 },
243 )
244 }
245
246 #[inline]
247 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
248 let (result1, result2) = unzip_byte_sliced::<$packed_storage, {Self::HEIGHT_BYTES}, {Self::SCALAR_BYTES}>(bytemuck::must_cast_ref(&self.data), bytemuck::must_cast_ref(&other.data), log_block_len + checked_log_2(Self::SCALAR_BYTES));
249
250 (
251 Self {
252 data: bytemuck::must_cast(result1),
253 },
254 Self {
255 data: bytemuck::must_cast(result2),
256 },
257 )
258 }
259 }
260
261 impl Mul for $name {
262 type Output = Self;
263
264 #[inline]
265 fn mul(self, rhs: Self) -> Self {
266 let mut result = Self::default();
267
268 for i in 0..Self::HEIGHT {
269 mul::<$packed_storage, $scalar_tower_level>(
270 &self.data[i],
271 &rhs.data[i],
272 &mut result.data[i],
273 );
274 }
275
276 result
277 }
278 }
279
280 impl Add for $name {
281 type Output = Self;
282
283 #[inline]
284 fn add(self, rhs: Self) -> Self {
285 Self {
286 data: array::from_fn(|byte_number| {
287 array::from_fn(|column|
288 self.data[byte_number][column] + rhs.data[byte_number][column]
289 )
290 }),
291 }
292 }
293 }
294
295 impl AddAssign for $name {
296 #[inline]
297 fn add_assign(&mut self, rhs: Self) {
298 for (data, rhs) in zip(&mut self.data, &rhs.data) {
299 for (data, rhs) in zip(data, rhs) {
300 *data += *rhs
301 }
302 }
303 }
304 }
305
306 byte_sliced_common!($name, $packed_storage, $scalar_type);
307
308 impl<Inner: Transformation<$packed_storage, $packed_storage>> Transformation<$name, $name> for TransformationWrapperNxN<Inner, {<$scalar_tower_level as TowerLevel>::WIDTH}> {
309 fn transform(&self, data: &$name) -> $name {
310 let mut result = <$name>::default();
311
312 for row in 0..<$name>::SCALAR_BYTES {
313 for col in 0..<$name>::SCALAR_BYTES {
314 let transformation = &self.0[col][row];
315
316 for i in 0..<$name>::HEIGHT {
317 result.data[i][row] += transformation.transform(&data.data[i][col]);
318 }
319 }
320 }
321
322 result
323 }
324 }
325
326 impl PackedTransformationFactory<$name> for $name {
327 type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> = TransformationWrapperNxN<<$packed_storage as PackedTransformationFactory<$packed_storage>>::PackedTransformation::<[AESTowerField8b; 8]>, {<$scalar_tower_level as TowerLevel>::WIDTH}>;
328
329 fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
330 transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
331 ) -> Self::PackedTransformation<Data> {
332 let transformations_8b = array::from_fn(|row| {
333 array::from_fn(|col| {
334 let row = row * 8;
335 let linear_transformation_8b = array::from_fn::<_, 8, _>(|row_8b| unsafe {
336 <<$name as PackedField>::Scalar as ExtensionField<AESTowerField8b>>::get_base_unchecked(&transformation.bases()[row + row_8b], col)
337 });
338
339 <$packed_storage as PackedTransformationFactory<$packed_storage
340 >>::make_packed_transformation(FieldLinearTransformation::new(linear_transformation_8b))
341 })
342 });
343
344 TransformationWrapperNxN(transformations_8b)
345 }
346 }
347 };
348}
349
350macro_rules! byte_sliced_common {
351 ($name:ident, $packed_storage:ty, $scalar_type:ty) => {
352 impl Add<$scalar_type> for $name {
353 type Output = Self;
354
355 #[inline]
356 fn add(self, rhs: $scalar_type) -> $name {
357 self + Self::broadcast(rhs)
358 }
359 }
360
361 impl AddAssign<$scalar_type> for $name {
362 #[inline]
363 fn add_assign(&mut self, rhs: $scalar_type) {
364 *self += Self::broadcast(rhs)
365 }
366 }
367
368 impl Sub<$scalar_type> for $name {
369 type Output = Self;
370
371 #[inline]
372 fn sub(self, rhs: $scalar_type) -> $name {
373 self.add(rhs)
374 }
375 }
376
377 impl SubAssign<$scalar_type> for $name {
378 #[inline]
379 fn sub_assign(&mut self, rhs: $scalar_type) {
380 self.add_assign(rhs)
381 }
382 }
383
384 impl Mul<$scalar_type> for $name {
385 type Output = Self;
386
387 #[inline]
388 fn mul(self, rhs: $scalar_type) -> $name {
389 self * Self::broadcast(rhs)
390 }
391 }
392
393 impl MulAssign<$scalar_type> for $name {
394 #[inline]
395 fn mul_assign(&mut self, rhs: $scalar_type) {
396 *self *= Self::broadcast(rhs);
397 }
398 }
399
400 impl Sub for $name {
401 type Output = Self;
402
403 #[inline]
404 fn sub(self, rhs: Self) -> Self {
405 self.add(rhs)
406 }
407 }
408
409 impl SubAssign for $name {
410 #[inline]
411 fn sub_assign(&mut self, rhs: Self) {
412 self.add_assign(rhs);
413 }
414 }
415
416 impl MulAssign for $name {
417 #[inline]
418 fn mul_assign(&mut self, rhs: Self) {
419 *self = *self * rhs;
420 }
421 }
422
423 impl Product for $name {
424 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
425 let mut result = Self::one();
426
427 let mut is_first_item = true;
428 for item in iter {
429 if is_first_item {
430 result = item;
431 } else {
432 result *= item;
433 }
434
435 is_first_item = false;
436 }
437
438 result
439 }
440 }
441
442 impl Sum for $name {
443 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
444 let mut result = Self::zero();
445
446 for item in iter {
447 result += item;
448 }
449
450 result
451 }
452 }
453
454 impl PackedExtension<$scalar_type> for $name {
455 type PackedSubfield = Self;
456
457 #[inline(always)]
458 fn cast_bases(packed: &[Self]) -> &[Self::PackedSubfield] {
459 packed
460 }
461
462 #[inline(always)]
463 fn cast_bases_mut(packed: &mut [Self]) -> &mut [Self::PackedSubfield] {
464 packed
465 }
466
467 #[inline(always)]
468 fn cast_exts(packed: &[Self::PackedSubfield]) -> &[Self] {
469 packed
470 }
471
472 #[inline(always)]
473 fn cast_exts_mut(packed: &mut [Self::PackedSubfield]) -> &mut [Self] {
474 packed
475 }
476
477 #[inline(always)]
478 fn cast_base(self) -> Self::PackedSubfield {
479 self
480 }
481
482 #[inline(always)]
483 fn cast_base_ref(&self) -> &Self::PackedSubfield {
484 self
485 }
486
487 #[inline(always)]
488 fn cast_base_mut(&mut self) -> &mut Self::PackedSubfield {
489 self
490 }
491
492 #[inline(always)]
493 fn cast_ext(base: Self::PackedSubfield) -> Self {
494 base
495 }
496
497 #[inline(always)]
498 fn cast_ext_ref(base: &Self::PackedSubfield) -> &Self {
499 base
500 }
501
502 #[inline(always)]
503 fn cast_ext_mut(base: &mut Self::PackedSubfield) -> &mut Self {
504 base
505 }
506 }
507 };
508}
509
510macro_rules! define_byte_sliced_3d_1b {
513 ($name:ident, $packed_storage:ty, $storage_tower_level: ty) => {
514 #[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
515 #[repr(transparent)]
516 pub struct $name {
517 pub(super) data: [$packed_storage; <$storage_tower_level>::WIDTH],
518 }
519
520 impl $name {
521 pub const BYTES: usize =
522 <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
523
524 pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
525 const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT_BYTES);
526
527 #[allow(clippy::modulo_one)]
532 #[inline(always)]
533 pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
534 type Packed8b =
535 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
536
537 Packed8b::cast_ext_ref(self.data.get_unchecked(byte_index % Self::HEIGHT_BYTES))
538 .get_unchecked(byte_index / Self::HEIGHT_BYTES)
539 .to_underlier()
540 }
541
542 #[inline]
544 pub fn transpose_to(
545 &self,
546 out: &mut [PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
547 Self::HEIGHT_BYTES],
548 ) {
549 let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
550 *underliers = WithUnderlier::to_underliers_arr(self.data);
551
552 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
553 underliers,
554 );
555 }
556
557 #[inline]
559 pub fn transpose_from(
560 underliers: &[PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
561 Self::HEIGHT_BYTES],
562 ) -> Self {
563 let mut underliers = WithUnderlier::to_underliers_arr(*underliers);
564
565 <$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<
566 $storage_tower_level,
567 >(&mut underliers);
568
569 Self {
570 data: WithUnderlier::from_underliers_arr(underliers),
571 }
572 }
573 }
574
575 impl Default for $name {
576 fn default() -> Self {
577 Self {
578 data: bytemuck::Zeroable::zeroed(),
579 }
580 }
581 }
582
583 impl Debug for $name {
584 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585 let values_str = self
586 .iter()
587 .map(|value| format!("{}", value))
588 .collect::<Vec<_>>()
589 .join(",");
590
591 write!(f, "ByteSlicedAES{}x1b([{}])", Self::WIDTH, values_str)
592 }
593 }
594
595 impl PackedField for $name {
596 type Scalar = BinaryField1b;
597
598 const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
599
600 #[allow(clippy::modulo_one)]
601 #[inline(always)]
602 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
603 self.data
604 .get_unchecked((i / 8) % Self::HEIGHT_BYTES)
605 .get_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8)
606 }
607
608 #[allow(clippy::modulo_one)]
609 #[inline(always)]
610 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
611 self.data
612 .get_unchecked_mut((i / 8) % Self::HEIGHT_BYTES)
613 .set_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8, scalar);
614 }
615
616 fn random(mut rng: impl rand::RngCore) -> Self {
617 let data = array::from_fn(|_| <$packed_storage>::random(&mut rng));
618 Self { data }
619 }
620
621 #[allow(unreachable_patterns)]
622 #[inline]
623 fn broadcast(scalar: Self::Scalar) -> Self {
624 let underlier = <$packed_storage as WithUnderlier>::Underlier::fill_with_bit(
625 scalar.to_underlier().into(),
626 );
627 Self {
628 data: array::from_fn(|_| WithUnderlier::from_underlier(underlier)),
629 }
630 }
631
632 #[inline]
633 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
634 let data = array::from_fn(|row| {
635 <$packed_storage>::from_fn(|col| {
636 f(row * 8 + 8 * Self::HEIGHT_BYTES * (col / 8) + col % 8)
637 })
638 });
639
640 Self { data }
641 }
642
643 #[inline]
644 fn square(self) -> Self {
645 let data = array::from_fn(|i| self.data[i].clone().square());
646 Self { data }
647 }
648
649 #[inline]
650 fn invert_or_zero(self) -> Self {
651 let data = array::from_fn(|i| self.data[i].clone().invert_or_zero());
652 Self { data }
653 }
654
655 #[inline(always)]
656 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
657 type Packed8b =
658 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
659
660 if log_block_len < 3 {
661 let mut result1 = Self::default();
662 let mut result2 = Self::default();
663
664 for i in 0..Self::HEIGHT_BYTES {
665 (result1.data[i], result2.data[i]) =
666 self.data[i].interleave(other.data[i], log_block_len);
667 }
668
669 (result1, result2)
670 } else {
671 let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
672 Packed8b::cast_ext_arr_ref(&self.data);
673 let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
674 Packed8b::cast_ext_arr_ref(&other.data);
675
676 let (result1, result2) =
677 interleave_byte_sliced(self_data, other_data, log_block_len - 3);
678
679 (
680 Self {
681 data: Packed8b::cast_base_arr(result1),
682 },
683 Self {
684 data: Packed8b::cast_base_arr(result2),
685 },
686 )
687 }
688 }
689
690 #[inline]
691 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
692 if log_block_len < 3 {
693 let mut result1 = Self::default();
694 let mut result2 = Self::default();
695
696 for i in 0..Self::HEIGHT_BYTES {
697 (result1.data[i], result2.data[i]) =
698 self.data[i].unzip(other.data[i], log_block_len);
699 }
700
701 (result1, result2)
702 } else {
703 type Packed8b =
704 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
705
706 let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
707 Packed8b::cast_ext_arr_ref(&self.data);
708 let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
709 Packed8b::cast_ext_arr_ref(&other.data);
710
711 let (result1, result2) = unzip_byte_sliced::<Packed8b, { Self::HEIGHT_BYTES }, 1>(
712 self_data,
713 other_data,
714 log_block_len - 3,
715 );
716
717 (
718 Self {
719 data: Packed8b::cast_base_arr(result1),
720 },
721 Self {
722 data: Packed8b::cast_base_arr(result2),
723 },
724 )
725 }
726 }
727 }
728
729 impl Mul for $name {
730 type Output = Self;
731
732 #[inline]
733 fn mul(self, rhs: Self) -> Self {
734 Self {
735 data: array::from_fn(|i| self.data[i].clone() * rhs.data[i].clone()),
736 }
737 }
738 }
739
740 impl Add for $name {
741 type Output = Self;
742
743 #[inline]
744 fn add(self, rhs: Self) -> Self {
745 Self {
746 data: array::from_fn(|byte_number| {
747 self.data[byte_number] + rhs.data[byte_number]
748 }),
749 }
750 }
751 }
752
753 impl AddAssign for $name {
754 #[inline]
755 fn add_assign(&mut self, rhs: Self) {
756 for (data, rhs) in zip(&mut self.data, &rhs.data) {
757 *data += *rhs
758 }
759 }
760 }
761
762 byte_sliced_common!($name, $packed_storage, BinaryField1b);
763
764 impl PackedTransformationFactory<$name> for $name {
765 type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> =
766 IDTransformation;
767
768 fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
769 _transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
770 ) -> Self::PackedTransformation<Data> {
771 IDTransformation
772 }
773 }
774 };
775}
776
777#[inline(always)]
778fn interleave_internal_block<P: PackedField, const N: usize, const LOG_BLOCK_LEN: usize>(
779 lhs: &[P; N],
780 rhs: &[P; N],
781) -> ([P; N], [P; N]) {
782 debug_assert!(LOG_BLOCK_LEN < checked_log_2(N));
783
784 let result_1 = array::from_fn(|i| {
785 let block_index = i >> (LOG_BLOCK_LEN + 1);
786 let block_start = block_index << (LOG_BLOCK_LEN + 1);
787 let block_offset = i - block_start;
788
789 if block_offset < (1 << LOG_BLOCK_LEN) {
790 lhs[i]
791 } else {
792 rhs[i - (1 << LOG_BLOCK_LEN)]
793 }
794 });
795 let result_2 = array::from_fn(|i| {
796 let block_index = i >> (LOG_BLOCK_LEN + 1);
797 let block_start = block_index << (LOG_BLOCK_LEN + 1);
798 let block_offset = i - block_start;
799
800 if block_offset < (1 << LOG_BLOCK_LEN) {
801 lhs[i + (1 << LOG_BLOCK_LEN)]
802 } else {
803 rhs[i]
804 }
805 });
806
807 (result_1, result_2)
808}
809
810#[inline(always)]
811fn interleave_byte_sliced<P: PackedField, const N: usize>(
812 lhs: &[P; N],
813 rhs: &[P; N],
814 log_block_len: usize,
815) -> ([P; N], [P; N]) {
816 debug_assert!(checked_log_2(N) <= 4);
817
818 match log_block_len {
819 x if x >= checked_log_2(N) => interleave_big_block::<P, N>(lhs, rhs, log_block_len),
820 0 => interleave_internal_block::<P, N, 0>(lhs, rhs),
821 1 => interleave_internal_block::<P, N, 1>(lhs, rhs),
822 2 => interleave_internal_block::<P, N, 2>(lhs, rhs),
823 3 => interleave_internal_block::<P, N, 3>(lhs, rhs),
824 _ => unreachable!(),
825 }
826}
827
828#[inline(always)]
829fn unzip_byte_sliced<P: PackedField, const N: usize, const SCALAR_BYTES: usize>(
830 lhs: &[P; N],
831 rhs: &[P; N],
832 log_block_len: usize,
833) -> ([P; N], [P; N]) {
834 let mut result1: [P; N] = bytemuck::Zeroable::zeroed();
835 let mut result2: [P; N] = bytemuck::Zeroable::zeroed();
836
837 let log_height = checked_log_2(N);
838 if log_block_len < log_height {
839 let block_size = 1 << log_block_len;
840 let half = N / 2;
841 for block_offset in (0..half).step_by(block_size) {
842 let target_offset = block_offset * 2;
843
844 result1[block_offset..block_offset + block_size]
845 .copy_from_slice(&lhs[target_offset..target_offset + block_size]);
846 result1[half + target_offset..half + target_offset + block_size]
847 .copy_from_slice(&rhs[target_offset..target_offset + block_size]);
848
849 result2[block_offset..block_offset + block_size]
850 .copy_from_slice(&lhs[target_offset + block_size..target_offset + 2 * block_size]);
851 result2[half + target_offset..half + target_offset + block_size]
852 .copy_from_slice(&rhs[target_offset + block_size..target_offset + 2 * block_size]);
853 }
854 } else {
855 for i in 0..N {
856 (result1[i], result2[i]) = lhs[i].unzip(rhs[i], log_block_len - log_height);
857 }
858 }
859
860 (result1, result2)
861}
862
863#[inline(always)]
864fn interleave_big_block<P: PackedField, const N: usize>(
865 lhs: &[P; N],
866 rhs: &[P; N],
867 log_block_len: usize,
868) -> ([P; N], [P; N]) {
869 let mut result_1 = <[P; N]>::zeroed();
870 let mut result_2 = <[P; N]>::zeroed();
871
872 for i in 0..N {
873 (result_1[i], result_2[i]) = lhs[i].interleave(rhs[i], log_block_len - checked_log_2(N));
874 }
875
876 (result_1, result_2)
877}
878
879define_byte_sliced_3d!(
881 ByteSlicedAES16x128b,
882 AESTowerField128b,
883 PackedAESBinaryField16x8b,
884 TowerLevel16,
885 TowerLevel16
886);
887define_byte_sliced_3d!(
888 ByteSlicedAES16x64b,
889 AESTowerField64b,
890 PackedAESBinaryField16x8b,
891 TowerLevel8,
892 TowerLevel8
893);
894define_byte_sliced_3d!(
895 ByteSlicedAES2x16x64b,
896 AESTowerField64b,
897 PackedAESBinaryField16x8b,
898 TowerLevel8,
899 TowerLevel16
900);
901define_byte_sliced_3d!(
902 ByteSlicedAES16x32b,
903 AESTowerField32b,
904 PackedAESBinaryField16x8b,
905 TowerLevel4,
906 TowerLevel4
907);
908define_byte_sliced_3d!(
909 ByteSlicedAES4x16x32b,
910 AESTowerField32b,
911 PackedAESBinaryField16x8b,
912 TowerLevel4,
913 TowerLevel16
914);
915define_byte_sliced_3d!(
916 ByteSlicedAES16x16b,
917 AESTowerField16b,
918 PackedAESBinaryField16x8b,
919 TowerLevel2,
920 TowerLevel2
921);
922define_byte_sliced_3d!(
923 ByteSlicedAES8x16x16b,
924 AESTowerField16b,
925 PackedAESBinaryField16x8b,
926 TowerLevel2,
927 TowerLevel16
928);
929define_byte_sliced_3d!(
930 ByteSlicedAES16x8b,
931 AESTowerField8b,
932 PackedAESBinaryField16x8b,
933 TowerLevel1,
934 TowerLevel1
935);
936define_byte_sliced_3d!(
937 ByteSlicedAES16x16x8b,
938 AESTowerField8b,
939 PackedAESBinaryField16x8b,
940 TowerLevel1,
941 TowerLevel16
942);
943
944define_byte_sliced_3d_1b!(ByteSliced16x128x1b, PackedBinaryField128x1b, TowerLevel16);
945define_byte_sliced_3d_1b!(ByteSliced8x128x1b, PackedBinaryField128x1b, TowerLevel8);
946define_byte_sliced_3d_1b!(ByteSliced4x128x1b, PackedBinaryField128x1b, TowerLevel4);
947define_byte_sliced_3d_1b!(ByteSliced2x128x1b, PackedBinaryField128x1b, TowerLevel2);
948define_byte_sliced_3d_1b!(ByteSliced1x128x1b, PackedBinaryField128x1b, TowerLevel1);
949
950define_byte_sliced_3d!(
952 ByteSlicedAES32x128b,
953 AESTowerField128b,
954 PackedAESBinaryField32x8b,
955 TowerLevel16,
956 TowerLevel16
957);
958define_byte_sliced_3d!(
959 ByteSlicedAES32x64b,
960 AESTowerField64b,
961 PackedAESBinaryField32x8b,
962 TowerLevel8,
963 TowerLevel8
964);
965define_byte_sliced_3d!(
966 ByteSlicedAES2x32x64b,
967 AESTowerField64b,
968 PackedAESBinaryField32x8b,
969 TowerLevel8,
970 TowerLevel16
971);
972define_byte_sliced_3d!(
973 ByteSlicedAES32x32b,
974 AESTowerField32b,
975 PackedAESBinaryField32x8b,
976 TowerLevel4,
977 TowerLevel4
978);
979define_byte_sliced_3d!(
980 ByteSlicedAES4x32x32b,
981 AESTowerField32b,
982 PackedAESBinaryField32x8b,
983 TowerLevel4,
984 TowerLevel16
985);
986define_byte_sliced_3d!(
987 ByteSlicedAES32x16b,
988 AESTowerField16b,
989 PackedAESBinaryField32x8b,
990 TowerLevel2,
991 TowerLevel2
992);
993define_byte_sliced_3d!(
994 ByteSlicedAES8x32x16b,
995 AESTowerField16b,
996 PackedAESBinaryField32x8b,
997 TowerLevel2,
998 TowerLevel16
999);
1000define_byte_sliced_3d!(
1001 ByteSlicedAES32x8b,
1002 AESTowerField8b,
1003 PackedAESBinaryField32x8b,
1004 TowerLevel1,
1005 TowerLevel1
1006);
1007define_byte_sliced_3d!(
1008 ByteSlicedAES16x32x8b,
1009 AESTowerField8b,
1010 PackedAESBinaryField32x8b,
1011 TowerLevel1,
1012 TowerLevel16
1013);
1014
1015define_byte_sliced_3d_1b!(ByteSliced16x256x1b, PackedBinaryField256x1b, TowerLevel16);
1016define_byte_sliced_3d_1b!(ByteSliced8x256x1b, PackedBinaryField256x1b, TowerLevel8);
1017define_byte_sliced_3d_1b!(ByteSliced4x256x1b, PackedBinaryField256x1b, TowerLevel4);
1018define_byte_sliced_3d_1b!(ByteSliced2x256x1b, PackedBinaryField256x1b, TowerLevel2);
1019define_byte_sliced_3d_1b!(ByteSliced1x256x1b, PackedBinaryField256x1b, TowerLevel1);
1020
1021define_byte_sliced_3d!(
1023 ByteSlicedAES64x128b,
1024 AESTowerField128b,
1025 PackedAESBinaryField64x8b,
1026 TowerLevel16,
1027 TowerLevel16
1028);
1029define_byte_sliced_3d!(
1030 ByteSlicedAES64x64b,
1031 AESTowerField64b,
1032 PackedAESBinaryField64x8b,
1033 TowerLevel8,
1034 TowerLevel8
1035);
1036define_byte_sliced_3d!(
1037 ByteSlicedAES2x64x64b,
1038 AESTowerField64b,
1039 PackedAESBinaryField64x8b,
1040 TowerLevel8,
1041 TowerLevel16
1042);
1043define_byte_sliced_3d!(
1044 ByteSlicedAES64x32b,
1045 AESTowerField32b,
1046 PackedAESBinaryField64x8b,
1047 TowerLevel4,
1048 TowerLevel4
1049);
1050define_byte_sliced_3d!(
1051 ByteSlicedAES4x64x32b,
1052 AESTowerField32b,
1053 PackedAESBinaryField64x8b,
1054 TowerLevel4,
1055 TowerLevel16
1056);
1057define_byte_sliced_3d!(
1058 ByteSlicedAES64x16b,
1059 AESTowerField16b,
1060 PackedAESBinaryField64x8b,
1061 TowerLevel2,
1062 TowerLevel2
1063);
1064define_byte_sliced_3d!(
1065 ByteSlicedAES8x64x16b,
1066 AESTowerField16b,
1067 PackedAESBinaryField64x8b,
1068 TowerLevel2,
1069 TowerLevel16
1070);
1071define_byte_sliced_3d!(
1072 ByteSlicedAES64x8b,
1073 AESTowerField8b,
1074 PackedAESBinaryField64x8b,
1075 TowerLevel1,
1076 TowerLevel1
1077);
1078define_byte_sliced_3d!(
1079 ByteSlicedAES16x64x8b,
1080 AESTowerField8b,
1081 PackedAESBinaryField64x8b,
1082 TowerLevel1,
1083 TowerLevel16
1084);
1085
1086define_byte_sliced_3d_1b!(ByteSliced16x512x1b, PackedBinaryField512x1b, TowerLevel16);
1087define_byte_sliced_3d_1b!(ByteSliced8x512x1b, PackedBinaryField512x1b, TowerLevel8);
1088define_byte_sliced_3d_1b!(ByteSliced4x512x1b, PackedBinaryField512x1b, TowerLevel4);
1089define_byte_sliced_3d_1b!(ByteSliced2x512x1b, PackedBinaryField512x1b, TowerLevel2);
1090define_byte_sliced_3d_1b!(ByteSliced1x512x1b, PackedBinaryField512x1b, TowerLevel1);
1091
1092macro_rules! impl_packed_extension{
1093 ($packed_ext:ty, $packed_base:ty,) => {
1094 impl PackedExtension<<$packed_base as PackedField>::Scalar> for $packed_ext {
1095 type PackedSubfield = $packed_base;
1096
1097 fn cast_bases(packed: &[Self]) -> &[Self::PackedSubfield] {
1098 bytemuck::must_cast_slice(packed)
1099 }
1100
1101 fn cast_bases_mut(packed: &mut [Self]) -> &mut [Self::PackedSubfield] {
1102 bytemuck::must_cast_slice_mut(packed)
1103 }
1104
1105 fn cast_exts(packed: &[Self::PackedSubfield]) -> &[Self] {
1106 bytemuck::must_cast_slice(packed)
1107 }
1108
1109 fn cast_exts_mut(packed: &mut [Self::PackedSubfield]) -> &mut [Self] {
1110 bytemuck::must_cast_slice_mut(packed)
1111 }
1112
1113 fn cast_base(self) -> Self::PackedSubfield {
1114 bytemuck::must_cast(self)
1115 }
1116
1117 fn cast_base_ref(&self) -> &Self::PackedSubfield {
1118 bytemuck::must_cast_ref(self)
1119 }
1120
1121 fn cast_base_mut(&mut self) -> &mut Self::PackedSubfield {
1122 bytemuck::must_cast_mut(self)
1123 }
1124
1125 fn cast_ext(base: Self::PackedSubfield) -> Self {
1126 bytemuck::must_cast(base)
1127 }
1128
1129 fn cast_ext_ref(base: &Self::PackedSubfield) -> &Self {
1130 bytemuck::must_cast_ref(base)
1131 }
1132
1133 fn cast_ext_mut(base: &mut Self::PackedSubfield) -> &mut Self {
1134 bytemuck::must_cast_mut(base)
1135 }
1136 }
1137 };
1138 (@pairs $head:ty, $next:ty,) => {
1139 impl_packed_extension!($head, $next,);
1140 };
1141 (@pairs $head:ty, $next:ty, $($tail:ty,)*) => {
1142 impl_packed_extension!($head, $next,);
1143 impl_packed_extension!(@pairs $head, $($tail,)*);
1144 };
1145 ($head:ty, $next:ty, $($tail:ty,)*) => {
1146 impl_packed_extension!(@pairs $head, $next, $($tail,)*);
1147 impl_packed_extension!($next, $($tail,)*);
1148 };
1149}
1150
1151impl_packed_extension!(
1152 ByteSlicedAES16x128b,
1153 ByteSlicedAES2x16x64b,
1154 ByteSlicedAES4x16x32b,
1155 ByteSlicedAES8x16x16b,
1156 ByteSlicedAES16x16x8b,
1157 ByteSliced16x128x1b,
1158);
1159
1160impl_packed_extension!(
1161 ByteSlicedAES32x128b,
1162 ByteSlicedAES2x32x64b,
1163 ByteSlicedAES4x32x32b,
1164 ByteSlicedAES8x32x16b,
1165 ByteSlicedAES16x32x8b,
1166 ByteSliced16x256x1b,
1167);
1168
1169impl_packed_extension!(
1170 ByteSlicedAES64x128b,
1171 ByteSlicedAES2x64x64b,
1172 ByteSlicedAES4x64x32b,
1173 ByteSlicedAES8x64x16b,
1174 ByteSlicedAES16x64x8b,
1175 ByteSliced16x512x1b,
1176);