1use std::{
4 array,
5 fmt::Debug,
6 iter::{Product, Sum, zip},
7 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_utils::checked_arithmetics::checked_log_2;
11use bytemuck::{Pod, Zeroable};
12
13use super::{invert::invert_or_zero, multiply::mul, square::square};
14use crate::{
15 AESTowerField8b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField128b,
16 BinaryField1b, ExtensionField, PackedAESBinaryField16x8b, PackedAESBinaryField64x8b,
17 PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedExtension,
18 PackedField,
19 arch::{
20 byte_sliced::underlier::ByteSlicedUnderlier, portable::packed_scaled::ScaledPackedField,
21 },
22 as_packed_field::{PackScalar, PackedType},
23 binary_field::BinaryField,
24 linear_transformation::{
25 FieldLinearTransformation, IDTransformation, PackedTransformationFactory, Transformation,
26 },
27 packed_aes_field::PackedAESBinaryField32x8b,
28 tower_levels::{TowerLevel, TowerLevel1, TowerLevel2, TowerLevel4, TowerLevel8, TowerLevel16},
29 underlier::{UnderlierWithBitOps, WithUnderlier},
30};
31
32pub struct TransformationWrapperNxN<Inner, const N: usize>([[Inner; N]; N]);
36
37macro_rules! define_byte_sliced_3d {
50 ($name:ident, $scalar_type:ty, $packed_storage:ty, $scalar_tower_level: ty, $storage_tower_level: ty) => {
51 #[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
52 #[repr(transparent)]
53 pub struct $name {
54 pub(super) data: [[$packed_storage; <$scalar_tower_level as TowerLevel>::WIDTH]; <$storage_tower_level as TowerLevel>::WIDTH / <$scalar_tower_level as TowerLevel>::WIDTH],
55 }
56
57 impl $name {
58 pub const BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
59
60 const SCALAR_BYTES: usize = <$scalar_type>::N_BITS / 8;
61 pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
62 const HEIGHT: usize = Self::HEIGHT_BYTES / Self::SCALAR_BYTES;
63 const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT);
64
65 #[allow(clippy::modulo_one)]
70 #[inline(always)]
71 pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
72 let row = byte_index % Self::HEIGHT_BYTES;
73 unsafe {
74 self.data
75 .get_unchecked(row / Self::SCALAR_BYTES)
76 .get_unchecked(row % Self::SCALAR_BYTES)
77 .get_unchecked(byte_index / Self::HEIGHT_BYTES)
78 .to_underlier()
79 }
80 }
81
82 #[inline]
84 pub fn transpose_to(&self, out: &mut [<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES]) {
85 let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
86 *underliers = bytemuck::must_cast(self.data);
87
88 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(underliers);
89 }
90
91 #[inline]
93 pub fn transpose_from(
94 underliers: &[<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES],
95 ) -> Self {
96 let mut result = Self {
97 data: bytemuck::must_cast(*underliers),
98 };
99
100 <$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<$storage_tower_level>(bytemuck::must_cast_mut(&mut result.data));
101
102 result
103 }
104 }
105
106 impl Default for $name {
107 fn default() -> Self {
108 Self {
109 data: bytemuck::Zeroable::zeroed(),
110 }
111 }
112 }
113
114 impl Debug for $name {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 let values_str = self
117 .iter()
118 .map(|value| format!("{}", value))
119 .collect::<Vec<_>>()
120 .join(",");
121
122 write!(f, "ByteSlicedAES{}x{}([{}])", Self::WIDTH, <$scalar_type>::N_BITS, values_str)
123 }
124 }
125
126 impl PackedField for $name {
127 type Scalar = $scalar_type;
128
129 const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
130
131 #[allow(clippy::modulo_one)]
132 #[inline(always)]
133 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
134 let element_rows = unsafe { self.data.get_unchecked(i % Self::HEIGHT) };
135 Self::Scalar::from_bases((0..Self::SCALAR_BYTES).map(|byte_index| unsafe {
136 element_rows
137 .get_unchecked(byte_index)
138 .get_unchecked(i / Self::HEIGHT)
139 }))
140 .expect("byte index is within bounds")
141 }
142
143 #[allow(clippy::modulo_one)]
144 #[inline(always)]
145 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
146 let element_rows = unsafe { self.data.get_unchecked_mut(i % Self::HEIGHT) };
147 for byte_index in 0..Self::SCALAR_BYTES {
148 unsafe {
149 element_rows
150 .get_unchecked_mut(byte_index)
151 .set_unchecked(
152 i / Self::HEIGHT,
153 scalar.get_base_unchecked(byte_index),
154 );
155 }
156 }
157 }
158
159 fn random(mut rng: impl rand::RngCore) -> Self {
160 let data = array::from_fn(|_| array::from_fn(|_| <$packed_storage>::random(&mut rng)));
161 Self { data }
162 }
163
164 #[allow(unreachable_patterns)]
165 #[inline]
166 fn broadcast(scalar: Self::Scalar) -> Self {
167 let data: [[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT] = match Self::SCALAR_BYTES {
168 1 => {
169 let packed_broadcast =
170 <$packed_storage>::broadcast(unsafe { scalar.get_base_unchecked(0) });
171 array::from_fn(|_| array::from_fn(|_| packed_broadcast))
172 }
173 Self::HEIGHT_BYTES => array::from_fn(|_| array::from_fn(|byte_index| {
174 <$packed_storage>::broadcast(unsafe {
175 scalar.get_base_unchecked(byte_index)
176 })
177 })),
178 _ => {
179 let mut data = <[[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT]>::zeroed();
180 for byte_index in 0..Self::SCALAR_BYTES {
181 let broadcast = <$packed_storage>::broadcast(unsafe {
182 scalar.get_base_unchecked(byte_index)
183 });
184
185 for i in 0..Self::HEIGHT {
186 data[i][byte_index] = broadcast;
187 }
188 }
189
190 data
191 }
192 };
193
194 Self { data }
195 }
196
197 impl_init_with_transpose!($packed_storage, $scalar_type, $storage_tower_level);
198
199 #[inline]
202 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Clone + Send + '_ {
203 type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
204
205 let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
206 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
207 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
208
209 let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
210 data.into_iter()
211 }
212
213 #[inline]
214 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Clone + Send {
215 type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
216
217 let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
218 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
219 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
220
221 let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
222 data.into_iter()
223 }
224
225 #[inline]
226 fn square(self) -> Self {
227 let mut result = Self::default();
228
229 for i in 0..Self::HEIGHT {
230 square::<$packed_storage, $scalar_tower_level>(
231 &self.data[i],
232 &mut result.data[i],
233 );
234 }
235
236 result
237 }
238
239 #[inline]
240 fn invert_or_zero(self) -> Self {
241 let mut result = Self::default();
242
243 for i in 0..Self::HEIGHT {
244 invert_or_zero::<$packed_storage, $scalar_tower_level>(
245 &self.data[i],
246 &mut result.data[i],
247 );
248 }
249
250 result
251 }
252
253 #[inline(always)]
254 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
255
256 let self_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&self.data);
257 let other_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&other.data);
258
259 let (data_1, data_2) = interleave_byte_sliced(self_data, other_data, log_block_len + checked_log_2(Self::SCALAR_BYTES));
261 (
262 Self {
263 data: bytemuck::must_cast(data_1),
264 },
265 Self {
266 data: bytemuck::must_cast(data_2),
267 },
268 )
269 }
270
271 #[inline]
272 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
273 let (result1, result2) = unzip_byte_sliced::<$packed_storage, {Self::HEIGHT_BYTES}, {Self::SCALAR_BYTES}>(bytemuck::must_cast_ref(&self.data), bytemuck::must_cast_ref(&other.data), log_block_len + checked_log_2(Self::SCALAR_BYTES));
274
275 (
276 Self {
277 data: bytemuck::must_cast(result1),
278 },
279 Self {
280 data: bytemuck::must_cast(result2),
281 },
282 )
283 }
284
285 impl_spread!($name, $scalar_type, $packed_storage, $storage_tower_level);
286 }
287
288 impl Mul for $name {
289 type Output = Self;
290
291 #[inline]
292 fn mul(self, rhs: Self) -> Self {
293 let mut result = Self::default();
294
295 for i in 0..Self::HEIGHT {
296 mul::<$packed_storage, $scalar_tower_level>(
297 &self.data[i],
298 &rhs.data[i],
299 &mut result.data[i],
300 );
301 }
302
303 result
304 }
305 }
306
307 impl Add for $name {
308 type Output = Self;
309
310 #[inline]
311 fn add(self, rhs: Self) -> Self {
312 Self {
313 data: array::from_fn(|byte_number| {
314 array::from_fn(|column|
315 self.data[byte_number][column] + rhs.data[byte_number][column]
316 )
317 }),
318 }
319 }
320 }
321
322 impl AddAssign for $name {
323 #[inline]
324 fn add_assign(&mut self, rhs: Self) {
325 for (data, rhs) in zip(&mut self.data, &rhs.data) {
326 for (data, rhs) in zip(data, rhs) {
327 *data += *rhs
328 }
329 }
330 }
331 }
332
333 byte_sliced_common!($name, $packed_storage, $scalar_type, $storage_tower_level);
334
335 impl<Inner: Transformation<$packed_storage, $packed_storage>> Transformation<$name, $name> for TransformationWrapperNxN<Inner, {<$scalar_tower_level as TowerLevel>::WIDTH}> {
336 fn transform(&self, data: &$name) -> $name {
337 let mut result = <$name>::default();
338
339 for row in 0..<$name>::SCALAR_BYTES {
340 for col in 0..<$name>::SCALAR_BYTES {
341 let transformation = &self.0[col][row];
342
343 for i in 0..<$name>::HEIGHT {
344 result.data[i][row] += transformation.transform(&data.data[i][col]);
345 }
346 }
347 }
348
349 result
350 }
351 }
352
353 impl PackedTransformationFactory<$name> for $name {
354 type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> = TransformationWrapperNxN<<$packed_storage as PackedTransformationFactory<$packed_storage>>::PackedTransformation::<[AESTowerField8b; 8]>, {<$scalar_tower_level as TowerLevel>::WIDTH}>;
355
356 fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
357 transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
358 ) -> Self::PackedTransformation<Data> {
359 let transformations_8b = array::from_fn(|row| {
360 array::from_fn(|col| {
361 let row = row * 8;
362 let linear_transformation_8b = array::from_fn::<_, 8, _>(|row_8b| unsafe {
363 <<$name as PackedField>::Scalar as ExtensionField<AESTowerField8b>>::get_base_unchecked(&transformation.bases()[row + row_8b], col)
364 });
365
366 <$packed_storage as PackedTransformationFactory<$packed_storage
367 >>::make_packed_transformation(FieldLinearTransformation::new(linear_transformation_8b))
368 })
369 });
370
371 TransformationWrapperNxN(transformations_8b)
372 }
373 }
374 };
375}
376
377macro_rules! impl_init_with_transpose {
380 ($packed_storage:ty, $scalar_type:ty, $storage_tower_level:ty) => {
381 #[inline]
382 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
383 type TransposePacked =
384 <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
385
386 let mut data: [TransposePacked; Self::HEIGHT_BYTES] = array::from_fn(|i| {
387 PackedField::from_fn(|j| f((i << TransposePacked::LOG_WIDTH) + j))
388 });
389 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
390 UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
391 &mut underliers,
392 );
393
394 Self {
395 data: bytemuck::must_cast(data),
396 }
397 }
398
399 #[inline]
400 fn from_scalars(scalars: impl IntoIterator<Item = Self::Scalar>) -> Self {
401 type TransposePacked =
402 <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
403
404 let mut data = [TransposePacked::default(); Self::HEIGHT_BYTES];
405 let mut iter = scalars.into_iter();
406 for v in data.iter_mut() {
407 *v = TransposePacked::from_scalars(&mut iter);
408 }
409
410 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
411 UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
412 &mut underliers,
413 );
414
415 Self {
416 data: bytemuck::must_cast(data),
417 }
418 }
419 };
420}
421
422macro_rules! impl_spread {
423 ($name:ident, $scalar_type:ty, $packed_storage:ty, $storage_tower_level:ty) => {
424 #[inline]
425 unsafe fn spread_unchecked(mut self, log_block_len: usize, block_idx: usize) -> Self {
426 {
427 let data: &mut [$packed_storage; Self::HEIGHT_BYTES] =
428 bytemuck::must_cast_mut(&mut self.data);
429 let underliers = WithUnderlier::to_underliers_arr_ref_mut(data);
430 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
431 underliers,
432 );
433 }
434
435 type PackedSequentialField =
436 <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
437 let scaled_data: ScaledPackedField<PackedSequentialField, { Self::HEIGHT_BYTES }> =
438 bytemuck::must_cast(self);
439
440 let mut result = unsafe { scaled_data.spread_unchecked(log_block_len, block_idx) };
441 let data: &mut [$packed_storage; Self::HEIGHT_BYTES] =
442 bytemuck::must_cast_mut(&mut result);
443 let underliers = WithUnderlier::to_underliers_arr_ref_mut(data);
444 UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(underliers);
445 bytemuck::must_cast(result)
446 }
447 };
448}
449
450macro_rules! byte_sliced_common {
451 ($name:ident, $packed_storage:ty, $scalar_type:ty, $storage_level:ty) => {
452 impl Add<$scalar_type> for $name {
453 type Output = Self;
454
455 #[inline]
456 fn add(self, rhs: $scalar_type) -> $name {
457 self + Self::broadcast(rhs)
458 }
459 }
460
461 impl AddAssign<$scalar_type> for $name {
462 #[inline]
463 fn add_assign(&mut self, rhs: $scalar_type) {
464 *self += Self::broadcast(rhs)
465 }
466 }
467
468 impl Sub<$scalar_type> for $name {
469 type Output = Self;
470
471 #[inline]
472 fn sub(self, rhs: $scalar_type) -> $name {
473 self.add(rhs)
474 }
475 }
476
477 impl SubAssign<$scalar_type> for $name {
478 #[inline]
479 fn sub_assign(&mut self, rhs: $scalar_type) {
480 self.add_assign(rhs)
481 }
482 }
483
484 impl Mul<$scalar_type> for $name {
485 type Output = Self;
486
487 #[inline]
488 fn mul(self, rhs: $scalar_type) -> $name {
489 self * Self::broadcast(rhs)
490 }
491 }
492
493 impl MulAssign<$scalar_type> for $name {
494 #[inline]
495 fn mul_assign(&mut self, rhs: $scalar_type) {
496 *self *= Self::broadcast(rhs);
497 }
498 }
499
500 impl Sub for $name {
501 type Output = Self;
502
503 #[inline]
504 fn sub(self, rhs: Self) -> Self {
505 self.add(rhs)
506 }
507 }
508
509 impl SubAssign for $name {
510 #[inline]
511 fn sub_assign(&mut self, rhs: Self) {
512 self.add_assign(rhs);
513 }
514 }
515
516 impl MulAssign for $name {
517 #[inline]
518 fn mul_assign(&mut self, rhs: Self) {
519 *self = *self * rhs;
520 }
521 }
522
523 impl Product for $name {
524 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
525 let mut result = Self::one();
526
527 let mut is_first_item = true;
528 for item in iter {
529 if is_first_item {
530 result = item;
531 } else {
532 result *= item;
533 }
534
535 is_first_item = false;
536 }
537
538 result
539 }
540 }
541
542 impl Sum for $name {
543 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
544 let mut result = Self::zero();
545
546 for item in iter {
547 result += item;
548 }
549
550 result
551 }
552 }
553
554 unsafe impl WithUnderlier for $name {
555 type Underlier = ByteSlicedUnderlier<
556 <$packed_storage as WithUnderlier>::Underlier,
557 { <$storage_level as TowerLevel>::WIDTH },
558 >;
559
560 #[inline(always)]
561 fn to_underlier(self) -> Self::Underlier {
562 bytemuck::must_cast(self)
563 }
564
565 #[inline(always)]
566 fn to_underlier_ref(&self) -> &Self::Underlier {
567 bytemuck::must_cast_ref(self)
568 }
569
570 #[inline(always)]
571 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
572 bytemuck::must_cast_mut(self)
573 }
574
575 #[inline(always)]
576 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
577 bytemuck::must_cast_slice(val)
578 }
579
580 #[inline(always)]
581 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
582 bytemuck::must_cast_slice_mut(val)
583 }
584
585 #[inline(always)]
586 fn from_underlier(val: Self::Underlier) -> Self {
587 bytemuck::must_cast(val)
588 }
589
590 #[inline(always)]
591 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
592 bytemuck::must_cast_ref(val)
593 }
594
595 #[inline(always)]
596 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
597 bytemuck::must_cast_mut(val)
598 }
599
600 #[inline(always)]
601 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
602 bytemuck::must_cast_slice(val)
603 }
604
605 #[inline(always)]
606 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
607 bytemuck::must_cast_slice_mut(val)
608 }
609 }
610
611 impl PackScalar<$scalar_type>
612 for ByteSlicedUnderlier<
613 <$packed_storage as WithUnderlier>::Underlier,
614 { <$storage_level as TowerLevel>::WIDTH },
615 >
616 {
617 type Packed = $name;
618 }
619 };
620}
621
622macro_rules! define_byte_sliced_3d_1b {
625 ($name:ident, $packed_storage:ty, $storage_tower_level: ty) => {
626 #[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
627 #[repr(transparent)]
628 pub struct $name {
629 pub(super) data: [$packed_storage; <$storage_tower_level>::WIDTH],
630 }
631
632 impl $name {
633 pub const BYTES: usize =
634 <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
635
636 pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
637 const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT_BYTES);
638
639 #[allow(clippy::modulo_one)]
644 #[inline(always)]
645 pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
646 type Packed8b =
647 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
648
649 unsafe {
650 Packed8b::cast_ext_ref(self.data.get_unchecked(byte_index % Self::HEIGHT_BYTES))
651 .get_unchecked(byte_index / Self::HEIGHT_BYTES)
652 .to_underlier()
653 }
654 }
655
656 #[inline]
659 pub fn transpose_to(
660 &self,
661 out: &mut [PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
662 Self::HEIGHT_BYTES],
663 ) {
664 let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
665 *underliers = WithUnderlier::to_underliers_arr(self.data);
666
667 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
668 underliers,
669 );
670 }
671
672 #[inline]
675 pub fn transpose_from(
676 underliers: &[PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
677 Self::HEIGHT_BYTES],
678 ) -> Self {
679 let mut underliers = WithUnderlier::to_underliers_arr(*underliers);
680
681 <$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<
682 $storage_tower_level,
683 >(&mut underliers);
684
685 Self {
686 data: WithUnderlier::from_underliers_arr(underliers),
687 }
688 }
689 }
690
691 impl Default for $name {
692 fn default() -> Self {
693 Self {
694 data: bytemuck::Zeroable::zeroed(),
695 }
696 }
697 }
698
699 impl Debug for $name {
700 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
701 let values_str = self
702 .iter()
703 .map(|value| format!("{}", value))
704 .collect::<Vec<_>>()
705 .join(",");
706
707 write!(f, "ByteSlicedAES{}x1b([{}])", Self::WIDTH, values_str)
708 }
709 }
710
711 impl PackedField for $name {
712 type Scalar = BinaryField1b;
713
714 const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
715
716 #[allow(clippy::modulo_one)]
717 #[inline(always)]
718 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
719 unsafe {
720 self.data
721 .get_unchecked((i / 8) % Self::HEIGHT_BYTES)
722 .get_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8)
723 }
724 }
725
726 #[allow(clippy::modulo_one)]
727 #[inline(always)]
728 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
729 unsafe {
730 self.data
731 .get_unchecked_mut((i / 8) % Self::HEIGHT_BYTES)
732 .set_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8, scalar);
733 }
734 }
735
736 fn random(mut rng: impl rand::RngCore) -> Self {
737 let data = array::from_fn(|_| <$packed_storage>::random(&mut rng));
738 Self { data }
739 }
740
741 impl_init_with_transpose!($packed_storage, BinaryField1b, $storage_tower_level);
742
743 #[allow(unreachable_patterns)]
747 #[inline]
748 fn broadcast(scalar: Self::Scalar) -> Self {
749 let underlier = <$packed_storage as WithUnderlier>::Underlier::fill_with_bit(
750 scalar.to_underlier().into(),
751 );
752 Self {
753 data: array::from_fn(|_| WithUnderlier::from_underlier(underlier)),
754 }
755 }
756
757 #[inline]
758 fn square(self) -> Self {
759 let data = array::from_fn(|i| self.data[i].clone().square());
760 Self { data }
761 }
762
763 #[inline]
764 fn invert_or_zero(self) -> Self {
765 let data = array::from_fn(|i| self.data[i].clone().invert_or_zero());
766 Self { data }
767 }
768
769 #[inline(always)]
770 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
771 type Packed8b =
772 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
773
774 if log_block_len < 3 {
775 let mut result1 = Self::default();
776 let mut result2 = Self::default();
777
778 for i in 0..Self::HEIGHT_BYTES {
779 (result1.data[i], result2.data[i]) =
780 self.data[i].interleave(other.data[i], log_block_len);
781 }
782
783 (result1, result2)
784 } else {
785 let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
786 Packed8b::cast_ext_arr_ref(&self.data);
787 let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
788 Packed8b::cast_ext_arr_ref(&other.data);
789
790 let (result1, result2) =
791 interleave_byte_sliced(self_data, other_data, log_block_len - 3);
792
793 (
794 Self {
795 data: Packed8b::cast_base_arr(result1),
796 },
797 Self {
798 data: Packed8b::cast_base_arr(result2),
799 },
800 )
801 }
802 }
803
804 #[inline]
805 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
806 if log_block_len < 3 {
807 let mut result1 = Self::default();
808 let mut result2 = Self::default();
809
810 for i in 0..Self::HEIGHT_BYTES {
811 (result1.data[i], result2.data[i]) =
812 self.data[i].unzip(other.data[i], log_block_len);
813 }
814
815 (result1, result2)
816 } else {
817 type Packed8b =
818 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
819
820 let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
821 Packed8b::cast_ext_arr_ref(&self.data);
822 let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
823 Packed8b::cast_ext_arr_ref(&other.data);
824
825 let (result1, result2) = unzip_byte_sliced::<Packed8b, { Self::HEIGHT_BYTES }, 1>(
826 self_data,
827 other_data,
828 log_block_len - 3,
829 );
830
831 (
832 Self {
833 data: Packed8b::cast_base_arr(result1),
834 },
835 Self {
836 data: Packed8b::cast_base_arr(result2),
837 },
838 )
839 }
840 }
841
842 impl_spread!($name, BinaryField1b, $packed_storage, $storage_tower_level);
843 }
844
845 impl Mul for $name {
846 type Output = Self;
847
848 #[inline]
849 fn mul(self, rhs: Self) -> Self {
850 Self {
851 data: array::from_fn(|i| self.data[i].clone() * rhs.data[i].clone()),
852 }
853 }
854 }
855
856 impl Add for $name {
857 type Output = Self;
858
859 #[inline]
860 fn add(self, rhs: Self) -> Self {
861 Self {
862 data: array::from_fn(|byte_number| {
863 self.data[byte_number] + rhs.data[byte_number]
864 }),
865 }
866 }
867 }
868
869 impl AddAssign for $name {
870 #[inline]
871 fn add_assign(&mut self, rhs: Self) {
872 for (data, rhs) in zip(&mut self.data, &rhs.data) {
873 *data += *rhs
874 }
875 }
876 }
877
878 byte_sliced_common!($name, $packed_storage, BinaryField1b, $storage_tower_level);
879
880 impl PackedTransformationFactory<$name> for $name {
881 type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> =
882 IDTransformation;
883
884 fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
885 _transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
886 ) -> Self::PackedTransformation<Data> {
887 IDTransformation
888 }
889 }
890 };
891}
892
893#[inline(always)]
894fn interleave_internal_block<P: PackedField, const N: usize, const LOG_BLOCK_LEN: usize>(
895 lhs: &[P; N],
896 rhs: &[P; N],
897) -> ([P; N], [P; N]) {
898 debug_assert!(LOG_BLOCK_LEN < checked_log_2(N));
899
900 let result_1 = array::from_fn(|i| {
901 let block_index = i >> (LOG_BLOCK_LEN + 1);
902 let block_start = block_index << (LOG_BLOCK_LEN + 1);
903 let block_offset = i - block_start;
904
905 if block_offset < (1 << LOG_BLOCK_LEN) {
906 lhs[i]
907 } else {
908 rhs[i - (1 << LOG_BLOCK_LEN)]
909 }
910 });
911 let result_2 = array::from_fn(|i| {
912 let block_index = i >> (LOG_BLOCK_LEN + 1);
913 let block_start = block_index << (LOG_BLOCK_LEN + 1);
914 let block_offset = i - block_start;
915
916 if block_offset < (1 << LOG_BLOCK_LEN) {
917 lhs[i + (1 << LOG_BLOCK_LEN)]
918 } else {
919 rhs[i]
920 }
921 });
922
923 (result_1, result_2)
924}
925
926#[inline(always)]
927fn interleave_byte_sliced<P: PackedField, const N: usize>(
928 lhs: &[P; N],
929 rhs: &[P; N],
930 log_block_len: usize,
931) -> ([P; N], [P; N]) {
932 debug_assert!(checked_log_2(N) <= 4);
933
934 match log_block_len {
935 x if x >= checked_log_2(N) => interleave_big_block::<P, N>(lhs, rhs, log_block_len),
936 0 => interleave_internal_block::<P, N, 0>(lhs, rhs),
937 1 => interleave_internal_block::<P, N, 1>(lhs, rhs),
938 2 => interleave_internal_block::<P, N, 2>(lhs, rhs),
939 3 => interleave_internal_block::<P, N, 3>(lhs, rhs),
940 _ => unreachable!(),
941 }
942}
943
944#[inline(always)]
945fn unzip_byte_sliced<P: PackedField, const N: usize, const SCALAR_BYTES: usize>(
946 lhs: &[P; N],
947 rhs: &[P; N],
948 log_block_len: usize,
949) -> ([P; N], [P; N]) {
950 let mut result1: [P; N] = bytemuck::Zeroable::zeroed();
951 let mut result2: [P; N] = bytemuck::Zeroable::zeroed();
952
953 let log_height = checked_log_2(N);
954 if log_block_len < log_height {
955 let block_size = 1 << log_block_len;
956 let half = N / 2;
957 for block_offset in (0..half).step_by(block_size) {
958 let target_offset = block_offset * 2;
959
960 result1[block_offset..block_offset + block_size]
961 .copy_from_slice(&lhs[target_offset..target_offset + block_size]);
962 result1[half + target_offset..half + target_offset + block_size]
963 .copy_from_slice(&rhs[target_offset..target_offset + block_size]);
964
965 result2[block_offset..block_offset + block_size]
966 .copy_from_slice(&lhs[target_offset + block_size..target_offset + 2 * block_size]);
967 result2[half + target_offset..half + target_offset + block_size]
968 .copy_from_slice(&rhs[target_offset + block_size..target_offset + 2 * block_size]);
969 }
970 } else {
971 for i in 0..N {
972 (result1[i], result2[i]) = lhs[i].unzip(rhs[i], log_block_len - log_height);
973 }
974 }
975
976 (result1, result2)
977}
978
979#[inline(always)]
980fn interleave_big_block<P: PackedField, const N: usize>(
981 lhs: &[P; N],
982 rhs: &[P; N],
983 log_block_len: usize,
984) -> ([P; N], [P; N]) {
985 let mut result_1 = <[P; N]>::zeroed();
986 let mut result_2 = <[P; N]>::zeroed();
987
988 for i in 0..N {
989 (result_1[i], result_2[i]) = lhs[i].interleave(rhs[i], log_block_len - checked_log_2(N));
990 }
991
992 (result_1, result_2)
993}
994
995define_byte_sliced_3d!(
997 ByteSlicedAES16x128b,
998 AESTowerField128b,
999 PackedAESBinaryField16x8b,
1000 TowerLevel16,
1001 TowerLevel16
1002);
1003define_byte_sliced_3d!(
1004 ByteSlicedAES16x64b,
1005 AESTowerField64b,
1006 PackedAESBinaryField16x8b,
1007 TowerLevel8,
1008 TowerLevel8
1009);
1010define_byte_sliced_3d!(
1011 ByteSlicedAES2x16x64b,
1012 AESTowerField64b,
1013 PackedAESBinaryField16x8b,
1014 TowerLevel8,
1015 TowerLevel16
1016);
1017define_byte_sliced_3d!(
1018 ByteSlicedAES16x32b,
1019 AESTowerField32b,
1020 PackedAESBinaryField16x8b,
1021 TowerLevel4,
1022 TowerLevel4
1023);
1024define_byte_sliced_3d!(
1025 ByteSlicedAES4x16x32b,
1026 AESTowerField32b,
1027 PackedAESBinaryField16x8b,
1028 TowerLevel4,
1029 TowerLevel16
1030);
1031define_byte_sliced_3d!(
1032 ByteSlicedAES16x16b,
1033 AESTowerField16b,
1034 PackedAESBinaryField16x8b,
1035 TowerLevel2,
1036 TowerLevel2
1037);
1038define_byte_sliced_3d!(
1039 ByteSlicedAES8x16x16b,
1040 AESTowerField16b,
1041 PackedAESBinaryField16x8b,
1042 TowerLevel2,
1043 TowerLevel16
1044);
1045define_byte_sliced_3d!(
1046 ByteSlicedAES16x8b,
1047 AESTowerField8b,
1048 PackedAESBinaryField16x8b,
1049 TowerLevel1,
1050 TowerLevel1
1051);
1052define_byte_sliced_3d!(
1053 ByteSlicedAES16x16x8b,
1054 AESTowerField8b,
1055 PackedAESBinaryField16x8b,
1056 TowerLevel1,
1057 TowerLevel16
1058);
1059
1060define_byte_sliced_3d_1b!(ByteSliced16x128x1b, PackedBinaryField128x1b, TowerLevel16);
1061define_byte_sliced_3d_1b!(ByteSliced8x128x1b, PackedBinaryField128x1b, TowerLevel8);
1062define_byte_sliced_3d_1b!(ByteSliced4x128x1b, PackedBinaryField128x1b, TowerLevel4);
1063define_byte_sliced_3d_1b!(ByteSliced2x128x1b, PackedBinaryField128x1b, TowerLevel2);
1064define_byte_sliced_3d_1b!(ByteSliced1x128x1b, PackedBinaryField128x1b, TowerLevel1);
1065
1066define_byte_sliced_3d!(
1068 ByteSlicedAES32x128b,
1069 AESTowerField128b,
1070 PackedAESBinaryField32x8b,
1071 TowerLevel16,
1072 TowerLevel16
1073);
1074define_byte_sliced_3d!(
1075 ByteSlicedAES32x64b,
1076 AESTowerField64b,
1077 PackedAESBinaryField32x8b,
1078 TowerLevel8,
1079 TowerLevel8
1080);
1081define_byte_sliced_3d!(
1082 ByteSlicedAES2x32x64b,
1083 AESTowerField64b,
1084 PackedAESBinaryField32x8b,
1085 TowerLevel8,
1086 TowerLevel16
1087);
1088define_byte_sliced_3d!(
1089 ByteSlicedAES32x32b,
1090 AESTowerField32b,
1091 PackedAESBinaryField32x8b,
1092 TowerLevel4,
1093 TowerLevel4
1094);
1095define_byte_sliced_3d!(
1096 ByteSlicedAES4x32x32b,
1097 AESTowerField32b,
1098 PackedAESBinaryField32x8b,
1099 TowerLevel4,
1100 TowerLevel16
1101);
1102define_byte_sliced_3d!(
1103 ByteSlicedAES32x16b,
1104 AESTowerField16b,
1105 PackedAESBinaryField32x8b,
1106 TowerLevel2,
1107 TowerLevel2
1108);
1109define_byte_sliced_3d!(
1110 ByteSlicedAES8x32x16b,
1111 AESTowerField16b,
1112 PackedAESBinaryField32x8b,
1113 TowerLevel2,
1114 TowerLevel16
1115);
1116define_byte_sliced_3d!(
1117 ByteSlicedAES32x8b,
1118 AESTowerField8b,
1119 PackedAESBinaryField32x8b,
1120 TowerLevel1,
1121 TowerLevel1
1122);
1123define_byte_sliced_3d!(
1124 ByteSlicedAES16x32x8b,
1125 AESTowerField8b,
1126 PackedAESBinaryField32x8b,
1127 TowerLevel1,
1128 TowerLevel16
1129);
1130
1131define_byte_sliced_3d_1b!(ByteSliced16x256x1b, PackedBinaryField256x1b, TowerLevel16);
1132define_byte_sliced_3d_1b!(ByteSliced8x256x1b, PackedBinaryField256x1b, TowerLevel8);
1133define_byte_sliced_3d_1b!(ByteSliced4x256x1b, PackedBinaryField256x1b, TowerLevel4);
1134define_byte_sliced_3d_1b!(ByteSliced2x256x1b, PackedBinaryField256x1b, TowerLevel2);
1135define_byte_sliced_3d_1b!(ByteSliced1x256x1b, PackedBinaryField256x1b, TowerLevel1);
1136
1137define_byte_sliced_3d!(
1139 ByteSlicedAES64x128b,
1140 AESTowerField128b,
1141 PackedAESBinaryField64x8b,
1142 TowerLevel16,
1143 TowerLevel16
1144);
1145define_byte_sliced_3d!(
1146 ByteSlicedAES64x64b,
1147 AESTowerField64b,
1148 PackedAESBinaryField64x8b,
1149 TowerLevel8,
1150 TowerLevel8
1151);
1152define_byte_sliced_3d!(
1153 ByteSlicedAES2x64x64b,
1154 AESTowerField64b,
1155 PackedAESBinaryField64x8b,
1156 TowerLevel8,
1157 TowerLevel16
1158);
1159define_byte_sliced_3d!(
1160 ByteSlicedAES64x32b,
1161 AESTowerField32b,
1162 PackedAESBinaryField64x8b,
1163 TowerLevel4,
1164 TowerLevel4
1165);
1166define_byte_sliced_3d!(
1167 ByteSlicedAES4x64x32b,
1168 AESTowerField32b,
1169 PackedAESBinaryField64x8b,
1170 TowerLevel4,
1171 TowerLevel16
1172);
1173define_byte_sliced_3d!(
1174 ByteSlicedAES64x16b,
1175 AESTowerField16b,
1176 PackedAESBinaryField64x8b,
1177 TowerLevel2,
1178 TowerLevel2
1179);
1180define_byte_sliced_3d!(
1181 ByteSlicedAES8x64x16b,
1182 AESTowerField16b,
1183 PackedAESBinaryField64x8b,
1184 TowerLevel2,
1185 TowerLevel16
1186);
1187define_byte_sliced_3d!(
1188 ByteSlicedAES64x8b,
1189 AESTowerField8b,
1190 PackedAESBinaryField64x8b,
1191 TowerLevel1,
1192 TowerLevel1
1193);
1194define_byte_sliced_3d!(
1195 ByteSlicedAES16x64x8b,
1196 AESTowerField8b,
1197 PackedAESBinaryField64x8b,
1198 TowerLevel1,
1199 TowerLevel16
1200);
1201
1202define_byte_sliced_3d_1b!(ByteSliced16x512x1b, PackedBinaryField512x1b, TowerLevel16);
1203define_byte_sliced_3d_1b!(ByteSliced8x512x1b, PackedBinaryField512x1b, TowerLevel8);
1204define_byte_sliced_3d_1b!(ByteSliced4x512x1b, PackedBinaryField512x1b, TowerLevel4);
1205define_byte_sliced_3d_1b!(ByteSliced2x512x1b, PackedBinaryField512x1b, TowerLevel2);
1206define_byte_sliced_3d_1b!(ByteSliced1x512x1b, PackedBinaryField512x1b, TowerLevel1);