1use std::{
4 array,
5 fmt::Debug,
6 iter::{zip, Product, Sum},
7 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_utils::checked_arithmetics::checked_log_2;
11use bytemuck::{Pod, Zeroable};
12
13use super::{invert::invert_or_zero, multiply::mul, square::square};
14use crate::{
15 arch::byte_sliced::underlier::ByteSlicedUnderlier,
16 as_packed_field::{PackScalar, PackedType},
17 binary_field::BinaryField,
18 linear_transformation::{
19 FieldLinearTransformation, IDTransformation, PackedTransformationFactory, Transformation,
20 },
21 packed_aes_field::PackedAESBinaryField32x8b,
22 tower_levels::{TowerLevel, TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8},
23 underlier::{UnderlierWithBitOps, WithUnderlier},
24 AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
25 BinaryField1b, ExtensionField, PackedAESBinaryField16x8b, PackedAESBinaryField64x8b,
26 PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedExtension,
27 PackedField,
28};
29
30pub struct TransformationWrapperNxN<Inner, const N: usize>([[Inner; N]; N]);
34
35macro_rules! define_byte_sliced_3d {
45 ($name:ident, $scalar_type:ty, $packed_storage:ty, $scalar_tower_level: ty, $storage_tower_level: ty) => {
46 #[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
47 #[repr(transparent)]
48 pub struct $name {
49 pub(super) data: [[$packed_storage; <$scalar_tower_level as TowerLevel>::WIDTH]; <$storage_tower_level as TowerLevel>::WIDTH / <$scalar_tower_level as TowerLevel>::WIDTH],
50 }
51
52 impl $name {
53 pub const BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
54
55 const SCALAR_BYTES: usize = <$scalar_type>::N_BITS / 8;
56 pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
57 const HEIGHT: usize = Self::HEIGHT_BYTES / Self::SCALAR_BYTES;
58 const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT);
59
60 #[allow(clippy::modulo_one)]
65 #[inline(always)]
66 pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
67 let row = byte_index % Self::HEIGHT_BYTES;
68 self.data
69 .get_unchecked(row / Self::SCALAR_BYTES)
70 .get_unchecked(row % Self::SCALAR_BYTES)
71 .get_unchecked(byte_index / Self::HEIGHT_BYTES)
72 .to_underlier()
73 }
74
75 #[inline]
77 pub fn transpose_to(&self, out: &mut [<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES]) {
78 let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
79 *underliers = bytemuck::must_cast(self.data);
80
81 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(underliers);
82 }
83
84 #[inline]
86 pub fn transpose_from(
87 underliers: &[<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES],
88 ) -> Self {
89 let mut result = Self {
90 data: bytemuck::must_cast(*underliers),
91 };
92
93 <$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<$storage_tower_level>(bytemuck::must_cast_mut(&mut result.data));
94
95 result
96 }
97 }
98
99 impl Default for $name {
100 fn default() -> Self {
101 Self {
102 data: bytemuck::Zeroable::zeroed(),
103 }
104 }
105 }
106
107 impl Debug for $name {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 let values_str = self
110 .iter()
111 .map(|value| format!("{}", value))
112 .collect::<Vec<_>>()
113 .join(",");
114
115 write!(f, "ByteSlicedAES{}x{}([{}])", Self::WIDTH, <$scalar_type>::N_BITS, values_str)
116 }
117 }
118
119 impl PackedField for $name {
120 type Scalar = $scalar_type;
121
122 const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
123
124 #[allow(clippy::modulo_one)]
125 #[inline(always)]
126 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
127 let element_rows = self.data.get_unchecked(i % Self::HEIGHT);
128 Self::Scalar::from_bases((0..Self::SCALAR_BYTES).map(|byte_index| {
129 element_rows
130 .get_unchecked(byte_index)
131 .get_unchecked(i / Self::HEIGHT)
132 }))
133 .expect("byte index is within bounds")
134 }
135
136 #[allow(clippy::modulo_one)]
137 #[inline(always)]
138 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
139 let element_rows = self.data.get_unchecked_mut(i % Self::HEIGHT);
140 for byte_index in 0..Self::SCALAR_BYTES {
141 element_rows
142 .get_unchecked_mut(byte_index)
143 .set_unchecked(
144 i / Self::HEIGHT,
145 scalar.get_base_unchecked(byte_index),
146 );
147 }
148 }
149
150 fn random(mut rng: impl rand::RngCore) -> Self {
151 let data = array::from_fn(|_| array::from_fn(|_| <$packed_storage>::random(&mut rng)));
152 Self { data }
153 }
154
155 #[allow(unreachable_patterns)]
156 #[inline]
157 fn broadcast(scalar: Self::Scalar) -> Self {
158 let data: [[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT] = match Self::SCALAR_BYTES {
159 1 => {
160 let packed_broadcast =
161 <$packed_storage>::broadcast(unsafe { scalar.get_base_unchecked(0) });
162 array::from_fn(|_| array::from_fn(|_| packed_broadcast))
163 }
164 Self::HEIGHT_BYTES => array::from_fn(|_| array::from_fn(|byte_index| {
165 <$packed_storage>::broadcast(unsafe {
166 scalar.get_base_unchecked(byte_index)
167 })
168 })),
169 _ => {
170 let mut data = <[[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT]>::zeroed();
171 for byte_index in 0..Self::SCALAR_BYTES {
172 let broadcast = <$packed_storage>::broadcast(unsafe {
173 scalar.get_base_unchecked(byte_index)
174 });
175
176 for i in 0..Self::HEIGHT {
177 data[i][byte_index] = broadcast;
178 }
179 }
180
181 data
182 }
183 };
184
185 Self { data }
186 }
187
188 impl_init_with_transpose!($packed_storage, $scalar_type, $storage_tower_level);
189
190 #[inline]
193 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Clone + Send + '_ {
194 type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
195
196 let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
197 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
198 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
199
200 let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
201 data.into_iter()
202 }
203
204 #[inline]
205 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Clone + Send {
206 type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
207
208 let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
209 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
210 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
211
212 let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
213 data.into_iter()
214 }
215
216 #[inline]
217 fn square(self) -> Self {
218 let mut result = Self::default();
219
220 for i in 0..Self::HEIGHT {
221 square::<$packed_storage, $scalar_tower_level>(
222 &self.data[i],
223 &mut result.data[i],
224 );
225 }
226
227 result
228 }
229
230 #[inline]
231 fn invert_or_zero(self) -> Self {
232 let mut result = Self::default();
233
234 for i in 0..Self::HEIGHT {
235 invert_or_zero::<$packed_storage, $scalar_tower_level>(
236 &self.data[i],
237 &mut result.data[i],
238 );
239 }
240
241 result
242 }
243
244 #[inline(always)]
245 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
246
247 let self_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&self.data);
248 let other_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&other.data);
249
250 let (data_1, data_2) = interleave_byte_sliced(self_data, other_data, log_block_len + checked_log_2(Self::SCALAR_BYTES));
252 (
253 Self {
254 data: bytemuck::must_cast(data_1),
255 },
256 Self {
257 data: bytemuck::must_cast(data_2),
258 },
259 )
260 }
261
262 #[inline]
263 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
264 let (result1, result2) = unzip_byte_sliced::<$packed_storage, {Self::HEIGHT_BYTES}, {Self::SCALAR_BYTES}>(bytemuck::must_cast_ref(&self.data), bytemuck::must_cast_ref(&other.data), log_block_len + checked_log_2(Self::SCALAR_BYTES));
265
266 (
267 Self {
268 data: bytemuck::must_cast(result1),
269 },
270 Self {
271 data: bytemuck::must_cast(result2),
272 },
273 )
274 }
275 }
276
277 impl Mul for $name {
278 type Output = Self;
279
280 #[inline]
281 fn mul(self, rhs: Self) -> Self {
282 let mut result = Self::default();
283
284 for i in 0..Self::HEIGHT {
285 mul::<$packed_storage, $scalar_tower_level>(
286 &self.data[i],
287 &rhs.data[i],
288 &mut result.data[i],
289 );
290 }
291
292 result
293 }
294 }
295
296 impl Add for $name {
297 type Output = Self;
298
299 #[inline]
300 fn add(self, rhs: Self) -> Self {
301 Self {
302 data: array::from_fn(|byte_number| {
303 array::from_fn(|column|
304 self.data[byte_number][column] + rhs.data[byte_number][column]
305 )
306 }),
307 }
308 }
309 }
310
311 impl AddAssign for $name {
312 #[inline]
313 fn add_assign(&mut self, rhs: Self) {
314 for (data, rhs) in zip(&mut self.data, &rhs.data) {
315 for (data, rhs) in zip(data, rhs) {
316 *data += *rhs
317 }
318 }
319 }
320 }
321
322 byte_sliced_common!($name, $packed_storage, $scalar_type, $storage_tower_level);
323
324 impl<Inner: Transformation<$packed_storage, $packed_storage>> Transformation<$name, $name> for TransformationWrapperNxN<Inner, {<$scalar_tower_level as TowerLevel>::WIDTH}> {
325 fn transform(&self, data: &$name) -> $name {
326 let mut result = <$name>::default();
327
328 for row in 0..<$name>::SCALAR_BYTES {
329 for col in 0..<$name>::SCALAR_BYTES {
330 let transformation = &self.0[col][row];
331
332 for i in 0..<$name>::HEIGHT {
333 result.data[i][row] += transformation.transform(&data.data[i][col]);
334 }
335 }
336 }
337
338 result
339 }
340 }
341
342 impl PackedTransformationFactory<$name> for $name {
343 type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> = TransformationWrapperNxN<<$packed_storage as PackedTransformationFactory<$packed_storage>>::PackedTransformation::<[AESTowerField8b; 8]>, {<$scalar_tower_level as TowerLevel>::WIDTH}>;
344
345 fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
346 transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
347 ) -> Self::PackedTransformation<Data> {
348 let transformations_8b = array::from_fn(|row| {
349 array::from_fn(|col| {
350 let row = row * 8;
351 let linear_transformation_8b = array::from_fn::<_, 8, _>(|row_8b| unsafe {
352 <<$name as PackedField>::Scalar as ExtensionField<AESTowerField8b>>::get_base_unchecked(&transformation.bases()[row + row_8b], col)
353 });
354
355 <$packed_storage as PackedTransformationFactory<$packed_storage
356 >>::make_packed_transformation(FieldLinearTransformation::new(linear_transformation_8b))
357 })
358 });
359
360 TransformationWrapperNxN(transformations_8b)
361 }
362 }
363 };
364}
365
366macro_rules! impl_init_with_transpose {
369 ($packed_storage:ty, $scalar_type:ty, $storage_tower_level:ty) => {
370 #[inline]
371 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
372 type TransposePacked =
373 <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
374
375 let mut data: [TransposePacked; Self::HEIGHT_BYTES] = array::from_fn(|i| {
376 PackedField::from_fn(|j| f((i << TransposePacked::LOG_WIDTH) + j))
377 });
378 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
379 UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
380 &mut underliers,
381 );
382
383 Self {
384 data: bytemuck::must_cast(data),
385 }
386 }
387
388 #[inline]
389 fn from_scalars(scalars: impl IntoIterator<Item = Self::Scalar>) -> Self {
390 type TransposePacked =
391 <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
392
393 let mut data = [TransposePacked::default(); Self::HEIGHT_BYTES];
394 let mut iter = scalars.into_iter();
395 for v in data.iter_mut() {
396 *v = TransposePacked::from_scalars(&mut iter);
397 }
398
399 let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
400 UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
401 &mut underliers,
402 );
403
404 Self {
405 data: bytemuck::must_cast(data),
406 }
407 }
408 };
409}
410
411macro_rules! byte_sliced_common {
412 ($name:ident, $packed_storage:ty, $scalar_type:ty, $storage_level:ty) => {
413 impl Add<$scalar_type> for $name {
414 type Output = Self;
415
416 #[inline]
417 fn add(self, rhs: $scalar_type) -> $name {
418 self + Self::broadcast(rhs)
419 }
420 }
421
422 impl AddAssign<$scalar_type> for $name {
423 #[inline]
424 fn add_assign(&mut self, rhs: $scalar_type) {
425 *self += Self::broadcast(rhs)
426 }
427 }
428
429 impl Sub<$scalar_type> for $name {
430 type Output = Self;
431
432 #[inline]
433 fn sub(self, rhs: $scalar_type) -> $name {
434 self.add(rhs)
435 }
436 }
437
438 impl SubAssign<$scalar_type> for $name {
439 #[inline]
440 fn sub_assign(&mut self, rhs: $scalar_type) {
441 self.add_assign(rhs)
442 }
443 }
444
445 impl Mul<$scalar_type> for $name {
446 type Output = Self;
447
448 #[inline]
449 fn mul(self, rhs: $scalar_type) -> $name {
450 self * Self::broadcast(rhs)
451 }
452 }
453
454 impl MulAssign<$scalar_type> for $name {
455 #[inline]
456 fn mul_assign(&mut self, rhs: $scalar_type) {
457 *self *= Self::broadcast(rhs);
458 }
459 }
460
461 impl Sub for $name {
462 type Output = Self;
463
464 #[inline]
465 fn sub(self, rhs: Self) -> Self {
466 self.add(rhs)
467 }
468 }
469
470 impl SubAssign for $name {
471 #[inline]
472 fn sub_assign(&mut self, rhs: Self) {
473 self.add_assign(rhs);
474 }
475 }
476
477 impl MulAssign for $name {
478 #[inline]
479 fn mul_assign(&mut self, rhs: Self) {
480 *self = *self * rhs;
481 }
482 }
483
484 impl Product for $name {
485 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
486 let mut result = Self::one();
487
488 let mut is_first_item = true;
489 for item in iter {
490 if is_first_item {
491 result = item;
492 } else {
493 result *= item;
494 }
495
496 is_first_item = false;
497 }
498
499 result
500 }
501 }
502
503 impl Sum for $name {
504 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
505 let mut result = Self::zero();
506
507 for item in iter {
508 result += item;
509 }
510
511 result
512 }
513 }
514
515 unsafe impl WithUnderlier for $name {
516 type Underlier = ByteSlicedUnderlier<
517 <$packed_storage as WithUnderlier>::Underlier,
518 { <$storage_level as TowerLevel>::WIDTH },
519 >;
520
521 #[inline(always)]
522 fn to_underlier(self) -> Self::Underlier {
523 bytemuck::must_cast(self)
524 }
525
526 #[inline(always)]
527 fn to_underlier_ref(&self) -> &Self::Underlier {
528 bytemuck::must_cast_ref(self)
529 }
530
531 #[inline(always)]
532 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
533 bytemuck::must_cast_mut(self)
534 }
535
536 #[inline(always)]
537 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
538 bytemuck::must_cast_slice(val)
539 }
540
541 #[inline(always)]
542 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
543 bytemuck::must_cast_slice_mut(val)
544 }
545
546 #[inline(always)]
547 fn from_underlier(val: Self::Underlier) -> Self {
548 bytemuck::must_cast(val)
549 }
550
551 #[inline(always)]
552 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
553 bytemuck::must_cast_ref(val)
554 }
555
556 #[inline(always)]
557 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
558 bytemuck::must_cast_mut(val)
559 }
560
561 #[inline(always)]
562 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
563 bytemuck::must_cast_slice(val)
564 }
565
566 #[inline(always)]
567 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
568 bytemuck::must_cast_slice_mut(val)
569 }
570 }
571
572 impl PackScalar<$scalar_type>
573 for ByteSlicedUnderlier<
574 <$packed_storage as WithUnderlier>::Underlier,
575 { <$storage_level as TowerLevel>::WIDTH },
576 >
577 {
578 type Packed = $name;
579 }
580 };
581}
582
583macro_rules! define_byte_sliced_3d_1b {
586 ($name:ident, $packed_storage:ty, $storage_tower_level: ty) => {
587 #[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
588 #[repr(transparent)]
589 pub struct $name {
590 pub(super) data: [$packed_storage; <$storage_tower_level>::WIDTH],
591 }
592
593 impl $name {
594 pub const BYTES: usize =
595 <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
596
597 pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
598 const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT_BYTES);
599
600 #[allow(clippy::modulo_one)]
605 #[inline(always)]
606 pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
607 type Packed8b =
608 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
609
610 Packed8b::cast_ext_ref(self.data.get_unchecked(byte_index % Self::HEIGHT_BYTES))
611 .get_unchecked(byte_index / Self::HEIGHT_BYTES)
612 .to_underlier()
613 }
614
615 #[inline]
617 pub fn transpose_to(
618 &self,
619 out: &mut [PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
620 Self::HEIGHT_BYTES],
621 ) {
622 let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
623 *underliers = WithUnderlier::to_underliers_arr(self.data);
624
625 UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
626 underliers,
627 );
628 }
629
630 #[inline]
632 pub fn transpose_from(
633 underliers: &[PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
634 Self::HEIGHT_BYTES],
635 ) -> Self {
636 let mut underliers = WithUnderlier::to_underliers_arr(*underliers);
637
638 <$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<
639 $storage_tower_level,
640 >(&mut underliers);
641
642 Self {
643 data: WithUnderlier::from_underliers_arr(underliers),
644 }
645 }
646 }
647
648 impl Default for $name {
649 fn default() -> Self {
650 Self {
651 data: bytemuck::Zeroable::zeroed(),
652 }
653 }
654 }
655
656 impl Debug for $name {
657 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
658 let values_str = self
659 .iter()
660 .map(|value| format!("{}", value))
661 .collect::<Vec<_>>()
662 .join(",");
663
664 write!(f, "ByteSlicedAES{}x1b([{}])", Self::WIDTH, values_str)
665 }
666 }
667
668 impl PackedField for $name {
669 type Scalar = BinaryField1b;
670
671 const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
672
673 #[allow(clippy::modulo_one)]
674 #[inline(always)]
675 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
676 self.data
677 .get_unchecked((i / 8) % Self::HEIGHT_BYTES)
678 .get_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8)
679 }
680
681 #[allow(clippy::modulo_one)]
682 #[inline(always)]
683 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
684 self.data
685 .get_unchecked_mut((i / 8) % Self::HEIGHT_BYTES)
686 .set_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8, scalar);
687 }
688
689 fn random(mut rng: impl rand::RngCore) -> Self {
690 let data = array::from_fn(|_| <$packed_storage>::random(&mut rng));
691 Self { data }
692 }
693
694 impl_init_with_transpose!($packed_storage, BinaryField1b, $storage_tower_level);
695
696 #[allow(unreachable_patterns)]
700 #[inline]
701 fn broadcast(scalar: Self::Scalar) -> Self {
702 let underlier = <$packed_storage as WithUnderlier>::Underlier::fill_with_bit(
703 scalar.to_underlier().into(),
704 );
705 Self {
706 data: array::from_fn(|_| WithUnderlier::from_underlier(underlier)),
707 }
708 }
709
710 #[inline]
711 fn square(self) -> Self {
712 let data = array::from_fn(|i| self.data[i].clone().square());
713 Self { data }
714 }
715
716 #[inline]
717 fn invert_or_zero(self) -> Self {
718 let data = array::from_fn(|i| self.data[i].clone().invert_or_zero());
719 Self { data }
720 }
721
722 #[inline(always)]
723 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
724 type Packed8b =
725 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
726
727 if log_block_len < 3 {
728 let mut result1 = Self::default();
729 let mut result2 = Self::default();
730
731 for i in 0..Self::HEIGHT_BYTES {
732 (result1.data[i], result2.data[i]) =
733 self.data[i].interleave(other.data[i], log_block_len);
734 }
735
736 (result1, result2)
737 } else {
738 let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
739 Packed8b::cast_ext_arr_ref(&self.data);
740 let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
741 Packed8b::cast_ext_arr_ref(&other.data);
742
743 let (result1, result2) =
744 interleave_byte_sliced(self_data, other_data, log_block_len - 3);
745
746 (
747 Self {
748 data: Packed8b::cast_base_arr(result1),
749 },
750 Self {
751 data: Packed8b::cast_base_arr(result2),
752 },
753 )
754 }
755 }
756
757 #[inline]
758 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
759 if log_block_len < 3 {
760 let mut result1 = Self::default();
761 let mut result2 = Self::default();
762
763 for i in 0..Self::HEIGHT_BYTES {
764 (result1.data[i], result2.data[i]) =
765 self.data[i].unzip(other.data[i], log_block_len);
766 }
767
768 (result1, result2)
769 } else {
770 type Packed8b =
771 PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
772
773 let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
774 Packed8b::cast_ext_arr_ref(&self.data);
775 let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
776 Packed8b::cast_ext_arr_ref(&other.data);
777
778 let (result1, result2) = unzip_byte_sliced::<Packed8b, { Self::HEIGHT_BYTES }, 1>(
779 self_data,
780 other_data,
781 log_block_len - 3,
782 );
783
784 (
785 Self {
786 data: Packed8b::cast_base_arr(result1),
787 },
788 Self {
789 data: Packed8b::cast_base_arr(result2),
790 },
791 )
792 }
793 }
794 }
795
796 impl Mul for $name {
797 type Output = Self;
798
799 #[inline]
800 fn mul(self, rhs: Self) -> Self {
801 Self {
802 data: array::from_fn(|i| self.data[i].clone() * rhs.data[i].clone()),
803 }
804 }
805 }
806
807 impl Add for $name {
808 type Output = Self;
809
810 #[inline]
811 fn add(self, rhs: Self) -> Self {
812 Self {
813 data: array::from_fn(|byte_number| {
814 self.data[byte_number] + rhs.data[byte_number]
815 }),
816 }
817 }
818 }
819
820 impl AddAssign for $name {
821 #[inline]
822 fn add_assign(&mut self, rhs: Self) {
823 for (data, rhs) in zip(&mut self.data, &rhs.data) {
824 *data += *rhs
825 }
826 }
827 }
828
829 byte_sliced_common!($name, $packed_storage, BinaryField1b, $storage_tower_level);
830
831 impl PackedTransformationFactory<$name> for $name {
832 type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> =
833 IDTransformation;
834
835 fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
836 _transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
837 ) -> Self::PackedTransformation<Data> {
838 IDTransformation
839 }
840 }
841 };
842}
843
844#[inline(always)]
845fn interleave_internal_block<P: PackedField, const N: usize, const LOG_BLOCK_LEN: usize>(
846 lhs: &[P; N],
847 rhs: &[P; N],
848) -> ([P; N], [P; N]) {
849 debug_assert!(LOG_BLOCK_LEN < checked_log_2(N));
850
851 let result_1 = array::from_fn(|i| {
852 let block_index = i >> (LOG_BLOCK_LEN + 1);
853 let block_start = block_index << (LOG_BLOCK_LEN + 1);
854 let block_offset = i - block_start;
855
856 if block_offset < (1 << LOG_BLOCK_LEN) {
857 lhs[i]
858 } else {
859 rhs[i - (1 << LOG_BLOCK_LEN)]
860 }
861 });
862 let result_2 = array::from_fn(|i| {
863 let block_index = i >> (LOG_BLOCK_LEN + 1);
864 let block_start = block_index << (LOG_BLOCK_LEN + 1);
865 let block_offset = i - block_start;
866
867 if block_offset < (1 << LOG_BLOCK_LEN) {
868 lhs[i + (1 << LOG_BLOCK_LEN)]
869 } else {
870 rhs[i]
871 }
872 });
873
874 (result_1, result_2)
875}
876
877#[inline(always)]
878fn interleave_byte_sliced<P: PackedField, const N: usize>(
879 lhs: &[P; N],
880 rhs: &[P; N],
881 log_block_len: usize,
882) -> ([P; N], [P; N]) {
883 debug_assert!(checked_log_2(N) <= 4);
884
885 match log_block_len {
886 x if x >= checked_log_2(N) => interleave_big_block::<P, N>(lhs, rhs, log_block_len),
887 0 => interleave_internal_block::<P, N, 0>(lhs, rhs),
888 1 => interleave_internal_block::<P, N, 1>(lhs, rhs),
889 2 => interleave_internal_block::<P, N, 2>(lhs, rhs),
890 3 => interleave_internal_block::<P, N, 3>(lhs, rhs),
891 _ => unreachable!(),
892 }
893}
894
895#[inline(always)]
896fn unzip_byte_sliced<P: PackedField, const N: usize, const SCALAR_BYTES: usize>(
897 lhs: &[P; N],
898 rhs: &[P; N],
899 log_block_len: usize,
900) -> ([P; N], [P; N]) {
901 let mut result1: [P; N] = bytemuck::Zeroable::zeroed();
902 let mut result2: [P; N] = bytemuck::Zeroable::zeroed();
903
904 let log_height = checked_log_2(N);
905 if log_block_len < log_height {
906 let block_size = 1 << log_block_len;
907 let half = N / 2;
908 for block_offset in (0..half).step_by(block_size) {
909 let target_offset = block_offset * 2;
910
911 result1[block_offset..block_offset + block_size]
912 .copy_from_slice(&lhs[target_offset..target_offset + block_size]);
913 result1[half + target_offset..half + target_offset + block_size]
914 .copy_from_slice(&rhs[target_offset..target_offset + block_size]);
915
916 result2[block_offset..block_offset + block_size]
917 .copy_from_slice(&lhs[target_offset + block_size..target_offset + 2 * block_size]);
918 result2[half + target_offset..half + target_offset + block_size]
919 .copy_from_slice(&rhs[target_offset + block_size..target_offset + 2 * block_size]);
920 }
921 } else {
922 for i in 0..N {
923 (result1[i], result2[i]) = lhs[i].unzip(rhs[i], log_block_len - log_height);
924 }
925 }
926
927 (result1, result2)
928}
929
930#[inline(always)]
931fn interleave_big_block<P: PackedField, const N: usize>(
932 lhs: &[P; N],
933 rhs: &[P; N],
934 log_block_len: usize,
935) -> ([P; N], [P; N]) {
936 let mut result_1 = <[P; N]>::zeroed();
937 let mut result_2 = <[P; N]>::zeroed();
938
939 for i in 0..N {
940 (result_1[i], result_2[i]) = lhs[i].interleave(rhs[i], log_block_len - checked_log_2(N));
941 }
942
943 (result_1, result_2)
944}
945
946define_byte_sliced_3d!(
948 ByteSlicedAES16x128b,
949 AESTowerField128b,
950 PackedAESBinaryField16x8b,
951 TowerLevel16,
952 TowerLevel16
953);
954define_byte_sliced_3d!(
955 ByteSlicedAES16x64b,
956 AESTowerField64b,
957 PackedAESBinaryField16x8b,
958 TowerLevel8,
959 TowerLevel8
960);
961define_byte_sliced_3d!(
962 ByteSlicedAES2x16x64b,
963 AESTowerField64b,
964 PackedAESBinaryField16x8b,
965 TowerLevel8,
966 TowerLevel16
967);
968define_byte_sliced_3d!(
969 ByteSlicedAES16x32b,
970 AESTowerField32b,
971 PackedAESBinaryField16x8b,
972 TowerLevel4,
973 TowerLevel4
974);
975define_byte_sliced_3d!(
976 ByteSlicedAES4x16x32b,
977 AESTowerField32b,
978 PackedAESBinaryField16x8b,
979 TowerLevel4,
980 TowerLevel16
981);
982define_byte_sliced_3d!(
983 ByteSlicedAES16x16b,
984 AESTowerField16b,
985 PackedAESBinaryField16x8b,
986 TowerLevel2,
987 TowerLevel2
988);
989define_byte_sliced_3d!(
990 ByteSlicedAES8x16x16b,
991 AESTowerField16b,
992 PackedAESBinaryField16x8b,
993 TowerLevel2,
994 TowerLevel16
995);
996define_byte_sliced_3d!(
997 ByteSlicedAES16x8b,
998 AESTowerField8b,
999 PackedAESBinaryField16x8b,
1000 TowerLevel1,
1001 TowerLevel1
1002);
1003define_byte_sliced_3d!(
1004 ByteSlicedAES16x16x8b,
1005 AESTowerField8b,
1006 PackedAESBinaryField16x8b,
1007 TowerLevel1,
1008 TowerLevel16
1009);
1010
1011define_byte_sliced_3d_1b!(ByteSliced16x128x1b, PackedBinaryField128x1b, TowerLevel16);
1012define_byte_sliced_3d_1b!(ByteSliced8x128x1b, PackedBinaryField128x1b, TowerLevel8);
1013define_byte_sliced_3d_1b!(ByteSliced4x128x1b, PackedBinaryField128x1b, TowerLevel4);
1014define_byte_sliced_3d_1b!(ByteSliced2x128x1b, PackedBinaryField128x1b, TowerLevel2);
1015define_byte_sliced_3d_1b!(ByteSliced1x128x1b, PackedBinaryField128x1b, TowerLevel1);
1016
1017define_byte_sliced_3d!(
1019 ByteSlicedAES32x128b,
1020 AESTowerField128b,
1021 PackedAESBinaryField32x8b,
1022 TowerLevel16,
1023 TowerLevel16
1024);
1025define_byte_sliced_3d!(
1026 ByteSlicedAES32x64b,
1027 AESTowerField64b,
1028 PackedAESBinaryField32x8b,
1029 TowerLevel8,
1030 TowerLevel8
1031);
1032define_byte_sliced_3d!(
1033 ByteSlicedAES2x32x64b,
1034 AESTowerField64b,
1035 PackedAESBinaryField32x8b,
1036 TowerLevel8,
1037 TowerLevel16
1038);
1039define_byte_sliced_3d!(
1040 ByteSlicedAES32x32b,
1041 AESTowerField32b,
1042 PackedAESBinaryField32x8b,
1043 TowerLevel4,
1044 TowerLevel4
1045);
1046define_byte_sliced_3d!(
1047 ByteSlicedAES4x32x32b,
1048 AESTowerField32b,
1049 PackedAESBinaryField32x8b,
1050 TowerLevel4,
1051 TowerLevel16
1052);
1053define_byte_sliced_3d!(
1054 ByteSlicedAES32x16b,
1055 AESTowerField16b,
1056 PackedAESBinaryField32x8b,
1057 TowerLevel2,
1058 TowerLevel2
1059);
1060define_byte_sliced_3d!(
1061 ByteSlicedAES8x32x16b,
1062 AESTowerField16b,
1063 PackedAESBinaryField32x8b,
1064 TowerLevel2,
1065 TowerLevel16
1066);
1067define_byte_sliced_3d!(
1068 ByteSlicedAES32x8b,
1069 AESTowerField8b,
1070 PackedAESBinaryField32x8b,
1071 TowerLevel1,
1072 TowerLevel1
1073);
1074define_byte_sliced_3d!(
1075 ByteSlicedAES16x32x8b,
1076 AESTowerField8b,
1077 PackedAESBinaryField32x8b,
1078 TowerLevel1,
1079 TowerLevel16
1080);
1081
1082define_byte_sliced_3d_1b!(ByteSliced16x256x1b, PackedBinaryField256x1b, TowerLevel16);
1083define_byte_sliced_3d_1b!(ByteSliced8x256x1b, PackedBinaryField256x1b, TowerLevel8);
1084define_byte_sliced_3d_1b!(ByteSliced4x256x1b, PackedBinaryField256x1b, TowerLevel4);
1085define_byte_sliced_3d_1b!(ByteSliced2x256x1b, PackedBinaryField256x1b, TowerLevel2);
1086define_byte_sliced_3d_1b!(ByteSliced1x256x1b, PackedBinaryField256x1b, TowerLevel1);
1087
1088define_byte_sliced_3d!(
1090 ByteSlicedAES64x128b,
1091 AESTowerField128b,
1092 PackedAESBinaryField64x8b,
1093 TowerLevel16,
1094 TowerLevel16
1095);
1096define_byte_sliced_3d!(
1097 ByteSlicedAES64x64b,
1098 AESTowerField64b,
1099 PackedAESBinaryField64x8b,
1100 TowerLevel8,
1101 TowerLevel8
1102);
1103define_byte_sliced_3d!(
1104 ByteSlicedAES2x64x64b,
1105 AESTowerField64b,
1106 PackedAESBinaryField64x8b,
1107 TowerLevel8,
1108 TowerLevel16
1109);
1110define_byte_sliced_3d!(
1111 ByteSlicedAES64x32b,
1112 AESTowerField32b,
1113 PackedAESBinaryField64x8b,
1114 TowerLevel4,
1115 TowerLevel4
1116);
1117define_byte_sliced_3d!(
1118 ByteSlicedAES4x64x32b,
1119 AESTowerField32b,
1120 PackedAESBinaryField64x8b,
1121 TowerLevel4,
1122 TowerLevel16
1123);
1124define_byte_sliced_3d!(
1125 ByteSlicedAES64x16b,
1126 AESTowerField16b,
1127 PackedAESBinaryField64x8b,
1128 TowerLevel2,
1129 TowerLevel2
1130);
1131define_byte_sliced_3d!(
1132 ByteSlicedAES8x64x16b,
1133 AESTowerField16b,
1134 PackedAESBinaryField64x8b,
1135 TowerLevel2,
1136 TowerLevel16
1137);
1138define_byte_sliced_3d!(
1139 ByteSlicedAES64x8b,
1140 AESTowerField8b,
1141 PackedAESBinaryField64x8b,
1142 TowerLevel1,
1143 TowerLevel1
1144);
1145define_byte_sliced_3d!(
1146 ByteSlicedAES16x64x8b,
1147 AESTowerField8b,
1148 PackedAESBinaryField64x8b,
1149 TowerLevel1,
1150 TowerLevel16
1151);
1152
1153define_byte_sliced_3d_1b!(ByteSliced16x512x1b, PackedBinaryField512x1b, TowerLevel16);
1154define_byte_sliced_3d_1b!(ByteSliced8x512x1b, PackedBinaryField512x1b, TowerLevel8);
1155define_byte_sliced_3d_1b!(ByteSliced4x512x1b, PackedBinaryField512x1b, TowerLevel4);
1156define_byte_sliced_3d_1b!(ByteSliced2x512x1b, PackedBinaryField512x1b, TowerLevel2);
1157define_byte_sliced_3d_1b!(ByteSliced1x512x1b, PackedBinaryField512x1b, TowerLevel1);