binius_field/arch/portable/byte_sliced/
packed_byte_sliced.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	fmt::Debug,
6	iter::{zip, Product, Sum},
7	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_utils::checked_arithmetics::checked_log_2;
11use bytemuck::{Pod, Zeroable};
12
13use super::{invert::invert_or_zero, multiply::mul, square::square};
14use crate::{
15	arch::byte_sliced::underlier::ByteSlicedUnderlier,
16	as_packed_field::{PackScalar, PackedType},
17	binary_field::BinaryField,
18	linear_transformation::{
19		FieldLinearTransformation, IDTransformation, PackedTransformationFactory, Transformation,
20	},
21	packed_aes_field::PackedAESBinaryField32x8b,
22	tower_levels::{TowerLevel, TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8},
23	underlier::{UnderlierWithBitOps, WithUnderlier},
24	AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
25	BinaryField1b, ExtensionField, PackedAESBinaryField16x8b, PackedAESBinaryField64x8b,
26	PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedExtension,
27	PackedField,
28};
29
30/// Packed transformation for byte-sliced fields with a scalar bigger than 8b.
31///
32/// `N` is the number of bytes in the scalar.
33pub struct TransformationWrapperNxN<Inner, const N: usize>([[Inner; N]; N]);
34
35/// Byte-sliced packed field with a fixed size (16x$packed_storage).
36/// For example for 32-bit scalar the data layout is the following:
37/// [ element_0[0], element_4[0], ... ]
38/// [ element_0[1], element_4[1], ... ]
39/// [ element_0[2], element_4[2], ... ]
40/// [ element_0[3], element_4[3], ... ]
41/// [ element_1[0], element_5[0], ... ]
42/// [ element_1[1], element_5[1], ... ]
43///  ...
44macro_rules! define_byte_sliced_3d {
45	($name:ident, $scalar_type:ty, $packed_storage:ty, $scalar_tower_level: ty, $storage_tower_level: ty) => {
46		#[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
47		#[repr(transparent)]
48		pub struct $name {
49			pub(super) data: [[$packed_storage; <$scalar_tower_level as TowerLevel>::WIDTH]; <$storage_tower_level as TowerLevel>::WIDTH / <$scalar_tower_level as TowerLevel>::WIDTH],
50		}
51
52		impl $name {
53			pub const BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
54
55			const SCALAR_BYTES: usize = <$scalar_type>::N_BITS / 8;
56			pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
57			const HEIGHT: usize = Self::HEIGHT_BYTES / Self::SCALAR_BYTES;
58			const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT);
59
60			/// Get the byte at the given index.
61			///
62			/// # Safety
63			/// The caller must ensure that `byte_index` is less than `BYTES`.
64			#[allow(clippy::modulo_one)]
65			#[inline(always)]
66			pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
67				let row = byte_index % Self::HEIGHT_BYTES;
68				self.data
69					.get_unchecked(row / Self::SCALAR_BYTES)
70					.get_unchecked(row % Self::SCALAR_BYTES)
71					.get_unchecked(byte_index / Self::HEIGHT_BYTES)
72					.to_underlier()
73			}
74
75			/// Convert the byte-sliced field to an array of "ordinary" packed fields preserving the order of scalars.
76			#[inline]
77			pub fn transpose_to(&self, out: &mut [<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES]) {
78				let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
79				*underliers = bytemuck::must_cast(self.data);
80
81				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(underliers);
82			}
83
84			/// Convert an array of "ordinary" packed fields to a byte-sliced field preserving the order of scalars.
85			#[inline]
86			pub fn transpose_from(
87				underliers: &[<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES],
88			) -> Self {
89				let mut result = Self {
90					data: bytemuck::must_cast(*underliers),
91				};
92
93				<$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<$storage_tower_level>(bytemuck::must_cast_mut(&mut result.data));
94
95				result
96			}
97		}
98
99		impl Default for $name {
100			fn default() -> Self {
101				Self {
102					data: bytemuck::Zeroable::zeroed(),
103				}
104			}
105		}
106
107		impl Debug for $name {
108			fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109				let values_str = self
110					.iter()
111					.map(|value| format!("{}", value))
112					.collect::<Vec<_>>()
113					.join(",");
114
115				write!(f, "ByteSlicedAES{}x{}([{}])", Self::WIDTH, <$scalar_type>::N_BITS, values_str)
116			}
117		}
118
119		impl PackedField for $name {
120			type Scalar = $scalar_type;
121
122			const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
123
124			#[allow(clippy::modulo_one)]
125			#[inline(always)]
126			unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
127				let element_rows = self.data.get_unchecked(i % Self::HEIGHT);
128				Self::Scalar::from_bases((0..Self::SCALAR_BYTES).map(|byte_index| {
129					element_rows
130						.get_unchecked(byte_index)
131						.get_unchecked(i / Self::HEIGHT)
132				}))
133				.expect("byte index is within bounds")
134			}
135
136			#[allow(clippy::modulo_one)]
137			#[inline(always)]
138			unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
139				let element_rows = self.data.get_unchecked_mut(i % Self::HEIGHT);
140				for byte_index in 0..Self::SCALAR_BYTES {
141					element_rows
142						.get_unchecked_mut(byte_index)
143						.set_unchecked(
144							i / Self::HEIGHT,
145							scalar.get_base_unchecked(byte_index),
146						);
147				}
148			}
149
150			fn random(mut rng: impl rand::RngCore) -> Self {
151				let data = array::from_fn(|_| array::from_fn(|_| <$packed_storage>::random(&mut rng)));
152				Self { data }
153			}
154
155			#[allow(unreachable_patterns)]
156			#[inline]
157			fn broadcast(scalar: Self::Scalar) -> Self {
158				let data: [[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT] = match Self::SCALAR_BYTES {
159					1 => {
160						let packed_broadcast =
161							<$packed_storage>::broadcast(unsafe { scalar.get_base_unchecked(0) });
162						array::from_fn(|_| array::from_fn(|_| packed_broadcast))
163					}
164					Self::HEIGHT_BYTES => array::from_fn(|_| array::from_fn(|byte_index| {
165						<$packed_storage>::broadcast(unsafe {
166							scalar.get_base_unchecked(byte_index)
167						})
168					})),
169					_ => {
170						let mut data = <[[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT]>::zeroed();
171						for byte_index in 0..Self::SCALAR_BYTES {
172							let broadcast = <$packed_storage>::broadcast(unsafe {
173								scalar.get_base_unchecked(byte_index)
174							});
175
176							for i in 0..Self::HEIGHT {
177								data[i][byte_index] = broadcast;
178							}
179						}
180
181						data
182					}
183				};
184
185				Self { data }
186			}
187
188			impl_init_with_transpose!($packed_storage, $scalar_type, $storage_tower_level);
189
190			// Transposing values first is faster. Also we know that the scalar is at least 8b,
191			// so we can cast the transposed array to the array of scalars.
192			#[inline]
193			fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Clone + Send + '_ {
194				type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
195
196				let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
197				let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
198				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
199
200				let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
201				data.into_iter()
202			}
203
204			#[inline]
205			fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Clone + Send {
206				type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
207
208				let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
209				let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
210				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
211
212				let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
213				data.into_iter()
214			}
215
216			#[inline]
217			fn square(self) -> Self {
218				let mut result = Self::default();
219
220				for i in 0..Self::HEIGHT {
221					square::<$packed_storage, $scalar_tower_level>(
222						&self.data[i],
223						&mut result.data[i],
224					);
225				}
226
227				result
228			}
229
230			#[inline]
231			fn invert_or_zero(self) -> Self {
232				let mut result = Self::default();
233
234				for i in 0..Self::HEIGHT {
235					invert_or_zero::<$packed_storage, $scalar_tower_level>(
236						&self.data[i],
237						&mut result.data[i],
238					);
239				}
240
241				result
242			}
243
244			#[inline(always)]
245			fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
246
247				let self_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&self.data);
248				let other_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&other.data);
249
250				// This implementation is faster than using a loop with `copy_from_slice` for the first 4 cases
251				let (data_1, data_2) = interleave_byte_sliced(self_data, other_data, log_block_len + checked_log_2(Self::SCALAR_BYTES));
252				(
253					Self {
254						data: bytemuck::must_cast(data_1),
255					},
256					Self {
257						data: bytemuck::must_cast(data_2),
258					},
259				)
260			}
261
262			#[inline]
263			fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
264				let (result1, result2) = unzip_byte_sliced::<$packed_storage, {Self::HEIGHT_BYTES}, {Self::SCALAR_BYTES}>(bytemuck::must_cast_ref(&self.data), bytemuck::must_cast_ref(&other.data), log_block_len + checked_log_2(Self::SCALAR_BYTES));
265
266				(
267					Self {
268						data: bytemuck::must_cast(result1),
269					},
270					Self {
271						data: bytemuck::must_cast(result2),
272					},
273				)
274			}
275		}
276
277		impl Mul for $name {
278			type Output = Self;
279
280			#[inline]
281			fn mul(self, rhs: Self) -> Self {
282				let mut result = Self::default();
283
284				for i in 0..Self::HEIGHT {
285					mul::<$packed_storage, $scalar_tower_level>(
286						&self.data[i],
287						&rhs.data[i],
288						&mut result.data[i],
289					);
290				}
291
292				result
293			}
294		}
295
296		impl Add for $name {
297			type Output = Self;
298
299			#[inline]
300			fn add(self, rhs: Self) -> Self {
301				Self {
302					data: array::from_fn(|byte_number| {
303						array::from_fn(|column|
304							self.data[byte_number][column] + rhs.data[byte_number][column]
305						)
306					}),
307				}
308			}
309		}
310
311		impl AddAssign for $name {
312			#[inline]
313			fn add_assign(&mut self, rhs: Self) {
314				for (data, rhs) in zip(&mut self.data, &rhs.data) {
315					for (data, rhs) in zip(data, rhs) {
316						*data += *rhs
317					}
318				}
319			}
320		}
321
322		byte_sliced_common!($name, $packed_storage, $scalar_type, $storage_tower_level);
323
324		impl<Inner: Transformation<$packed_storage, $packed_storage>> Transformation<$name, $name> for TransformationWrapperNxN<Inner, {<$scalar_tower_level as TowerLevel>::WIDTH}> {
325			fn transform(&self, data: &$name) -> $name {
326				let mut result = <$name>::default();
327
328				for row in 0..<$name>::SCALAR_BYTES {
329					for col in 0..<$name>::SCALAR_BYTES {
330						let transformation = &self.0[col][row];
331
332						for i in 0..<$name>::HEIGHT {
333							result.data[i][row] += transformation.transform(&data.data[i][col]);
334						}
335					}
336				}
337
338				result
339			}
340		}
341
342		impl PackedTransformationFactory<$name> for $name {
343			type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> = TransformationWrapperNxN<<$packed_storage as  PackedTransformationFactory<$packed_storage>>::PackedTransformation::<[AESTowerField8b; 8]>, {<$scalar_tower_level as TowerLevel>::WIDTH}>;
344
345			fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
346				transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
347			) -> Self::PackedTransformation<Data> {
348				let transformations_8b = array::from_fn(|row| {
349					array::from_fn(|col| {
350						let row = row * 8;
351						let linear_transformation_8b = array::from_fn::<_, 8, _>(|row_8b| unsafe {
352							<<$name as PackedField>::Scalar as ExtensionField<AESTowerField8b>>::get_base_unchecked(&transformation.bases()[row + row_8b], col)
353						});
354
355						<$packed_storage as PackedTransformationFactory<$packed_storage
356						>>::make_packed_transformation(FieldLinearTransformation::new(linear_transformation_8b))
357					})
358				});
359
360				TransformationWrapperNxN(transformations_8b)
361			}
362		}
363	};
364}
365
366/// This macro implements the `from_fn` and `from_scalars` methods for byte-sliced packed fields
367/// using transpose operations. This is faster both for 1b and non-1b scalars.
368macro_rules! impl_init_with_transpose {
369	($packed_storage:ty, $scalar_type:ty, $storage_tower_level:ty) => {
370		#[inline]
371		fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
372			type TransposePacked =
373				<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
374
375			let mut data: [TransposePacked; Self::HEIGHT_BYTES] = array::from_fn(|i| {
376				PackedField::from_fn(|j| f((i << TransposePacked::LOG_WIDTH) + j))
377			});
378			let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
379			UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
380				&mut underliers,
381			);
382
383			Self {
384				data: bytemuck::must_cast(data),
385			}
386		}
387
388		#[inline]
389		fn from_scalars(scalars: impl IntoIterator<Item = Self::Scalar>) -> Self {
390			type TransposePacked =
391				<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
392
393			let mut data = [TransposePacked::default(); Self::HEIGHT_BYTES];
394			let mut iter = scalars.into_iter();
395			for v in data.iter_mut() {
396				*v = TransposePacked::from_scalars(&mut iter);
397			}
398
399			let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
400			UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
401				&mut underliers,
402			);
403
404			Self {
405				data: bytemuck::must_cast(data),
406			}
407		}
408	};
409}
410
411macro_rules! byte_sliced_common {
412	($name:ident, $packed_storage:ty, $scalar_type:ty, $storage_level:ty) => {
413		impl Add<$scalar_type> for $name {
414			type Output = Self;
415
416			#[inline]
417			fn add(self, rhs: $scalar_type) -> $name {
418				self + Self::broadcast(rhs)
419			}
420		}
421
422		impl AddAssign<$scalar_type> for $name {
423			#[inline]
424			fn add_assign(&mut self, rhs: $scalar_type) {
425				*self += Self::broadcast(rhs)
426			}
427		}
428
429		impl Sub<$scalar_type> for $name {
430			type Output = Self;
431
432			#[inline]
433			fn sub(self, rhs: $scalar_type) -> $name {
434				self.add(rhs)
435			}
436		}
437
438		impl SubAssign<$scalar_type> for $name {
439			#[inline]
440			fn sub_assign(&mut self, rhs: $scalar_type) {
441				self.add_assign(rhs)
442			}
443		}
444
445		impl Mul<$scalar_type> for $name {
446			type Output = Self;
447
448			#[inline]
449			fn mul(self, rhs: $scalar_type) -> $name {
450				self * Self::broadcast(rhs)
451			}
452		}
453
454		impl MulAssign<$scalar_type> for $name {
455			#[inline]
456			fn mul_assign(&mut self, rhs: $scalar_type) {
457				*self *= Self::broadcast(rhs);
458			}
459		}
460
461		impl Sub for $name {
462			type Output = Self;
463
464			#[inline]
465			fn sub(self, rhs: Self) -> Self {
466				self.add(rhs)
467			}
468		}
469
470		impl SubAssign for $name {
471			#[inline]
472			fn sub_assign(&mut self, rhs: Self) {
473				self.add_assign(rhs);
474			}
475		}
476
477		impl MulAssign for $name {
478			#[inline]
479			fn mul_assign(&mut self, rhs: Self) {
480				*self = *self * rhs;
481			}
482		}
483
484		impl Product for $name {
485			fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
486				let mut result = Self::one();
487
488				let mut is_first_item = true;
489				for item in iter {
490					if is_first_item {
491						result = item;
492					} else {
493						result *= item;
494					}
495
496					is_first_item = false;
497				}
498
499				result
500			}
501		}
502
503		impl Sum for $name {
504			fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
505				let mut result = Self::zero();
506
507				for item in iter {
508					result += item;
509				}
510
511				result
512			}
513		}
514
515		unsafe impl WithUnderlier for $name {
516			type Underlier = ByteSlicedUnderlier<
517				<$packed_storage as WithUnderlier>::Underlier,
518				{ <$storage_level as TowerLevel>::WIDTH },
519			>;
520
521			#[inline(always)]
522			fn to_underlier(self) -> Self::Underlier {
523				bytemuck::must_cast(self)
524			}
525
526			#[inline(always)]
527			fn to_underlier_ref(&self) -> &Self::Underlier {
528				bytemuck::must_cast_ref(self)
529			}
530
531			#[inline(always)]
532			fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
533				bytemuck::must_cast_mut(self)
534			}
535
536			#[inline(always)]
537			fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
538				bytemuck::must_cast_slice(val)
539			}
540
541			#[inline(always)]
542			fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
543				bytemuck::must_cast_slice_mut(val)
544			}
545
546			#[inline(always)]
547			fn from_underlier(val: Self::Underlier) -> Self {
548				bytemuck::must_cast(val)
549			}
550
551			#[inline(always)]
552			fn from_underlier_ref(val: &Self::Underlier) -> &Self {
553				bytemuck::must_cast_ref(val)
554			}
555
556			#[inline(always)]
557			fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
558				bytemuck::must_cast_mut(val)
559			}
560
561			#[inline(always)]
562			fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
563				bytemuck::must_cast_slice(val)
564			}
565
566			#[inline(always)]
567			fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
568				bytemuck::must_cast_slice_mut(val)
569			}
570		}
571
572		impl PackScalar<$scalar_type>
573			for ByteSlicedUnderlier<
574				<$packed_storage as WithUnderlier>::Underlier,
575				{ <$storage_level as TowerLevel>::WIDTH },
576			>
577		{
578			type Packed = $name;
579		}
580	};
581}
582
583/// Special case: byte-sliced packed 1b-fields. The order of bytes in the layout matches the one
584/// for a byte-sliced AES field, each byte contains 1b-scalar elements in the natural order.
585macro_rules! define_byte_sliced_3d_1b {
586	($name:ident, $packed_storage:ty, $storage_tower_level: ty) => {
587		#[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
588		#[repr(transparent)]
589		pub struct $name {
590			pub(super) data: [$packed_storage; <$storage_tower_level>::WIDTH],
591		}
592
593		impl $name {
594			pub const BYTES: usize =
595				<$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
596
597			pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
598			const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT_BYTES);
599
600			/// Get the byte at the given index.
601			///
602			/// # Safety
603			/// The caller must ensure that `byte_index` is less than `BYTES`.
604			#[allow(clippy::modulo_one)]
605			#[inline(always)]
606			pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
607				type Packed8b =
608					PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
609
610				Packed8b::cast_ext_ref(self.data.get_unchecked(byte_index % Self::HEIGHT_BYTES))
611					.get_unchecked(byte_index / Self::HEIGHT_BYTES)
612					.to_underlier()
613			}
614
615			/// Convert the byte-sliced field to an array of "ordinary" packed fields preserving the order of scalars.
616			#[inline]
617			pub fn transpose_to(
618				&self,
619				out: &mut [PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
620					     Self::HEIGHT_BYTES],
621			) {
622				let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
623				*underliers = WithUnderlier::to_underliers_arr(self.data);
624
625				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
626					underliers,
627				);
628			}
629
630			/// Convert an array of "ordinary" packed fields to a byte-sliced field preserving the order of scalars.
631			#[inline]
632			pub fn transpose_from(
633				underliers: &[PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
634					 Self::HEIGHT_BYTES],
635			) -> Self {
636				let mut underliers = WithUnderlier::to_underliers_arr(*underliers);
637
638				<$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<
639					$storage_tower_level,
640				>(&mut underliers);
641
642				Self {
643					data: WithUnderlier::from_underliers_arr(underliers),
644				}
645			}
646		}
647
648		impl Default for $name {
649			fn default() -> Self {
650				Self {
651					data: bytemuck::Zeroable::zeroed(),
652				}
653			}
654		}
655
656		impl Debug for $name {
657			fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
658				let values_str = self
659					.iter()
660					.map(|value| format!("{}", value))
661					.collect::<Vec<_>>()
662					.join(",");
663
664				write!(f, "ByteSlicedAES{}x1b([{}])", Self::WIDTH, values_str)
665			}
666		}
667
668		impl PackedField for $name {
669			type Scalar = BinaryField1b;
670
671			const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
672
673			#[allow(clippy::modulo_one)]
674			#[inline(always)]
675			unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
676				self.data
677					.get_unchecked((i / 8) % Self::HEIGHT_BYTES)
678					.get_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8)
679			}
680
681			#[allow(clippy::modulo_one)]
682			#[inline(always)]
683			unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
684				self.data
685					.get_unchecked_mut((i / 8) % Self::HEIGHT_BYTES)
686					.set_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8, scalar);
687			}
688
689			fn random(mut rng: impl rand::RngCore) -> Self {
690				let data = array::from_fn(|_| <$packed_storage>::random(&mut rng));
691				Self { data }
692			}
693
694			impl_init_with_transpose!($packed_storage, BinaryField1b, $storage_tower_level);
695
696			// Benchmarks show that transposing before the iteration makes it slower for 1b case,
697			// so do not override the default implementations of `iter` and `into_iter`.
698
699			#[allow(unreachable_patterns)]
700			#[inline]
701			fn broadcast(scalar: Self::Scalar) -> Self {
702				let underlier = <$packed_storage as WithUnderlier>::Underlier::fill_with_bit(
703					scalar.to_underlier().into(),
704				);
705				Self {
706					data: array::from_fn(|_| WithUnderlier::from_underlier(underlier)),
707				}
708			}
709
710			#[inline]
711			fn square(self) -> Self {
712				let data = array::from_fn(|i| self.data[i].clone().square());
713				Self { data }
714			}
715
716			#[inline]
717			fn invert_or_zero(self) -> Self {
718				let data = array::from_fn(|i| self.data[i].clone().invert_or_zero());
719				Self { data }
720			}
721
722			#[inline(always)]
723			fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
724				type Packed8b =
725					PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
726
727				if log_block_len < 3 {
728					let mut result1 = Self::default();
729					let mut result2 = Self::default();
730
731					for i in 0..Self::HEIGHT_BYTES {
732						(result1.data[i], result2.data[i]) =
733							self.data[i].interleave(other.data[i], log_block_len);
734					}
735
736					(result1, result2)
737				} else {
738					let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
739						Packed8b::cast_ext_arr_ref(&self.data);
740					let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
741						Packed8b::cast_ext_arr_ref(&other.data);
742
743					let (result1, result2) =
744						interleave_byte_sliced(self_data, other_data, log_block_len - 3);
745
746					(
747						Self {
748							data: Packed8b::cast_base_arr(result1),
749						},
750						Self {
751							data: Packed8b::cast_base_arr(result2),
752						},
753					)
754				}
755			}
756
757			#[inline]
758			fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
759				if log_block_len < 3 {
760					let mut result1 = Self::default();
761					let mut result2 = Self::default();
762
763					for i in 0..Self::HEIGHT_BYTES {
764						(result1.data[i], result2.data[i]) =
765							self.data[i].unzip(other.data[i], log_block_len);
766					}
767
768					(result1, result2)
769				} else {
770					type Packed8b =
771						PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
772
773					let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
774						Packed8b::cast_ext_arr_ref(&self.data);
775					let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
776						Packed8b::cast_ext_arr_ref(&other.data);
777
778					let (result1, result2) = unzip_byte_sliced::<Packed8b, { Self::HEIGHT_BYTES }, 1>(
779						self_data,
780						other_data,
781						log_block_len - 3,
782					);
783
784					(
785						Self {
786							data: Packed8b::cast_base_arr(result1),
787						},
788						Self {
789							data: Packed8b::cast_base_arr(result2),
790						},
791					)
792				}
793			}
794		}
795
796		impl Mul for $name {
797			type Output = Self;
798
799			#[inline]
800			fn mul(self, rhs: Self) -> Self {
801				Self {
802					data: array::from_fn(|i| self.data[i].clone() * rhs.data[i].clone()),
803				}
804			}
805		}
806
807		impl Add for $name {
808			type Output = Self;
809
810			#[inline]
811			fn add(self, rhs: Self) -> Self {
812				Self {
813					data: array::from_fn(|byte_number| {
814						self.data[byte_number] + rhs.data[byte_number]
815					}),
816				}
817			}
818		}
819
820		impl AddAssign for $name {
821			#[inline]
822			fn add_assign(&mut self, rhs: Self) {
823				for (data, rhs) in zip(&mut self.data, &rhs.data) {
824					*data += *rhs
825				}
826			}
827		}
828
829		byte_sliced_common!($name, $packed_storage, BinaryField1b, $storage_tower_level);
830
831		impl PackedTransformationFactory<$name> for $name {
832			type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> =
833				IDTransformation;
834
835			fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
836				_transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
837			) -> Self::PackedTransformation<Data> {
838				IDTransformation
839			}
840		}
841	};
842}
843
844#[inline(always)]
845fn interleave_internal_block<P: PackedField, const N: usize, const LOG_BLOCK_LEN: usize>(
846	lhs: &[P; N],
847	rhs: &[P; N],
848) -> ([P; N], [P; N]) {
849	debug_assert!(LOG_BLOCK_LEN < checked_log_2(N));
850
851	let result_1 = array::from_fn(|i| {
852		let block_index = i >> (LOG_BLOCK_LEN + 1);
853		let block_start = block_index << (LOG_BLOCK_LEN + 1);
854		let block_offset = i - block_start;
855
856		if block_offset < (1 << LOG_BLOCK_LEN) {
857			lhs[i]
858		} else {
859			rhs[i - (1 << LOG_BLOCK_LEN)]
860		}
861	});
862	let result_2 = array::from_fn(|i| {
863		let block_index = i >> (LOG_BLOCK_LEN + 1);
864		let block_start = block_index << (LOG_BLOCK_LEN + 1);
865		let block_offset = i - block_start;
866
867		if block_offset < (1 << LOG_BLOCK_LEN) {
868			lhs[i + (1 << LOG_BLOCK_LEN)]
869		} else {
870			rhs[i]
871		}
872	});
873
874	(result_1, result_2)
875}
876
877#[inline(always)]
878fn interleave_byte_sliced<P: PackedField, const N: usize>(
879	lhs: &[P; N],
880	rhs: &[P; N],
881	log_block_len: usize,
882) -> ([P; N], [P; N]) {
883	debug_assert!(checked_log_2(N) <= 4);
884
885	match log_block_len {
886		x if x >= checked_log_2(N) => interleave_big_block::<P, N>(lhs, rhs, log_block_len),
887		0 => interleave_internal_block::<P, N, 0>(lhs, rhs),
888		1 => interleave_internal_block::<P, N, 1>(lhs, rhs),
889		2 => interleave_internal_block::<P, N, 2>(lhs, rhs),
890		3 => interleave_internal_block::<P, N, 3>(lhs, rhs),
891		_ => unreachable!(),
892	}
893}
894
895#[inline(always)]
896fn unzip_byte_sliced<P: PackedField, const N: usize, const SCALAR_BYTES: usize>(
897	lhs: &[P; N],
898	rhs: &[P; N],
899	log_block_len: usize,
900) -> ([P; N], [P; N]) {
901	let mut result1: [P; N] = bytemuck::Zeroable::zeroed();
902	let mut result2: [P; N] = bytemuck::Zeroable::zeroed();
903
904	let log_height = checked_log_2(N);
905	if log_block_len < log_height {
906		let block_size = 1 << log_block_len;
907		let half = N / 2;
908		for block_offset in (0..half).step_by(block_size) {
909			let target_offset = block_offset * 2;
910
911			result1[block_offset..block_offset + block_size]
912				.copy_from_slice(&lhs[target_offset..target_offset + block_size]);
913			result1[half + target_offset..half + target_offset + block_size]
914				.copy_from_slice(&rhs[target_offset..target_offset + block_size]);
915
916			result2[block_offset..block_offset + block_size]
917				.copy_from_slice(&lhs[target_offset + block_size..target_offset + 2 * block_size]);
918			result2[half + target_offset..half + target_offset + block_size]
919				.copy_from_slice(&rhs[target_offset + block_size..target_offset + 2 * block_size]);
920		}
921	} else {
922		for i in 0..N {
923			(result1[i], result2[i]) = lhs[i].unzip(rhs[i], log_block_len - log_height);
924		}
925	}
926
927	(result1, result2)
928}
929
930#[inline(always)]
931fn interleave_big_block<P: PackedField, const N: usize>(
932	lhs: &[P; N],
933	rhs: &[P; N],
934	log_block_len: usize,
935) -> ([P; N], [P; N]) {
936	let mut result_1 = <[P; N]>::zeroed();
937	let mut result_2 = <[P; N]>::zeroed();
938
939	for i in 0..N {
940		(result_1[i], result_2[i]) = lhs[i].interleave(rhs[i], log_block_len - checked_log_2(N));
941	}
942
943	(result_1, result_2)
944}
945
946// 128 bit
947define_byte_sliced_3d!(
948	ByteSlicedAES16x128b,
949	AESTowerField128b,
950	PackedAESBinaryField16x8b,
951	TowerLevel16,
952	TowerLevel16
953);
954define_byte_sliced_3d!(
955	ByteSlicedAES16x64b,
956	AESTowerField64b,
957	PackedAESBinaryField16x8b,
958	TowerLevel8,
959	TowerLevel8
960);
961define_byte_sliced_3d!(
962	ByteSlicedAES2x16x64b,
963	AESTowerField64b,
964	PackedAESBinaryField16x8b,
965	TowerLevel8,
966	TowerLevel16
967);
968define_byte_sliced_3d!(
969	ByteSlicedAES16x32b,
970	AESTowerField32b,
971	PackedAESBinaryField16x8b,
972	TowerLevel4,
973	TowerLevel4
974);
975define_byte_sliced_3d!(
976	ByteSlicedAES4x16x32b,
977	AESTowerField32b,
978	PackedAESBinaryField16x8b,
979	TowerLevel4,
980	TowerLevel16
981);
982define_byte_sliced_3d!(
983	ByteSlicedAES16x16b,
984	AESTowerField16b,
985	PackedAESBinaryField16x8b,
986	TowerLevel2,
987	TowerLevel2
988);
989define_byte_sliced_3d!(
990	ByteSlicedAES8x16x16b,
991	AESTowerField16b,
992	PackedAESBinaryField16x8b,
993	TowerLevel2,
994	TowerLevel16
995);
996define_byte_sliced_3d!(
997	ByteSlicedAES16x8b,
998	AESTowerField8b,
999	PackedAESBinaryField16x8b,
1000	TowerLevel1,
1001	TowerLevel1
1002);
1003define_byte_sliced_3d!(
1004	ByteSlicedAES16x16x8b,
1005	AESTowerField8b,
1006	PackedAESBinaryField16x8b,
1007	TowerLevel1,
1008	TowerLevel16
1009);
1010
1011define_byte_sliced_3d_1b!(ByteSliced16x128x1b, PackedBinaryField128x1b, TowerLevel16);
1012define_byte_sliced_3d_1b!(ByteSliced8x128x1b, PackedBinaryField128x1b, TowerLevel8);
1013define_byte_sliced_3d_1b!(ByteSliced4x128x1b, PackedBinaryField128x1b, TowerLevel4);
1014define_byte_sliced_3d_1b!(ByteSliced2x128x1b, PackedBinaryField128x1b, TowerLevel2);
1015define_byte_sliced_3d_1b!(ByteSliced1x128x1b, PackedBinaryField128x1b, TowerLevel1);
1016
1017// 256 bit
1018define_byte_sliced_3d!(
1019	ByteSlicedAES32x128b,
1020	AESTowerField128b,
1021	PackedAESBinaryField32x8b,
1022	TowerLevel16,
1023	TowerLevel16
1024);
1025define_byte_sliced_3d!(
1026	ByteSlicedAES32x64b,
1027	AESTowerField64b,
1028	PackedAESBinaryField32x8b,
1029	TowerLevel8,
1030	TowerLevel8
1031);
1032define_byte_sliced_3d!(
1033	ByteSlicedAES2x32x64b,
1034	AESTowerField64b,
1035	PackedAESBinaryField32x8b,
1036	TowerLevel8,
1037	TowerLevel16
1038);
1039define_byte_sliced_3d!(
1040	ByteSlicedAES32x32b,
1041	AESTowerField32b,
1042	PackedAESBinaryField32x8b,
1043	TowerLevel4,
1044	TowerLevel4
1045);
1046define_byte_sliced_3d!(
1047	ByteSlicedAES4x32x32b,
1048	AESTowerField32b,
1049	PackedAESBinaryField32x8b,
1050	TowerLevel4,
1051	TowerLevel16
1052);
1053define_byte_sliced_3d!(
1054	ByteSlicedAES32x16b,
1055	AESTowerField16b,
1056	PackedAESBinaryField32x8b,
1057	TowerLevel2,
1058	TowerLevel2
1059);
1060define_byte_sliced_3d!(
1061	ByteSlicedAES8x32x16b,
1062	AESTowerField16b,
1063	PackedAESBinaryField32x8b,
1064	TowerLevel2,
1065	TowerLevel16
1066);
1067define_byte_sliced_3d!(
1068	ByteSlicedAES32x8b,
1069	AESTowerField8b,
1070	PackedAESBinaryField32x8b,
1071	TowerLevel1,
1072	TowerLevel1
1073);
1074define_byte_sliced_3d!(
1075	ByteSlicedAES16x32x8b,
1076	AESTowerField8b,
1077	PackedAESBinaryField32x8b,
1078	TowerLevel1,
1079	TowerLevel16
1080);
1081
1082define_byte_sliced_3d_1b!(ByteSliced16x256x1b, PackedBinaryField256x1b, TowerLevel16);
1083define_byte_sliced_3d_1b!(ByteSliced8x256x1b, PackedBinaryField256x1b, TowerLevel8);
1084define_byte_sliced_3d_1b!(ByteSliced4x256x1b, PackedBinaryField256x1b, TowerLevel4);
1085define_byte_sliced_3d_1b!(ByteSliced2x256x1b, PackedBinaryField256x1b, TowerLevel2);
1086define_byte_sliced_3d_1b!(ByteSliced1x256x1b, PackedBinaryField256x1b, TowerLevel1);
1087
1088// 512 bit
1089define_byte_sliced_3d!(
1090	ByteSlicedAES64x128b,
1091	AESTowerField128b,
1092	PackedAESBinaryField64x8b,
1093	TowerLevel16,
1094	TowerLevel16
1095);
1096define_byte_sliced_3d!(
1097	ByteSlicedAES64x64b,
1098	AESTowerField64b,
1099	PackedAESBinaryField64x8b,
1100	TowerLevel8,
1101	TowerLevel8
1102);
1103define_byte_sliced_3d!(
1104	ByteSlicedAES2x64x64b,
1105	AESTowerField64b,
1106	PackedAESBinaryField64x8b,
1107	TowerLevel8,
1108	TowerLevel16
1109);
1110define_byte_sliced_3d!(
1111	ByteSlicedAES64x32b,
1112	AESTowerField32b,
1113	PackedAESBinaryField64x8b,
1114	TowerLevel4,
1115	TowerLevel4
1116);
1117define_byte_sliced_3d!(
1118	ByteSlicedAES4x64x32b,
1119	AESTowerField32b,
1120	PackedAESBinaryField64x8b,
1121	TowerLevel4,
1122	TowerLevel16
1123);
1124define_byte_sliced_3d!(
1125	ByteSlicedAES64x16b,
1126	AESTowerField16b,
1127	PackedAESBinaryField64x8b,
1128	TowerLevel2,
1129	TowerLevel2
1130);
1131define_byte_sliced_3d!(
1132	ByteSlicedAES8x64x16b,
1133	AESTowerField16b,
1134	PackedAESBinaryField64x8b,
1135	TowerLevel2,
1136	TowerLevel16
1137);
1138define_byte_sliced_3d!(
1139	ByteSlicedAES64x8b,
1140	AESTowerField8b,
1141	PackedAESBinaryField64x8b,
1142	TowerLevel1,
1143	TowerLevel1
1144);
1145define_byte_sliced_3d!(
1146	ByteSlicedAES16x64x8b,
1147	AESTowerField8b,
1148	PackedAESBinaryField64x8b,
1149	TowerLevel1,
1150	TowerLevel16
1151);
1152
1153define_byte_sliced_3d_1b!(ByteSliced16x512x1b, PackedBinaryField512x1b, TowerLevel16);
1154define_byte_sliced_3d_1b!(ByteSliced8x512x1b, PackedBinaryField512x1b, TowerLevel8);
1155define_byte_sliced_3d_1b!(ByteSliced4x512x1b, PackedBinaryField512x1b, TowerLevel4);
1156define_byte_sliced_3d_1b!(ByteSliced2x512x1b, PackedBinaryField512x1b, TowerLevel2);
1157define_byte_sliced_3d_1b!(ByteSliced1x512x1b, PackedBinaryField512x1b, TowerLevel1);