binius_field/arch/portable/byte_sliced/
packed_byte_sliced.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	fmt::Debug,
6	iter::{Product, Sum, zip},
7	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_utils::checked_arithmetics::checked_log_2;
11use bytemuck::{Pod, Zeroable};
12
13use super::{invert::invert_or_zero, multiply::mul, square::square};
14use crate::{
15	AESTowerField8b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField128b,
16	BinaryField1b, ExtensionField, PackedAESBinaryField16x8b, PackedAESBinaryField64x8b,
17	PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedExtension,
18	PackedField,
19	arch::{
20		byte_sliced::underlier::ByteSlicedUnderlier, portable::packed_scaled::ScaledPackedField,
21	},
22	as_packed_field::{PackScalar, PackedType},
23	binary_field::BinaryField,
24	linear_transformation::{
25		FieldLinearTransformation, IDTransformation, PackedTransformationFactory, Transformation,
26	},
27	packed_aes_field::PackedAESBinaryField32x8b,
28	tower_levels::{TowerLevel, TowerLevel1, TowerLevel2, TowerLevel4, TowerLevel8, TowerLevel16},
29	underlier::{UnderlierWithBitOps, WithUnderlier},
30};
31
32/// Packed transformation for byte-sliced fields with a scalar bigger than 8b.
33///
34/// `N` is the number of bytes in the scalar.
35pub struct TransformationWrapperNxN<Inner, const N: usize>([[Inner; N]; N]);
36
37/// Byte-sliced packed field with a fixed size (16x$packed_storage).
38/// For example for 32-bit scalar the data layout is the following:
39///
40/// ```plain
41/// [ element_0[0], element_4[0], ... ]
42/// [ element_0[1], element_4[1], ... ]
43/// [ element_0[2], element_4[2], ... ]
44/// [ element_0[3], element_4[3], ... ]
45/// [ element_1[0], element_5[0], ... ]
46/// [ element_1[1], element_5[1], ... ]
47/// ...
48/// ```
49macro_rules! define_byte_sliced_3d {
50	($name:ident, $scalar_type:ty, $packed_storage:ty, $scalar_tower_level: ty, $storage_tower_level: ty) => {
51		#[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
52		#[repr(transparent)]
53		pub struct $name {
54			pub(super) data: [[$packed_storage; <$scalar_tower_level as TowerLevel>::WIDTH]; <$storage_tower_level as TowerLevel>::WIDTH / <$scalar_tower_level as TowerLevel>::WIDTH],
55		}
56
57		impl $name {
58			pub const BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
59
60			const SCALAR_BYTES: usize = <$scalar_type>::N_BITS / 8;
61			pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
62			const HEIGHT: usize = Self::HEIGHT_BYTES / Self::SCALAR_BYTES;
63			const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT);
64
65			/// Get the byte at the given index.
66			///
67			/// # Safety
68			/// The caller must ensure that `byte_index` is less than `BYTES`.
69			#[allow(clippy::modulo_one)]
70			#[inline(always)]
71			pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
72				let row = byte_index % Self::HEIGHT_BYTES;
73				unsafe {
74					self.data
75						.get_unchecked(row / Self::SCALAR_BYTES)
76						.get_unchecked(row % Self::SCALAR_BYTES)
77						.get_unchecked(byte_index / Self::HEIGHT_BYTES)
78						.to_underlier()
79				}
80			}
81
82			/// Convert the byte-sliced field to an array of "ordinary" packed fields preserving the order of scalars.
83			#[inline]
84			pub fn transpose_to(&self, out: &mut [<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES]) {
85				let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
86				*underliers = bytemuck::must_cast(self.data);
87
88				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(underliers);
89			}
90
91			/// Convert an array of "ordinary" packed fields to a byte-sliced field preserving the order of scalars.
92			#[inline]
93			pub fn transpose_from(
94				underliers: &[<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES],
95			) -> Self {
96				let mut result = Self {
97					data: bytemuck::must_cast(*underliers),
98				};
99
100				<$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<$storage_tower_level>(bytemuck::must_cast_mut(&mut result.data));
101
102				result
103			}
104		}
105
106		impl Default for $name {
107			fn default() -> Self {
108				Self {
109					data: bytemuck::Zeroable::zeroed(),
110				}
111			}
112		}
113
114		impl Debug for $name {
115			fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116				let values_str = self
117					.iter()
118					.map(|value| format!("{}", value))
119					.collect::<Vec<_>>()
120					.join(",");
121
122				write!(f, "ByteSlicedAES{}x{}([{}])", Self::WIDTH, <$scalar_type>::N_BITS, values_str)
123			}
124		}
125
126		impl PackedField for $name {
127			type Scalar = $scalar_type;
128
129			const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
130
131			#[allow(clippy::modulo_one)]
132			#[inline(always)]
133			unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
134				let element_rows = unsafe { self.data.get_unchecked(i % Self::HEIGHT) };
135				Self::Scalar::from_bases((0..Self::SCALAR_BYTES).map(|byte_index| unsafe {
136					element_rows
137						.get_unchecked(byte_index)
138						.get_unchecked(i / Self::HEIGHT)
139				}))
140				.expect("byte index is within bounds")
141			}
142
143			#[allow(clippy::modulo_one)]
144			#[inline(always)]
145			unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
146				let element_rows = unsafe { self.data.get_unchecked_mut(i % Self::HEIGHT) };
147				for byte_index in 0..Self::SCALAR_BYTES {
148					unsafe {
149						element_rows
150							.get_unchecked_mut(byte_index)
151							.set_unchecked(
152								i / Self::HEIGHT,
153								scalar.get_base_unchecked(byte_index),
154							);
155					}
156				}
157			}
158
159			fn random(mut rng: impl rand::RngCore) -> Self {
160				let data = array::from_fn(|_| array::from_fn(|_| <$packed_storage>::random(&mut rng)));
161				Self { data }
162			}
163
164			#[allow(unreachable_patterns)]
165			#[inline]
166			fn broadcast(scalar: Self::Scalar) -> Self {
167				let data: [[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT] = match Self::SCALAR_BYTES {
168					1 => {
169						let packed_broadcast =
170							<$packed_storage>::broadcast(unsafe { scalar.get_base_unchecked(0) });
171						array::from_fn(|_| array::from_fn(|_| packed_broadcast))
172					}
173					Self::HEIGHT_BYTES => array::from_fn(|_| array::from_fn(|byte_index| {
174						<$packed_storage>::broadcast(unsafe {
175							scalar.get_base_unchecked(byte_index)
176						})
177					})),
178					_ => {
179						let mut data = <[[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT]>::zeroed();
180						for byte_index in 0..Self::SCALAR_BYTES {
181							let broadcast = <$packed_storage>::broadcast(unsafe {
182								scalar.get_base_unchecked(byte_index)
183							});
184
185							for i in 0..Self::HEIGHT {
186								data[i][byte_index] = broadcast;
187							}
188						}
189
190						data
191					}
192				};
193
194				Self { data }
195			}
196
197			impl_init_with_transpose!($packed_storage, $scalar_type, $storage_tower_level);
198
199			// Transposing values first is faster. Also we know that the scalar is at least 8b,
200			// so we can cast the transposed array to the array of scalars.
201			#[inline]
202			fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Clone + Send + '_ {
203				type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
204
205				let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
206				let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
207				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
208
209				let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
210				data.into_iter()
211			}
212
213			#[inline]
214			fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Clone + Send {
215				type TransposePacked = <<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
216
217				let mut data: [TransposePacked; Self::HEIGHT_BYTES] = bytemuck::must_cast(self.data);
218				let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
219				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(&mut underliers);
220
221				let data: [Self::Scalar; {Self::WIDTH}] = bytemuck::must_cast(data);
222				data.into_iter()
223			}
224
225			#[inline]
226			fn square(self) -> Self {
227				let mut result = Self::default();
228
229				for i in 0..Self::HEIGHT {
230					square::<$packed_storage, $scalar_tower_level>(
231						&self.data[i],
232						&mut result.data[i],
233					);
234				}
235
236				result
237			}
238
239			#[inline]
240			fn invert_or_zero(self) -> Self {
241				let mut result = Self::default();
242
243				for i in 0..Self::HEIGHT {
244					invert_or_zero::<$packed_storage, $scalar_tower_level>(
245						&self.data[i],
246						&mut result.data[i],
247					);
248				}
249
250				result
251			}
252
253			#[inline(always)]
254			fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
255
256				let self_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&self.data);
257				let other_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&other.data);
258
259				// This implementation is faster than using a loop with `copy_from_slice` for the first 4 cases
260				let (data_1, data_2) = interleave_byte_sliced(self_data, other_data, log_block_len + checked_log_2(Self::SCALAR_BYTES));
261				(
262					Self {
263						data: bytemuck::must_cast(data_1),
264					},
265					Self {
266						data: bytemuck::must_cast(data_2),
267					},
268				)
269			}
270
271			#[inline]
272			fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
273				let (result1, result2) = unzip_byte_sliced::<$packed_storage, {Self::HEIGHT_BYTES}, {Self::SCALAR_BYTES}>(bytemuck::must_cast_ref(&self.data), bytemuck::must_cast_ref(&other.data), log_block_len + checked_log_2(Self::SCALAR_BYTES));
274
275				(
276					Self {
277						data: bytemuck::must_cast(result1),
278					},
279					Self {
280						data: bytemuck::must_cast(result2),
281					},
282				)
283			}
284
285			impl_spread!($name, $scalar_type, $packed_storage, $storage_tower_level);
286		}
287
288		impl Mul for $name {
289			type Output = Self;
290
291			#[inline]
292			fn mul(self, rhs: Self) -> Self {
293				let mut result = Self::default();
294
295				for i in 0..Self::HEIGHT {
296					mul::<$packed_storage, $scalar_tower_level>(
297						&self.data[i],
298						&rhs.data[i],
299						&mut result.data[i],
300					);
301				}
302
303				result
304			}
305		}
306
307		impl Add for $name {
308			type Output = Self;
309
310			#[inline]
311			fn add(self, rhs: Self) -> Self {
312				Self {
313					data: array::from_fn(|byte_number| {
314						array::from_fn(|column|
315							self.data[byte_number][column] + rhs.data[byte_number][column]
316						)
317					}),
318				}
319			}
320		}
321
322		impl AddAssign for $name {
323			#[inline]
324			fn add_assign(&mut self, rhs: Self) {
325				for (data, rhs) in zip(&mut self.data, &rhs.data) {
326					for (data, rhs) in zip(data, rhs) {
327						*data += *rhs
328					}
329				}
330			}
331		}
332
333		byte_sliced_common!($name, $packed_storage, $scalar_type, $storage_tower_level);
334
335		impl<Inner: Transformation<$packed_storage, $packed_storage>> Transformation<$name, $name> for TransformationWrapperNxN<Inner, {<$scalar_tower_level as TowerLevel>::WIDTH}> {
336			fn transform(&self, data: &$name) -> $name {
337				let mut result = <$name>::default();
338
339				for row in 0..<$name>::SCALAR_BYTES {
340					for col in 0..<$name>::SCALAR_BYTES {
341						let transformation = &self.0[col][row];
342
343						for i in 0..<$name>::HEIGHT {
344							result.data[i][row] += transformation.transform(&data.data[i][col]);
345						}
346					}
347				}
348
349				result
350			}
351		}
352
353		impl PackedTransformationFactory<$name> for $name {
354			type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> = TransformationWrapperNxN<<$packed_storage as  PackedTransformationFactory<$packed_storage>>::PackedTransformation::<[AESTowerField8b; 8]>, {<$scalar_tower_level as TowerLevel>::WIDTH}>;
355
356			fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
357				transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
358			) -> Self::PackedTransformation<Data> {
359				let transformations_8b = array::from_fn(|row| {
360					array::from_fn(|col| {
361						let row = row * 8;
362						let linear_transformation_8b = array::from_fn::<_, 8, _>(|row_8b| unsafe {
363							<<$name as PackedField>::Scalar as ExtensionField<AESTowerField8b>>::get_base_unchecked(&transformation.bases()[row + row_8b], col)
364						});
365
366						<$packed_storage as PackedTransformationFactory<$packed_storage
367						>>::make_packed_transformation(FieldLinearTransformation::new(linear_transformation_8b))
368					})
369				});
370
371				TransformationWrapperNxN(transformations_8b)
372			}
373		}
374	};
375}
376
377/// This macro implements the `from_fn` and `from_scalars` methods for byte-sliced packed fields
378/// using transpose operations. This is faster both for 1b and non-1b scalars.
379macro_rules! impl_init_with_transpose {
380	($packed_storage:ty, $scalar_type:ty, $storage_tower_level:ty) => {
381		#[inline]
382		fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
383			type TransposePacked =
384				<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
385
386			let mut data: [TransposePacked; Self::HEIGHT_BYTES] = array::from_fn(|i| {
387				PackedField::from_fn(|j| f((i << TransposePacked::LOG_WIDTH) + j))
388			});
389			let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
390			UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
391				&mut underliers,
392			);
393
394			Self {
395				data: bytemuck::must_cast(data),
396			}
397		}
398
399		#[inline]
400		fn from_scalars(scalars: impl IntoIterator<Item = Self::Scalar>) -> Self {
401			type TransposePacked =
402				<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
403
404			let mut data = [TransposePacked::default(); Self::HEIGHT_BYTES];
405			let mut iter = scalars.into_iter();
406			for v in data.iter_mut() {
407				*v = TransposePacked::from_scalars(&mut iter);
408			}
409
410			let mut underliers = WithUnderlier::to_underliers_arr_ref_mut(&mut data);
411			UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(
412				&mut underliers,
413			);
414
415			Self {
416				data: bytemuck::must_cast(data),
417			}
418		}
419	};
420}
421
422macro_rules! impl_spread {
423	($name:ident, $scalar_type:ty, $packed_storage:ty, $storage_tower_level:ty) => {
424		#[inline]
425		unsafe fn spread_unchecked(mut self, log_block_len: usize, block_idx: usize) -> Self {
426			{
427				let data: &mut [$packed_storage; Self::HEIGHT_BYTES] =
428					bytemuck::must_cast_mut(&mut self.data);
429				let underliers = WithUnderlier::to_underliers_arr_ref_mut(data);
430				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
431					underliers,
432				);
433			}
434
435			type PackedSequentialField =
436				<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed;
437			let scaled_data: ScaledPackedField<PackedSequentialField, { Self::HEIGHT_BYTES }> =
438				bytemuck::must_cast(self);
439
440			let mut result = unsafe { scaled_data.spread_unchecked(log_block_len, block_idx) };
441			let data: &mut [$packed_storage; Self::HEIGHT_BYTES] =
442				bytemuck::must_cast_mut(&mut result);
443			let underliers = WithUnderlier::to_underliers_arr_ref_mut(data);
444			UnderlierWithBitOps::transpose_bytes_to_byte_sliced::<$storage_tower_level>(underliers);
445			bytemuck::must_cast(result)
446		}
447	};
448}
449
450macro_rules! byte_sliced_common {
451	($name:ident, $packed_storage:ty, $scalar_type:ty, $storage_level:ty) => {
452		impl Add<$scalar_type> for $name {
453			type Output = Self;
454
455			#[inline]
456			fn add(self, rhs: $scalar_type) -> $name {
457				self + Self::broadcast(rhs)
458			}
459		}
460
461		impl AddAssign<$scalar_type> for $name {
462			#[inline]
463			fn add_assign(&mut self, rhs: $scalar_type) {
464				*self += Self::broadcast(rhs)
465			}
466		}
467
468		impl Sub<$scalar_type> for $name {
469			type Output = Self;
470
471			#[inline]
472			fn sub(self, rhs: $scalar_type) -> $name {
473				self.add(rhs)
474			}
475		}
476
477		impl SubAssign<$scalar_type> for $name {
478			#[inline]
479			fn sub_assign(&mut self, rhs: $scalar_type) {
480				self.add_assign(rhs)
481			}
482		}
483
484		impl Mul<$scalar_type> for $name {
485			type Output = Self;
486
487			#[inline]
488			fn mul(self, rhs: $scalar_type) -> $name {
489				self * Self::broadcast(rhs)
490			}
491		}
492
493		impl MulAssign<$scalar_type> for $name {
494			#[inline]
495			fn mul_assign(&mut self, rhs: $scalar_type) {
496				*self *= Self::broadcast(rhs);
497			}
498		}
499
500		impl Sub for $name {
501			type Output = Self;
502
503			#[inline]
504			fn sub(self, rhs: Self) -> Self {
505				self.add(rhs)
506			}
507		}
508
509		impl SubAssign for $name {
510			#[inline]
511			fn sub_assign(&mut self, rhs: Self) {
512				self.add_assign(rhs);
513			}
514		}
515
516		impl MulAssign for $name {
517			#[inline]
518			fn mul_assign(&mut self, rhs: Self) {
519				*self = *self * rhs;
520			}
521		}
522
523		impl Product for $name {
524			fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
525				let mut result = Self::one();
526
527				let mut is_first_item = true;
528				for item in iter {
529					if is_first_item {
530						result = item;
531					} else {
532						result *= item;
533					}
534
535					is_first_item = false;
536				}
537
538				result
539			}
540		}
541
542		impl Sum for $name {
543			fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
544				let mut result = Self::zero();
545
546				for item in iter {
547					result += item;
548				}
549
550				result
551			}
552		}
553
554		unsafe impl WithUnderlier for $name {
555			type Underlier = ByteSlicedUnderlier<
556				<$packed_storage as WithUnderlier>::Underlier,
557				{ <$storage_level as TowerLevel>::WIDTH },
558			>;
559
560			#[inline(always)]
561			fn to_underlier(self) -> Self::Underlier {
562				bytemuck::must_cast(self)
563			}
564
565			#[inline(always)]
566			fn to_underlier_ref(&self) -> &Self::Underlier {
567				bytemuck::must_cast_ref(self)
568			}
569
570			#[inline(always)]
571			fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
572				bytemuck::must_cast_mut(self)
573			}
574
575			#[inline(always)]
576			fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
577				bytemuck::must_cast_slice(val)
578			}
579
580			#[inline(always)]
581			fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
582				bytemuck::must_cast_slice_mut(val)
583			}
584
585			#[inline(always)]
586			fn from_underlier(val: Self::Underlier) -> Self {
587				bytemuck::must_cast(val)
588			}
589
590			#[inline(always)]
591			fn from_underlier_ref(val: &Self::Underlier) -> &Self {
592				bytemuck::must_cast_ref(val)
593			}
594
595			#[inline(always)]
596			fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
597				bytemuck::must_cast_mut(val)
598			}
599
600			#[inline(always)]
601			fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
602				bytemuck::must_cast_slice(val)
603			}
604
605			#[inline(always)]
606			fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
607				bytemuck::must_cast_slice_mut(val)
608			}
609		}
610
611		impl PackScalar<$scalar_type>
612			for ByteSlicedUnderlier<
613				<$packed_storage as WithUnderlier>::Underlier,
614				{ <$storage_level as TowerLevel>::WIDTH },
615			>
616		{
617			type Packed = $name;
618		}
619	};
620}
621
622/// Special case: byte-sliced packed 1b-fields. The order of bytes in the layout matches the one
623/// for a byte-sliced AES field, each byte contains 1b-scalar elements in the natural order.
624macro_rules! define_byte_sliced_3d_1b {
625	($name:ident, $packed_storage:ty, $storage_tower_level: ty) => {
626		#[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
627		#[repr(transparent)]
628		pub struct $name {
629			pub(super) data: [$packed_storage; <$storage_tower_level>::WIDTH],
630		}
631
632		impl $name {
633			pub const BYTES: usize =
634				<$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
635
636			pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
637			const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT_BYTES);
638
639			/// Get the byte at the given index.
640			///
641			/// # Safety
642			/// The caller must ensure that `byte_index` is less than `BYTES`.
643			#[allow(clippy::modulo_one)]
644			#[inline(always)]
645			pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
646				type Packed8b =
647					PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
648
649				unsafe {
650					Packed8b::cast_ext_ref(self.data.get_unchecked(byte_index % Self::HEIGHT_BYTES))
651						.get_unchecked(byte_index / Self::HEIGHT_BYTES)
652						.to_underlier()
653				}
654			}
655
656			/// Convert the byte-sliced field to an array of "ordinary" packed fields preserving the
657			/// order of scalars.
658			#[inline]
659			pub fn transpose_to(
660				&self,
661				out: &mut [PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
662					     Self::HEIGHT_BYTES],
663			) {
664				let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
665				*underliers = WithUnderlier::to_underliers_arr(self.data);
666
667				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
668					underliers,
669				);
670			}
671
672			/// Convert an array of "ordinary" packed fields to a byte-sliced field preserving the
673			/// order of scalars.
674			#[inline]
675			pub fn transpose_from(
676				underliers: &[PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
677					 Self::HEIGHT_BYTES],
678			) -> Self {
679				let mut underliers = WithUnderlier::to_underliers_arr(*underliers);
680
681				<$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<
682					$storage_tower_level,
683				>(&mut underliers);
684
685				Self {
686					data: WithUnderlier::from_underliers_arr(underliers),
687				}
688			}
689		}
690
691		impl Default for $name {
692			fn default() -> Self {
693				Self {
694					data: bytemuck::Zeroable::zeroed(),
695				}
696			}
697		}
698
699		impl Debug for $name {
700			fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
701				let values_str = self
702					.iter()
703					.map(|value| format!("{}", value))
704					.collect::<Vec<_>>()
705					.join(",");
706
707				write!(f, "ByteSlicedAES{}x1b([{}])", Self::WIDTH, values_str)
708			}
709		}
710
711		impl PackedField for $name {
712			type Scalar = BinaryField1b;
713
714			const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
715
716			#[allow(clippy::modulo_one)]
717			#[inline(always)]
718			unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
719				unsafe {
720					self.data
721						.get_unchecked((i / 8) % Self::HEIGHT_BYTES)
722						.get_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8)
723				}
724			}
725
726			#[allow(clippy::modulo_one)]
727			#[inline(always)]
728			unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
729				unsafe {
730					self.data
731						.get_unchecked_mut((i / 8) % Self::HEIGHT_BYTES)
732						.set_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8, scalar);
733				}
734			}
735
736			fn random(mut rng: impl rand::RngCore) -> Self {
737				let data = array::from_fn(|_| <$packed_storage>::random(&mut rng));
738				Self { data }
739			}
740
741			impl_init_with_transpose!($packed_storage, BinaryField1b, $storage_tower_level);
742
743			// Benchmarks show that transposing before the iteration makes it slower for 1b case,
744			// so do not override the default implementations of `iter` and `into_iter`.
745
746			#[allow(unreachable_patterns)]
747			#[inline]
748			fn broadcast(scalar: Self::Scalar) -> Self {
749				let underlier = <$packed_storage as WithUnderlier>::Underlier::fill_with_bit(
750					scalar.to_underlier().into(),
751				);
752				Self {
753					data: array::from_fn(|_| WithUnderlier::from_underlier(underlier)),
754				}
755			}
756
757			#[inline]
758			fn square(self) -> Self {
759				let data = array::from_fn(|i| self.data[i].clone().square());
760				Self { data }
761			}
762
763			#[inline]
764			fn invert_or_zero(self) -> Self {
765				let data = array::from_fn(|i| self.data[i].clone().invert_or_zero());
766				Self { data }
767			}
768
769			#[inline(always)]
770			fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
771				type Packed8b =
772					PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
773
774				if log_block_len < 3 {
775					let mut result1 = Self::default();
776					let mut result2 = Self::default();
777
778					for i in 0..Self::HEIGHT_BYTES {
779						(result1.data[i], result2.data[i]) =
780							self.data[i].interleave(other.data[i], log_block_len);
781					}
782
783					(result1, result2)
784				} else {
785					let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
786						Packed8b::cast_ext_arr_ref(&self.data);
787					let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
788						Packed8b::cast_ext_arr_ref(&other.data);
789
790					let (result1, result2) =
791						interleave_byte_sliced(self_data, other_data, log_block_len - 3);
792
793					(
794						Self {
795							data: Packed8b::cast_base_arr(result1),
796						},
797						Self {
798							data: Packed8b::cast_base_arr(result2),
799						},
800					)
801				}
802			}
803
804			#[inline]
805			fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
806				if log_block_len < 3 {
807					let mut result1 = Self::default();
808					let mut result2 = Self::default();
809
810					for i in 0..Self::HEIGHT_BYTES {
811						(result1.data[i], result2.data[i]) =
812							self.data[i].unzip(other.data[i], log_block_len);
813					}
814
815					(result1, result2)
816				} else {
817					type Packed8b =
818						PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
819
820					let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
821						Packed8b::cast_ext_arr_ref(&self.data);
822					let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
823						Packed8b::cast_ext_arr_ref(&other.data);
824
825					let (result1, result2) = unzip_byte_sliced::<Packed8b, { Self::HEIGHT_BYTES }, 1>(
826						self_data,
827						other_data,
828						log_block_len - 3,
829					);
830
831					(
832						Self {
833							data: Packed8b::cast_base_arr(result1),
834						},
835						Self {
836							data: Packed8b::cast_base_arr(result2),
837						},
838					)
839				}
840			}
841
842			impl_spread!($name, BinaryField1b, $packed_storage, $storage_tower_level);
843		}
844
845		impl Mul for $name {
846			type Output = Self;
847
848			#[inline]
849			fn mul(self, rhs: Self) -> Self {
850				Self {
851					data: array::from_fn(|i| self.data[i].clone() * rhs.data[i].clone()),
852				}
853			}
854		}
855
856		impl Add for $name {
857			type Output = Self;
858
859			#[inline]
860			fn add(self, rhs: Self) -> Self {
861				Self {
862					data: array::from_fn(|byte_number| {
863						self.data[byte_number] + rhs.data[byte_number]
864					}),
865				}
866			}
867		}
868
869		impl AddAssign for $name {
870			#[inline]
871			fn add_assign(&mut self, rhs: Self) {
872				for (data, rhs) in zip(&mut self.data, &rhs.data) {
873					*data += *rhs
874				}
875			}
876		}
877
878		byte_sliced_common!($name, $packed_storage, BinaryField1b, $storage_tower_level);
879
880		impl PackedTransformationFactory<$name> for $name {
881			type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> =
882				IDTransformation;
883
884			fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
885				_transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
886			) -> Self::PackedTransformation<Data> {
887				IDTransformation
888			}
889		}
890	};
891}
892
893#[inline(always)]
894fn interleave_internal_block<P: PackedField, const N: usize, const LOG_BLOCK_LEN: usize>(
895	lhs: &[P; N],
896	rhs: &[P; N],
897) -> ([P; N], [P; N]) {
898	debug_assert!(LOG_BLOCK_LEN < checked_log_2(N));
899
900	let result_1 = array::from_fn(|i| {
901		let block_index = i >> (LOG_BLOCK_LEN + 1);
902		let block_start = block_index << (LOG_BLOCK_LEN + 1);
903		let block_offset = i - block_start;
904
905		if block_offset < (1 << LOG_BLOCK_LEN) {
906			lhs[i]
907		} else {
908			rhs[i - (1 << LOG_BLOCK_LEN)]
909		}
910	});
911	let result_2 = array::from_fn(|i| {
912		let block_index = i >> (LOG_BLOCK_LEN + 1);
913		let block_start = block_index << (LOG_BLOCK_LEN + 1);
914		let block_offset = i - block_start;
915
916		if block_offset < (1 << LOG_BLOCK_LEN) {
917			lhs[i + (1 << LOG_BLOCK_LEN)]
918		} else {
919			rhs[i]
920		}
921	});
922
923	(result_1, result_2)
924}
925
926#[inline(always)]
927fn interleave_byte_sliced<P: PackedField, const N: usize>(
928	lhs: &[P; N],
929	rhs: &[P; N],
930	log_block_len: usize,
931) -> ([P; N], [P; N]) {
932	debug_assert!(checked_log_2(N) <= 4);
933
934	match log_block_len {
935		x if x >= checked_log_2(N) => interleave_big_block::<P, N>(lhs, rhs, log_block_len),
936		0 => interleave_internal_block::<P, N, 0>(lhs, rhs),
937		1 => interleave_internal_block::<P, N, 1>(lhs, rhs),
938		2 => interleave_internal_block::<P, N, 2>(lhs, rhs),
939		3 => interleave_internal_block::<P, N, 3>(lhs, rhs),
940		_ => unreachable!(),
941	}
942}
943
944#[inline(always)]
945fn unzip_byte_sliced<P: PackedField, const N: usize, const SCALAR_BYTES: usize>(
946	lhs: &[P; N],
947	rhs: &[P; N],
948	log_block_len: usize,
949) -> ([P; N], [P; N]) {
950	let mut result1: [P; N] = bytemuck::Zeroable::zeroed();
951	let mut result2: [P; N] = bytemuck::Zeroable::zeroed();
952
953	let log_height = checked_log_2(N);
954	if log_block_len < log_height {
955		let block_size = 1 << log_block_len;
956		let half = N / 2;
957		for block_offset in (0..half).step_by(block_size) {
958			let target_offset = block_offset * 2;
959
960			result1[block_offset..block_offset + block_size]
961				.copy_from_slice(&lhs[target_offset..target_offset + block_size]);
962			result1[half + target_offset..half + target_offset + block_size]
963				.copy_from_slice(&rhs[target_offset..target_offset + block_size]);
964
965			result2[block_offset..block_offset + block_size]
966				.copy_from_slice(&lhs[target_offset + block_size..target_offset + 2 * block_size]);
967			result2[half + target_offset..half + target_offset + block_size]
968				.copy_from_slice(&rhs[target_offset + block_size..target_offset + 2 * block_size]);
969		}
970	} else {
971		for i in 0..N {
972			(result1[i], result2[i]) = lhs[i].unzip(rhs[i], log_block_len - log_height);
973		}
974	}
975
976	(result1, result2)
977}
978
979#[inline(always)]
980fn interleave_big_block<P: PackedField, const N: usize>(
981	lhs: &[P; N],
982	rhs: &[P; N],
983	log_block_len: usize,
984) -> ([P; N], [P; N]) {
985	let mut result_1 = <[P; N]>::zeroed();
986	let mut result_2 = <[P; N]>::zeroed();
987
988	for i in 0..N {
989		(result_1[i], result_2[i]) = lhs[i].interleave(rhs[i], log_block_len - checked_log_2(N));
990	}
991
992	(result_1, result_2)
993}
994
995// 128 bit
996define_byte_sliced_3d!(
997	ByteSlicedAES16x128b,
998	AESTowerField128b,
999	PackedAESBinaryField16x8b,
1000	TowerLevel16,
1001	TowerLevel16
1002);
1003define_byte_sliced_3d!(
1004	ByteSlicedAES16x64b,
1005	AESTowerField64b,
1006	PackedAESBinaryField16x8b,
1007	TowerLevel8,
1008	TowerLevel8
1009);
1010define_byte_sliced_3d!(
1011	ByteSlicedAES2x16x64b,
1012	AESTowerField64b,
1013	PackedAESBinaryField16x8b,
1014	TowerLevel8,
1015	TowerLevel16
1016);
1017define_byte_sliced_3d!(
1018	ByteSlicedAES16x32b,
1019	AESTowerField32b,
1020	PackedAESBinaryField16x8b,
1021	TowerLevel4,
1022	TowerLevel4
1023);
1024define_byte_sliced_3d!(
1025	ByteSlicedAES4x16x32b,
1026	AESTowerField32b,
1027	PackedAESBinaryField16x8b,
1028	TowerLevel4,
1029	TowerLevel16
1030);
1031define_byte_sliced_3d!(
1032	ByteSlicedAES16x16b,
1033	AESTowerField16b,
1034	PackedAESBinaryField16x8b,
1035	TowerLevel2,
1036	TowerLevel2
1037);
1038define_byte_sliced_3d!(
1039	ByteSlicedAES8x16x16b,
1040	AESTowerField16b,
1041	PackedAESBinaryField16x8b,
1042	TowerLevel2,
1043	TowerLevel16
1044);
1045define_byte_sliced_3d!(
1046	ByteSlicedAES16x8b,
1047	AESTowerField8b,
1048	PackedAESBinaryField16x8b,
1049	TowerLevel1,
1050	TowerLevel1
1051);
1052define_byte_sliced_3d!(
1053	ByteSlicedAES16x16x8b,
1054	AESTowerField8b,
1055	PackedAESBinaryField16x8b,
1056	TowerLevel1,
1057	TowerLevel16
1058);
1059
1060define_byte_sliced_3d_1b!(ByteSliced16x128x1b, PackedBinaryField128x1b, TowerLevel16);
1061define_byte_sliced_3d_1b!(ByteSliced8x128x1b, PackedBinaryField128x1b, TowerLevel8);
1062define_byte_sliced_3d_1b!(ByteSliced4x128x1b, PackedBinaryField128x1b, TowerLevel4);
1063define_byte_sliced_3d_1b!(ByteSliced2x128x1b, PackedBinaryField128x1b, TowerLevel2);
1064define_byte_sliced_3d_1b!(ByteSliced1x128x1b, PackedBinaryField128x1b, TowerLevel1);
1065
1066// 256 bit
1067define_byte_sliced_3d!(
1068	ByteSlicedAES32x128b,
1069	AESTowerField128b,
1070	PackedAESBinaryField32x8b,
1071	TowerLevel16,
1072	TowerLevel16
1073);
1074define_byte_sliced_3d!(
1075	ByteSlicedAES32x64b,
1076	AESTowerField64b,
1077	PackedAESBinaryField32x8b,
1078	TowerLevel8,
1079	TowerLevel8
1080);
1081define_byte_sliced_3d!(
1082	ByteSlicedAES2x32x64b,
1083	AESTowerField64b,
1084	PackedAESBinaryField32x8b,
1085	TowerLevel8,
1086	TowerLevel16
1087);
1088define_byte_sliced_3d!(
1089	ByteSlicedAES32x32b,
1090	AESTowerField32b,
1091	PackedAESBinaryField32x8b,
1092	TowerLevel4,
1093	TowerLevel4
1094);
1095define_byte_sliced_3d!(
1096	ByteSlicedAES4x32x32b,
1097	AESTowerField32b,
1098	PackedAESBinaryField32x8b,
1099	TowerLevel4,
1100	TowerLevel16
1101);
1102define_byte_sliced_3d!(
1103	ByteSlicedAES32x16b,
1104	AESTowerField16b,
1105	PackedAESBinaryField32x8b,
1106	TowerLevel2,
1107	TowerLevel2
1108);
1109define_byte_sliced_3d!(
1110	ByteSlicedAES8x32x16b,
1111	AESTowerField16b,
1112	PackedAESBinaryField32x8b,
1113	TowerLevel2,
1114	TowerLevel16
1115);
1116define_byte_sliced_3d!(
1117	ByteSlicedAES32x8b,
1118	AESTowerField8b,
1119	PackedAESBinaryField32x8b,
1120	TowerLevel1,
1121	TowerLevel1
1122);
1123define_byte_sliced_3d!(
1124	ByteSlicedAES16x32x8b,
1125	AESTowerField8b,
1126	PackedAESBinaryField32x8b,
1127	TowerLevel1,
1128	TowerLevel16
1129);
1130
1131define_byte_sliced_3d_1b!(ByteSliced16x256x1b, PackedBinaryField256x1b, TowerLevel16);
1132define_byte_sliced_3d_1b!(ByteSliced8x256x1b, PackedBinaryField256x1b, TowerLevel8);
1133define_byte_sliced_3d_1b!(ByteSliced4x256x1b, PackedBinaryField256x1b, TowerLevel4);
1134define_byte_sliced_3d_1b!(ByteSliced2x256x1b, PackedBinaryField256x1b, TowerLevel2);
1135define_byte_sliced_3d_1b!(ByteSliced1x256x1b, PackedBinaryField256x1b, TowerLevel1);
1136
1137// 512 bit
1138define_byte_sliced_3d!(
1139	ByteSlicedAES64x128b,
1140	AESTowerField128b,
1141	PackedAESBinaryField64x8b,
1142	TowerLevel16,
1143	TowerLevel16
1144);
1145define_byte_sliced_3d!(
1146	ByteSlicedAES64x64b,
1147	AESTowerField64b,
1148	PackedAESBinaryField64x8b,
1149	TowerLevel8,
1150	TowerLevel8
1151);
1152define_byte_sliced_3d!(
1153	ByteSlicedAES2x64x64b,
1154	AESTowerField64b,
1155	PackedAESBinaryField64x8b,
1156	TowerLevel8,
1157	TowerLevel16
1158);
1159define_byte_sliced_3d!(
1160	ByteSlicedAES64x32b,
1161	AESTowerField32b,
1162	PackedAESBinaryField64x8b,
1163	TowerLevel4,
1164	TowerLevel4
1165);
1166define_byte_sliced_3d!(
1167	ByteSlicedAES4x64x32b,
1168	AESTowerField32b,
1169	PackedAESBinaryField64x8b,
1170	TowerLevel4,
1171	TowerLevel16
1172);
1173define_byte_sliced_3d!(
1174	ByteSlicedAES64x16b,
1175	AESTowerField16b,
1176	PackedAESBinaryField64x8b,
1177	TowerLevel2,
1178	TowerLevel2
1179);
1180define_byte_sliced_3d!(
1181	ByteSlicedAES8x64x16b,
1182	AESTowerField16b,
1183	PackedAESBinaryField64x8b,
1184	TowerLevel2,
1185	TowerLevel16
1186);
1187define_byte_sliced_3d!(
1188	ByteSlicedAES64x8b,
1189	AESTowerField8b,
1190	PackedAESBinaryField64x8b,
1191	TowerLevel1,
1192	TowerLevel1
1193);
1194define_byte_sliced_3d!(
1195	ByteSlicedAES16x64x8b,
1196	AESTowerField8b,
1197	PackedAESBinaryField64x8b,
1198	TowerLevel1,
1199	TowerLevel16
1200);
1201
1202define_byte_sliced_3d_1b!(ByteSliced16x512x1b, PackedBinaryField512x1b, TowerLevel16);
1203define_byte_sliced_3d_1b!(ByteSliced8x512x1b, PackedBinaryField512x1b, TowerLevel8);
1204define_byte_sliced_3d_1b!(ByteSliced4x512x1b, PackedBinaryField512x1b, TowerLevel4);
1205define_byte_sliced_3d_1b!(ByteSliced2x512x1b, PackedBinaryField512x1b, TowerLevel2);
1206define_byte_sliced_3d_1b!(ByteSliced1x512x1b, PackedBinaryField512x1b, TowerLevel1);