binius_field/arch/portable/byte_sliced/
packed_byte_sliced.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	fmt::Debug,
6	iter::{zip, Product, Sum},
7	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_utils::checked_arithmetics::checked_log_2;
11use bytemuck::{Pod, Zeroable};
12
13use super::{invert::invert_or_zero, multiply::mul, square::square};
14use crate::{
15	as_packed_field::{PackScalar, PackedType},
16	binary_field::BinaryField,
17	linear_transformation::{
18		FieldLinearTransformation, IDTransformation, PackedTransformationFactory, Transformation,
19	},
20	packed_aes_field::PackedAESBinaryField32x8b,
21	tower_levels::{TowerLevel, TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8},
22	underlier::{UnderlierWithBitOps, WithUnderlier},
23	AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b,
24	BinaryField1b, ExtensionField, PackedAESBinaryField16x8b, PackedAESBinaryField64x8b,
25	PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedExtension,
26	PackedField,
27};
28
29/// Packed transformation for byte-sliced fields with a scalar bigger than 8b.
30///
31/// `N` is the number of bytes in the scalar.
32pub struct TransformationWrapperNxN<Inner, const N: usize>([[Inner; N]; N]);
33
34/// Byte-sliced packed field with a fixed size (16x$packed_storage).
35/// For example for 32-bit scalar the data layout is the following:
36/// [ element_0[0], element_4[0], ... ]
37/// [ element_0[1], element_4[1], ... ]
38/// [ element_0[2], element_4[2], ... ]
39/// [ element_0[3], element_4[3], ... ]
40/// [ element_1[0], element_5[0], ... ]
41/// [ element_1[1], element_5[1], ... ]
42///  ...
43macro_rules! define_byte_sliced_3d {
44	($name:ident, $scalar_type:ty, $packed_storage:ty, $scalar_tower_level: ty, $storage_tower_level: ty) => {
45		#[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
46		#[repr(transparent)]
47		pub struct $name {
48			pub(super) data: [[$packed_storage; <$scalar_tower_level as TowerLevel>::WIDTH]; <$storage_tower_level as TowerLevel>::WIDTH / <$scalar_tower_level as TowerLevel>::WIDTH],
49		}
50
51		impl $name {
52			pub const BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
53
54			const SCALAR_BYTES: usize = <$scalar_type>::N_BITS / 8;
55			pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
56			const HEIGHT: usize = Self::HEIGHT_BYTES / Self::SCALAR_BYTES;
57			const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT);
58
59			/// Get the byte at the given index.
60			///
61			/// # Safety
62			/// The caller must ensure that `byte_index` is less than `BYTES`.
63			#[allow(clippy::modulo_one)]
64			#[inline(always)]
65			pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
66				let row = byte_index % Self::HEIGHT_BYTES;
67				self.data
68					.get_unchecked(row / Self::SCALAR_BYTES)
69					.get_unchecked(row % Self::SCALAR_BYTES)
70					.get_unchecked(byte_index / Self::HEIGHT_BYTES)
71					.to_underlier()
72			}
73
74			/// Convert the byte-sliced field to an array of "ordinary" packed fields preserving the order of scalars.
75			#[inline]
76			pub fn transpose_to(&self, out: &mut [<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES]) {
77				let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
78				*underliers = bytemuck::must_cast(self.data);
79
80				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(underliers);
81			}
82
83			/// Convert an array of "ordinary" packed fields to a byte-sliced field preserving the order of scalars.
84			#[inline]
85			pub fn transpose_from(
86				underliers: &[<<$packed_storage as WithUnderlier>::Underlier as PackScalar<$scalar_type>>::Packed; Self::HEIGHT_BYTES],
87			) -> Self {
88				let mut result = Self {
89					data: bytemuck::must_cast(*underliers),
90				};
91
92				<$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<$storage_tower_level>(bytemuck::must_cast_mut(&mut result.data));
93
94				result
95			}
96		}
97
98		impl Default for $name {
99			fn default() -> Self {
100				Self {
101					data: bytemuck::Zeroable::zeroed(),
102				}
103			}
104		}
105
106		impl Debug for $name {
107			fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108				let values_str = self
109					.iter()
110					.map(|value| format!("{}", value))
111					.collect::<Vec<_>>()
112					.join(",");
113
114				write!(f, "ByteSlicedAES{}x{}([{}])", Self::WIDTH, <$scalar_type>::N_BITS, values_str)
115			}
116		}
117
118		impl PackedField for $name {
119			type Scalar = $scalar_type;
120
121			const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
122
123			#[allow(clippy::modulo_one)]
124			#[inline(always)]
125			unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
126				let element_rows = self.data.get_unchecked(i % Self::HEIGHT);
127				Self::Scalar::from_bases((0..Self::SCALAR_BYTES).map(|byte_index| {
128					element_rows
129						.get_unchecked(byte_index)
130						.get_unchecked(i / Self::HEIGHT)
131				}))
132				.expect("byte index is within bounds")
133			}
134
135			#[allow(clippy::modulo_one)]
136			#[inline(always)]
137			unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
138				let element_rows = self.data.get_unchecked_mut(i % Self::HEIGHT);
139				for byte_index in 0..Self::SCALAR_BYTES {
140					element_rows
141						.get_unchecked_mut(byte_index)
142						.set_unchecked(
143							i / Self::HEIGHT,
144							scalar.get_base_unchecked(byte_index),
145						);
146				}
147			}
148
149			fn random(mut rng: impl rand::RngCore) -> Self {
150				let data = array::from_fn(|_| array::from_fn(|_| <$packed_storage>::random(&mut rng)));
151				Self { data }
152			}
153
154			#[allow(unreachable_patterns)]
155			#[inline]
156			fn broadcast(scalar: Self::Scalar) -> Self {
157				let data: [[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT] = match Self::SCALAR_BYTES {
158					1 => {
159						let packed_broadcast =
160							<$packed_storage>::broadcast(unsafe { scalar.get_base_unchecked(0) });
161						array::from_fn(|_| array::from_fn(|_| packed_broadcast))
162					}
163					Self::HEIGHT_BYTES => array::from_fn(|_| array::from_fn(|byte_index| {
164						<$packed_storage>::broadcast(unsafe {
165							scalar.get_base_unchecked(byte_index)
166						})
167					})),
168					_ => {
169						let mut data = <[[$packed_storage; Self::SCALAR_BYTES]; Self::HEIGHT]>::zeroed();
170						for byte_index in 0..Self::SCALAR_BYTES {
171							let broadcast = <$packed_storage>::broadcast(unsafe {
172								scalar.get_base_unchecked(byte_index)
173							});
174
175							for i in 0..Self::HEIGHT {
176								data[i][byte_index] = broadcast;
177							}
178						}
179
180						data
181					}
182				};
183
184				Self { data }
185			}
186
187			#[inline]
188			fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
189				let mut result = Self::default();
190
191				// TODO: use transposition here as soon as implemented
192				for i in 0..Self::WIDTH {
193					//SAFETY: i doesn't exceed Self::WIDTH
194					unsafe { result.set_unchecked(i, f(i)) };
195				}
196
197				result
198			}
199
200			#[inline]
201			fn square(self) -> Self {
202				let mut result = Self::default();
203
204				for i in 0..Self::HEIGHT {
205					square::<$packed_storage, $scalar_tower_level>(
206						&self.data[i],
207						&mut result.data[i],
208					);
209				}
210
211				result
212			}
213
214			#[inline]
215			fn invert_or_zero(self) -> Self {
216				let mut result = Self::default();
217
218				for i in 0..Self::HEIGHT {
219					invert_or_zero::<$packed_storage, $scalar_tower_level>(
220						&self.data[i],
221						&mut result.data[i],
222					);
223				}
224
225				result
226			}
227
228			#[inline(always)]
229			fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
230
231				let self_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&self.data);
232				let other_data: &[$packed_storage; Self::HEIGHT_BYTES] = bytemuck::must_cast_ref(&other.data);
233
234				// This implementation is faster than using a loop with `copy_from_slice` for the first 4 cases
235				let (data_1, data_2) = interleave_byte_sliced(self_data, other_data, log_block_len + checked_log_2(Self::SCALAR_BYTES));
236				(
237					Self {
238						data: bytemuck::must_cast(data_1),
239					},
240					Self {
241						data: bytemuck::must_cast(data_2),
242					},
243				)
244			}
245
246			#[inline]
247			fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
248				let (result1, result2) = unzip_byte_sliced::<$packed_storage, {Self::HEIGHT_BYTES}, {Self::SCALAR_BYTES}>(bytemuck::must_cast_ref(&self.data), bytemuck::must_cast_ref(&other.data), log_block_len + checked_log_2(Self::SCALAR_BYTES));
249
250				(
251					Self {
252						data: bytemuck::must_cast(result1),
253					},
254					Self {
255						data: bytemuck::must_cast(result2),
256					},
257				)
258			}
259		}
260
261		impl Mul for $name {
262			type Output = Self;
263
264			#[inline]
265			fn mul(self, rhs: Self) -> Self {
266				let mut result = Self::default();
267
268				for i in 0..Self::HEIGHT {
269					mul::<$packed_storage, $scalar_tower_level>(
270						&self.data[i],
271						&rhs.data[i],
272						&mut result.data[i],
273					);
274				}
275
276				result
277			}
278		}
279
280		impl Add for $name {
281			type Output = Self;
282
283			#[inline]
284			fn add(self, rhs: Self) -> Self {
285				Self {
286					data: array::from_fn(|byte_number| {
287						array::from_fn(|column|
288							self.data[byte_number][column] + rhs.data[byte_number][column]
289						)
290					}),
291				}
292			}
293		}
294
295		impl AddAssign for $name {
296			#[inline]
297			fn add_assign(&mut self, rhs: Self) {
298				for (data, rhs) in zip(&mut self.data, &rhs.data) {
299					for (data, rhs) in zip(data, rhs) {
300						*data += *rhs
301					}
302				}
303			}
304		}
305
306		byte_sliced_common!($name, $packed_storage, $scalar_type);
307
308		impl<Inner: Transformation<$packed_storage, $packed_storage>> Transformation<$name, $name> for TransformationWrapperNxN<Inner, {<$scalar_tower_level as TowerLevel>::WIDTH}> {
309			fn transform(&self, data: &$name) -> $name {
310				let mut result = <$name>::default();
311
312				for row in 0..<$name>::SCALAR_BYTES {
313					for col in 0..<$name>::SCALAR_BYTES {
314						let transformation = &self.0[col][row];
315
316						for i in 0..<$name>::HEIGHT {
317							result.data[i][row] += transformation.transform(&data.data[i][col]);
318						}
319					}
320				}
321
322				result
323			}
324		}
325
326		impl PackedTransformationFactory<$name> for $name {
327			type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> = TransformationWrapperNxN<<$packed_storage as  PackedTransformationFactory<$packed_storage>>::PackedTransformation::<[AESTowerField8b; 8]>, {<$scalar_tower_level as TowerLevel>::WIDTH}>;
328
329			fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
330				transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
331			) -> Self::PackedTransformation<Data> {
332				let transformations_8b = array::from_fn(|row| {
333					array::from_fn(|col| {
334						let row = row * 8;
335						let linear_transformation_8b = array::from_fn::<_, 8, _>(|row_8b| unsafe {
336							<<$name as PackedField>::Scalar as ExtensionField<AESTowerField8b>>::get_base_unchecked(&transformation.bases()[row + row_8b], col)
337						});
338
339						<$packed_storage as PackedTransformationFactory<$packed_storage
340						>>::make_packed_transformation(FieldLinearTransformation::new(linear_transformation_8b))
341					})
342				});
343
344				TransformationWrapperNxN(transformations_8b)
345			}
346		}
347	};
348}
349
350macro_rules! byte_sliced_common {
351	($name:ident, $packed_storage:ty, $scalar_type:ty) => {
352		impl Add<$scalar_type> for $name {
353			type Output = Self;
354
355			#[inline]
356			fn add(self, rhs: $scalar_type) -> $name {
357				self + Self::broadcast(rhs)
358			}
359		}
360
361		impl AddAssign<$scalar_type> for $name {
362			#[inline]
363			fn add_assign(&mut self, rhs: $scalar_type) {
364				*self += Self::broadcast(rhs)
365			}
366		}
367
368		impl Sub<$scalar_type> for $name {
369			type Output = Self;
370
371			#[inline]
372			fn sub(self, rhs: $scalar_type) -> $name {
373				self.add(rhs)
374			}
375		}
376
377		impl SubAssign<$scalar_type> for $name {
378			#[inline]
379			fn sub_assign(&mut self, rhs: $scalar_type) {
380				self.add_assign(rhs)
381			}
382		}
383
384		impl Mul<$scalar_type> for $name {
385			type Output = Self;
386
387			#[inline]
388			fn mul(self, rhs: $scalar_type) -> $name {
389				self * Self::broadcast(rhs)
390			}
391		}
392
393		impl MulAssign<$scalar_type> for $name {
394			#[inline]
395			fn mul_assign(&mut self, rhs: $scalar_type) {
396				*self *= Self::broadcast(rhs);
397			}
398		}
399
400		impl Sub for $name {
401			type Output = Self;
402
403			#[inline]
404			fn sub(self, rhs: Self) -> Self {
405				self.add(rhs)
406			}
407		}
408
409		impl SubAssign for $name {
410			#[inline]
411			fn sub_assign(&mut self, rhs: Self) {
412				self.add_assign(rhs);
413			}
414		}
415
416		impl MulAssign for $name {
417			#[inline]
418			fn mul_assign(&mut self, rhs: Self) {
419				*self = *self * rhs;
420			}
421		}
422
423		impl Product for $name {
424			fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
425				let mut result = Self::one();
426
427				let mut is_first_item = true;
428				for item in iter {
429					if is_first_item {
430						result = item;
431					} else {
432						result *= item;
433					}
434
435					is_first_item = false;
436				}
437
438				result
439			}
440		}
441
442		impl Sum for $name {
443			fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
444				let mut result = Self::zero();
445
446				for item in iter {
447					result += item;
448				}
449
450				result
451			}
452		}
453
454		impl PackedExtension<$scalar_type> for $name {
455			type PackedSubfield = Self;
456
457			#[inline(always)]
458			fn cast_bases(packed: &[Self]) -> &[Self::PackedSubfield] {
459				packed
460			}
461
462			#[inline(always)]
463			fn cast_bases_mut(packed: &mut [Self]) -> &mut [Self::PackedSubfield] {
464				packed
465			}
466
467			#[inline(always)]
468			fn cast_exts(packed: &[Self::PackedSubfield]) -> &[Self] {
469				packed
470			}
471
472			#[inline(always)]
473			fn cast_exts_mut(packed: &mut [Self::PackedSubfield]) -> &mut [Self] {
474				packed
475			}
476
477			#[inline(always)]
478			fn cast_base(self) -> Self::PackedSubfield {
479				self
480			}
481
482			#[inline(always)]
483			fn cast_base_ref(&self) -> &Self::PackedSubfield {
484				self
485			}
486
487			#[inline(always)]
488			fn cast_base_mut(&mut self) -> &mut Self::PackedSubfield {
489				self
490			}
491
492			#[inline(always)]
493			fn cast_ext(base: Self::PackedSubfield) -> Self {
494				base
495			}
496
497			#[inline(always)]
498			fn cast_ext_ref(base: &Self::PackedSubfield) -> &Self {
499				base
500			}
501
502			#[inline(always)]
503			fn cast_ext_mut(base: &mut Self::PackedSubfield) -> &mut Self {
504				base
505			}
506		}
507	};
508}
509
510/// Special case: byte-sliced packed 1b-fields. The order of bytes in the layout matches the one
511/// for a byte-sliced AES field, each byte contains 1b-scalar elements in the natural order.
512macro_rules! define_byte_sliced_3d_1b {
513	($name:ident, $packed_storage:ty, $storage_tower_level: ty) => {
514		#[derive(Clone, Copy, PartialEq, Eq, Pod, Zeroable)]
515		#[repr(transparent)]
516		pub struct $name {
517			pub(super) data: [$packed_storage; <$storage_tower_level>::WIDTH],
518		}
519
520		impl $name {
521			pub const BYTES: usize =
522				<$storage_tower_level as TowerLevel>::WIDTH * <$packed_storage>::WIDTH;
523
524			pub(crate) const HEIGHT_BYTES: usize = <$storage_tower_level as TowerLevel>::WIDTH;
525			const LOG_HEIGHT: usize = checked_log_2(Self::HEIGHT_BYTES);
526
527			/// Get the byte at the given index.
528			///
529			/// # Safety
530			/// The caller must ensure that `byte_index` is less than `BYTES`.
531			#[allow(clippy::modulo_one)]
532			#[inline(always)]
533			pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 {
534				type Packed8b =
535					PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
536
537				Packed8b::cast_ext_ref(self.data.get_unchecked(byte_index % Self::HEIGHT_BYTES))
538					.get_unchecked(byte_index / Self::HEIGHT_BYTES)
539					.to_underlier()
540			}
541
542			/// Convert the byte-sliced field to an array of "ordinary" packed fields preserving the order of scalars.
543			#[inline]
544			pub fn transpose_to(
545				&self,
546				out: &mut [PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
547					     Self::HEIGHT_BYTES],
548			) {
549				let underliers = WithUnderlier::to_underliers_arr_ref_mut(out);
550				*underliers = WithUnderlier::to_underliers_arr(self.data);
551
552				UnderlierWithBitOps::transpose_bytes_from_byte_sliced::<$storage_tower_level>(
553					underliers,
554				);
555			}
556
557			/// Convert an array of "ordinary" packed fields to a byte-sliced field preserving the order of scalars.
558			#[inline]
559			pub fn transpose_from(
560				underliers: &[PackedType<<$packed_storage as WithUnderlier>::Underlier, BinaryField1b>;
561					 Self::HEIGHT_BYTES],
562			) -> Self {
563				let mut underliers = WithUnderlier::to_underliers_arr(*underliers);
564
565				<$packed_storage as WithUnderlier>::Underlier::transpose_bytes_to_byte_sliced::<
566					$storage_tower_level,
567				>(&mut underliers);
568
569				Self {
570					data: WithUnderlier::from_underliers_arr(underliers),
571				}
572			}
573		}
574
575		impl Default for $name {
576			fn default() -> Self {
577				Self {
578					data: bytemuck::Zeroable::zeroed(),
579				}
580			}
581		}
582
583		impl Debug for $name {
584			fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585				let values_str = self
586					.iter()
587					.map(|value| format!("{}", value))
588					.collect::<Vec<_>>()
589					.join(",");
590
591				write!(f, "ByteSlicedAES{}x1b([{}])", Self::WIDTH, values_str)
592			}
593		}
594
595		impl PackedField for $name {
596			type Scalar = BinaryField1b;
597
598			const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH + Self::LOG_HEIGHT;
599
600			#[allow(clippy::modulo_one)]
601			#[inline(always)]
602			unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
603				self.data
604					.get_unchecked((i / 8) % Self::HEIGHT_BYTES)
605					.get_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8)
606			}
607
608			#[allow(clippy::modulo_one)]
609			#[inline(always)]
610			unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
611				self.data
612					.get_unchecked_mut((i / 8) % Self::HEIGHT_BYTES)
613					.set_unchecked(8 * (i / (Self::HEIGHT_BYTES * 8)) + i % 8, scalar);
614			}
615
616			fn random(mut rng: impl rand::RngCore) -> Self {
617				let data = array::from_fn(|_| <$packed_storage>::random(&mut rng));
618				Self { data }
619			}
620
621			#[allow(unreachable_patterns)]
622			#[inline]
623			fn broadcast(scalar: Self::Scalar) -> Self {
624				let underlier = <$packed_storage as WithUnderlier>::Underlier::fill_with_bit(
625					scalar.to_underlier().into(),
626				);
627				Self {
628					data: array::from_fn(|_| WithUnderlier::from_underlier(underlier)),
629				}
630			}
631
632			#[inline]
633			fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
634				let data = array::from_fn(|row| {
635					<$packed_storage>::from_fn(|col| {
636						f(row * 8 + 8 * Self::HEIGHT_BYTES * (col / 8) + col % 8)
637					})
638				});
639
640				Self { data }
641			}
642
643			#[inline]
644			fn square(self) -> Self {
645				let data = array::from_fn(|i| self.data[i].clone().square());
646				Self { data }
647			}
648
649			#[inline]
650			fn invert_or_zero(self) -> Self {
651				let data = array::from_fn(|i| self.data[i].clone().invert_or_zero());
652				Self { data }
653			}
654
655			#[inline(always)]
656			fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
657				type Packed8b =
658					PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
659
660				if log_block_len < 3 {
661					let mut result1 = Self::default();
662					let mut result2 = Self::default();
663
664					for i in 0..Self::HEIGHT_BYTES {
665						(result1.data[i], result2.data[i]) =
666							self.data[i].interleave(other.data[i], log_block_len);
667					}
668
669					(result1, result2)
670				} else {
671					let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
672						Packed8b::cast_ext_arr_ref(&self.data);
673					let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
674						Packed8b::cast_ext_arr_ref(&other.data);
675
676					let (result1, result2) =
677						interleave_byte_sliced(self_data, other_data, log_block_len - 3);
678
679					(
680						Self {
681							data: Packed8b::cast_base_arr(result1),
682						},
683						Self {
684							data: Packed8b::cast_base_arr(result2),
685						},
686					)
687				}
688			}
689
690			#[inline]
691			fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
692				if log_block_len < 3 {
693					let mut result1 = Self::default();
694					let mut result2 = Self::default();
695
696					for i in 0..Self::HEIGHT_BYTES {
697						(result1.data[i], result2.data[i]) =
698							self.data[i].unzip(other.data[i], log_block_len);
699					}
700
701					(result1, result2)
702				} else {
703					type Packed8b =
704						PackedType<<$packed_storage as WithUnderlier>::Underlier, AESTowerField8b>;
705
706					let self_data: &[Packed8b; Self::HEIGHT_BYTES] =
707						Packed8b::cast_ext_arr_ref(&self.data);
708					let other_data: &[Packed8b; Self::HEIGHT_BYTES] =
709						Packed8b::cast_ext_arr_ref(&other.data);
710
711					let (result1, result2) = unzip_byte_sliced::<Packed8b, { Self::HEIGHT_BYTES }, 1>(
712						self_data,
713						other_data,
714						log_block_len - 3,
715					);
716
717					(
718						Self {
719							data: Packed8b::cast_base_arr(result1),
720						},
721						Self {
722							data: Packed8b::cast_base_arr(result2),
723						},
724					)
725				}
726			}
727		}
728
729		impl Mul for $name {
730			type Output = Self;
731
732			#[inline]
733			fn mul(self, rhs: Self) -> Self {
734				Self {
735					data: array::from_fn(|i| self.data[i].clone() * rhs.data[i].clone()),
736				}
737			}
738		}
739
740		impl Add for $name {
741			type Output = Self;
742
743			#[inline]
744			fn add(self, rhs: Self) -> Self {
745				Self {
746					data: array::from_fn(|byte_number| {
747						self.data[byte_number] + rhs.data[byte_number]
748					}),
749				}
750			}
751		}
752
753		impl AddAssign for $name {
754			#[inline]
755			fn add_assign(&mut self, rhs: Self) {
756				for (data, rhs) in zip(&mut self.data, &rhs.data) {
757					*data += *rhs
758				}
759			}
760		}
761
762		byte_sliced_common!($name, $packed_storage, BinaryField1b);
763
764		impl PackedTransformationFactory<$name> for $name {
765			type PackedTransformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync> =
766				IDTransformation;
767
768			fn make_packed_transformation<Data: AsRef<[<$name as PackedField>::Scalar]> + Sync>(
769				_transformation: FieldLinearTransformation<<$name as PackedField>::Scalar, Data>,
770			) -> Self::PackedTransformation<Data> {
771				IDTransformation
772			}
773		}
774	};
775}
776
777#[inline(always)]
778fn interleave_internal_block<P: PackedField, const N: usize, const LOG_BLOCK_LEN: usize>(
779	lhs: &[P; N],
780	rhs: &[P; N],
781) -> ([P; N], [P; N]) {
782	debug_assert!(LOG_BLOCK_LEN < checked_log_2(N));
783
784	let result_1 = array::from_fn(|i| {
785		let block_index = i >> (LOG_BLOCK_LEN + 1);
786		let block_start = block_index << (LOG_BLOCK_LEN + 1);
787		let block_offset = i - block_start;
788
789		if block_offset < (1 << LOG_BLOCK_LEN) {
790			lhs[i]
791		} else {
792			rhs[i - (1 << LOG_BLOCK_LEN)]
793		}
794	});
795	let result_2 = array::from_fn(|i| {
796		let block_index = i >> (LOG_BLOCK_LEN + 1);
797		let block_start = block_index << (LOG_BLOCK_LEN + 1);
798		let block_offset = i - block_start;
799
800		if block_offset < (1 << LOG_BLOCK_LEN) {
801			lhs[i + (1 << LOG_BLOCK_LEN)]
802		} else {
803			rhs[i]
804		}
805	});
806
807	(result_1, result_2)
808}
809
810#[inline(always)]
811fn interleave_byte_sliced<P: PackedField, const N: usize>(
812	lhs: &[P; N],
813	rhs: &[P; N],
814	log_block_len: usize,
815) -> ([P; N], [P; N]) {
816	debug_assert!(checked_log_2(N) <= 4);
817
818	match log_block_len {
819		x if x >= checked_log_2(N) => interleave_big_block::<P, N>(lhs, rhs, log_block_len),
820		0 => interleave_internal_block::<P, N, 0>(lhs, rhs),
821		1 => interleave_internal_block::<P, N, 1>(lhs, rhs),
822		2 => interleave_internal_block::<P, N, 2>(lhs, rhs),
823		3 => interleave_internal_block::<P, N, 3>(lhs, rhs),
824		_ => unreachable!(),
825	}
826}
827
828#[inline(always)]
829fn unzip_byte_sliced<P: PackedField, const N: usize, const SCALAR_BYTES: usize>(
830	lhs: &[P; N],
831	rhs: &[P; N],
832	log_block_len: usize,
833) -> ([P; N], [P; N]) {
834	let mut result1: [P; N] = bytemuck::Zeroable::zeroed();
835	let mut result2: [P; N] = bytemuck::Zeroable::zeroed();
836
837	let log_height = checked_log_2(N);
838	if log_block_len < log_height {
839		let block_size = 1 << log_block_len;
840		let half = N / 2;
841		for block_offset in (0..half).step_by(block_size) {
842			let target_offset = block_offset * 2;
843
844			result1[block_offset..block_offset + block_size]
845				.copy_from_slice(&lhs[target_offset..target_offset + block_size]);
846			result1[half + target_offset..half + target_offset + block_size]
847				.copy_from_slice(&rhs[target_offset..target_offset + block_size]);
848
849			result2[block_offset..block_offset + block_size]
850				.copy_from_slice(&lhs[target_offset + block_size..target_offset + 2 * block_size]);
851			result2[half + target_offset..half + target_offset + block_size]
852				.copy_from_slice(&rhs[target_offset + block_size..target_offset + 2 * block_size]);
853		}
854	} else {
855		for i in 0..N {
856			(result1[i], result2[i]) = lhs[i].unzip(rhs[i], log_block_len - log_height);
857		}
858	}
859
860	(result1, result2)
861}
862
863#[inline(always)]
864fn interleave_big_block<P: PackedField, const N: usize>(
865	lhs: &[P; N],
866	rhs: &[P; N],
867	log_block_len: usize,
868) -> ([P; N], [P; N]) {
869	let mut result_1 = <[P; N]>::zeroed();
870	let mut result_2 = <[P; N]>::zeroed();
871
872	for i in 0..N {
873		(result_1[i], result_2[i]) = lhs[i].interleave(rhs[i], log_block_len - checked_log_2(N));
874	}
875
876	(result_1, result_2)
877}
878
879// 128 bit
880define_byte_sliced_3d!(
881	ByteSlicedAES16x128b,
882	AESTowerField128b,
883	PackedAESBinaryField16x8b,
884	TowerLevel16,
885	TowerLevel16
886);
887define_byte_sliced_3d!(
888	ByteSlicedAES16x64b,
889	AESTowerField64b,
890	PackedAESBinaryField16x8b,
891	TowerLevel8,
892	TowerLevel8
893);
894define_byte_sliced_3d!(
895	ByteSlicedAES2x16x64b,
896	AESTowerField64b,
897	PackedAESBinaryField16x8b,
898	TowerLevel8,
899	TowerLevel16
900);
901define_byte_sliced_3d!(
902	ByteSlicedAES16x32b,
903	AESTowerField32b,
904	PackedAESBinaryField16x8b,
905	TowerLevel4,
906	TowerLevel4
907);
908define_byte_sliced_3d!(
909	ByteSlicedAES4x16x32b,
910	AESTowerField32b,
911	PackedAESBinaryField16x8b,
912	TowerLevel4,
913	TowerLevel16
914);
915define_byte_sliced_3d!(
916	ByteSlicedAES16x16b,
917	AESTowerField16b,
918	PackedAESBinaryField16x8b,
919	TowerLevel2,
920	TowerLevel2
921);
922define_byte_sliced_3d!(
923	ByteSlicedAES8x16x16b,
924	AESTowerField16b,
925	PackedAESBinaryField16x8b,
926	TowerLevel2,
927	TowerLevel16
928);
929define_byte_sliced_3d!(
930	ByteSlicedAES16x8b,
931	AESTowerField8b,
932	PackedAESBinaryField16x8b,
933	TowerLevel1,
934	TowerLevel1
935);
936define_byte_sliced_3d!(
937	ByteSlicedAES16x16x8b,
938	AESTowerField8b,
939	PackedAESBinaryField16x8b,
940	TowerLevel1,
941	TowerLevel16
942);
943
944define_byte_sliced_3d_1b!(ByteSliced16x128x1b, PackedBinaryField128x1b, TowerLevel16);
945define_byte_sliced_3d_1b!(ByteSliced8x128x1b, PackedBinaryField128x1b, TowerLevel8);
946define_byte_sliced_3d_1b!(ByteSliced4x128x1b, PackedBinaryField128x1b, TowerLevel4);
947define_byte_sliced_3d_1b!(ByteSliced2x128x1b, PackedBinaryField128x1b, TowerLevel2);
948define_byte_sliced_3d_1b!(ByteSliced1x128x1b, PackedBinaryField128x1b, TowerLevel1);
949
950// 256 bit
951define_byte_sliced_3d!(
952	ByteSlicedAES32x128b,
953	AESTowerField128b,
954	PackedAESBinaryField32x8b,
955	TowerLevel16,
956	TowerLevel16
957);
958define_byte_sliced_3d!(
959	ByteSlicedAES32x64b,
960	AESTowerField64b,
961	PackedAESBinaryField32x8b,
962	TowerLevel8,
963	TowerLevel8
964);
965define_byte_sliced_3d!(
966	ByteSlicedAES2x32x64b,
967	AESTowerField64b,
968	PackedAESBinaryField32x8b,
969	TowerLevel8,
970	TowerLevel16
971);
972define_byte_sliced_3d!(
973	ByteSlicedAES32x32b,
974	AESTowerField32b,
975	PackedAESBinaryField32x8b,
976	TowerLevel4,
977	TowerLevel4
978);
979define_byte_sliced_3d!(
980	ByteSlicedAES4x32x32b,
981	AESTowerField32b,
982	PackedAESBinaryField32x8b,
983	TowerLevel4,
984	TowerLevel16
985);
986define_byte_sliced_3d!(
987	ByteSlicedAES32x16b,
988	AESTowerField16b,
989	PackedAESBinaryField32x8b,
990	TowerLevel2,
991	TowerLevel2
992);
993define_byte_sliced_3d!(
994	ByteSlicedAES8x32x16b,
995	AESTowerField16b,
996	PackedAESBinaryField32x8b,
997	TowerLevel2,
998	TowerLevel16
999);
1000define_byte_sliced_3d!(
1001	ByteSlicedAES32x8b,
1002	AESTowerField8b,
1003	PackedAESBinaryField32x8b,
1004	TowerLevel1,
1005	TowerLevel1
1006);
1007define_byte_sliced_3d!(
1008	ByteSlicedAES16x32x8b,
1009	AESTowerField8b,
1010	PackedAESBinaryField32x8b,
1011	TowerLevel1,
1012	TowerLevel16
1013);
1014
1015define_byte_sliced_3d_1b!(ByteSliced16x256x1b, PackedBinaryField256x1b, TowerLevel16);
1016define_byte_sliced_3d_1b!(ByteSliced8x256x1b, PackedBinaryField256x1b, TowerLevel8);
1017define_byte_sliced_3d_1b!(ByteSliced4x256x1b, PackedBinaryField256x1b, TowerLevel4);
1018define_byte_sliced_3d_1b!(ByteSliced2x256x1b, PackedBinaryField256x1b, TowerLevel2);
1019define_byte_sliced_3d_1b!(ByteSliced1x256x1b, PackedBinaryField256x1b, TowerLevel1);
1020
1021// 512 bit
1022define_byte_sliced_3d!(
1023	ByteSlicedAES64x128b,
1024	AESTowerField128b,
1025	PackedAESBinaryField64x8b,
1026	TowerLevel16,
1027	TowerLevel16
1028);
1029define_byte_sliced_3d!(
1030	ByteSlicedAES64x64b,
1031	AESTowerField64b,
1032	PackedAESBinaryField64x8b,
1033	TowerLevel8,
1034	TowerLevel8
1035);
1036define_byte_sliced_3d!(
1037	ByteSlicedAES2x64x64b,
1038	AESTowerField64b,
1039	PackedAESBinaryField64x8b,
1040	TowerLevel8,
1041	TowerLevel16
1042);
1043define_byte_sliced_3d!(
1044	ByteSlicedAES64x32b,
1045	AESTowerField32b,
1046	PackedAESBinaryField64x8b,
1047	TowerLevel4,
1048	TowerLevel4
1049);
1050define_byte_sliced_3d!(
1051	ByteSlicedAES4x64x32b,
1052	AESTowerField32b,
1053	PackedAESBinaryField64x8b,
1054	TowerLevel4,
1055	TowerLevel16
1056);
1057define_byte_sliced_3d!(
1058	ByteSlicedAES64x16b,
1059	AESTowerField16b,
1060	PackedAESBinaryField64x8b,
1061	TowerLevel2,
1062	TowerLevel2
1063);
1064define_byte_sliced_3d!(
1065	ByteSlicedAES8x64x16b,
1066	AESTowerField16b,
1067	PackedAESBinaryField64x8b,
1068	TowerLevel2,
1069	TowerLevel16
1070);
1071define_byte_sliced_3d!(
1072	ByteSlicedAES64x8b,
1073	AESTowerField8b,
1074	PackedAESBinaryField64x8b,
1075	TowerLevel1,
1076	TowerLevel1
1077);
1078define_byte_sliced_3d!(
1079	ByteSlicedAES16x64x8b,
1080	AESTowerField8b,
1081	PackedAESBinaryField64x8b,
1082	TowerLevel1,
1083	TowerLevel16
1084);
1085
1086define_byte_sliced_3d_1b!(ByteSliced16x512x1b, PackedBinaryField512x1b, TowerLevel16);
1087define_byte_sliced_3d_1b!(ByteSliced8x512x1b, PackedBinaryField512x1b, TowerLevel8);
1088define_byte_sliced_3d_1b!(ByteSliced4x512x1b, PackedBinaryField512x1b, TowerLevel4);
1089define_byte_sliced_3d_1b!(ByteSliced2x512x1b, PackedBinaryField512x1b, TowerLevel2);
1090define_byte_sliced_3d_1b!(ByteSliced1x512x1b, PackedBinaryField512x1b, TowerLevel1);
1091
1092macro_rules! impl_packed_extension{
1093	($packed_ext:ty, $packed_base:ty,) => {
1094		impl PackedExtension<<$packed_base as PackedField>::Scalar> for $packed_ext {
1095			type PackedSubfield = $packed_base;
1096
1097			fn cast_bases(packed: &[Self]) -> &[Self::PackedSubfield] {
1098				bytemuck::must_cast_slice(packed)
1099			}
1100
1101			fn cast_bases_mut(packed: &mut [Self]) -> &mut [Self::PackedSubfield] {
1102				bytemuck::must_cast_slice_mut(packed)
1103			}
1104
1105			fn cast_exts(packed: &[Self::PackedSubfield]) -> &[Self] {
1106				bytemuck::must_cast_slice(packed)
1107			}
1108
1109			fn cast_exts_mut(packed: &mut [Self::PackedSubfield]) -> &mut [Self] {
1110				bytemuck::must_cast_slice_mut(packed)
1111			}
1112
1113			fn cast_base(self) -> Self::PackedSubfield {
1114				bytemuck::must_cast(self)
1115			}
1116
1117			fn cast_base_ref(&self) -> &Self::PackedSubfield {
1118				bytemuck::must_cast_ref(self)
1119			}
1120
1121			fn cast_base_mut(&mut self) -> &mut Self::PackedSubfield {
1122				bytemuck::must_cast_mut(self)
1123			}
1124
1125			fn cast_ext(base: Self::PackedSubfield) -> Self {
1126				bytemuck::must_cast(base)
1127			}
1128
1129			fn cast_ext_ref(base: &Self::PackedSubfield) -> &Self {
1130				bytemuck::must_cast_ref(base)
1131			}
1132
1133			fn cast_ext_mut(base: &mut Self::PackedSubfield) -> &mut Self {
1134				bytemuck::must_cast_mut(base)
1135			}
1136		}
1137	};
1138	(@pairs $head:ty, $next:ty,) => {
1139		impl_packed_extension!($head, $next,);
1140	};
1141	(@pairs $head:ty, $next:ty, $($tail:ty,)*) => {
1142		impl_packed_extension!($head, $next,);
1143		impl_packed_extension!(@pairs $head, $($tail,)*);
1144	};
1145	($head:ty, $next:ty, $($tail:ty,)*) => {
1146		impl_packed_extension!(@pairs $head, $next, $($tail,)*);
1147		impl_packed_extension!($next, $($tail,)*);
1148	};
1149}
1150
1151impl_packed_extension!(
1152	ByteSlicedAES16x128b,
1153	ByteSlicedAES2x16x64b,
1154	ByteSlicedAES4x16x32b,
1155	ByteSlicedAES8x16x16b,
1156	ByteSlicedAES16x16x8b,
1157	ByteSliced16x128x1b,
1158);
1159
1160impl_packed_extension!(
1161	ByteSlicedAES32x128b,
1162	ByteSlicedAES2x32x64b,
1163	ByteSlicedAES4x32x32b,
1164	ByteSlicedAES8x32x16b,
1165	ByteSlicedAES16x32x8b,
1166	ByteSliced16x256x1b,
1167);
1168
1169impl_packed_extension!(
1170	ByteSlicedAES64x128b,
1171	ByteSlicedAES2x64x64b,
1172	ByteSlicedAES4x64x32b,
1173	ByteSlicedAES8x64x16b,
1174	ByteSlicedAES16x64x8b,
1175	ByteSliced16x512x1b,
1176);