binius_field/arch/portable/byte_sliced/
mod.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3mod invert;
4mod multiply;
5mod packed_byte_sliced;
6mod square;
7
8pub use packed_byte_sliced::*;
9
10#[cfg(test)]
11pub mod tests {
12	use super::*;
13	use crate::{
14		packed::{get_packed_slice, set_packed_slice},
15		PackedAESBinaryField16x16b, PackedAESBinaryField16x32b, PackedAESBinaryField16x8b,
16		PackedAESBinaryField1x128b, PackedAESBinaryField2x128b, PackedAESBinaryField2x64b,
17		PackedAESBinaryField32x16b, PackedAESBinaryField32x8b, PackedAESBinaryField4x128b,
18		PackedAESBinaryField4x32b, PackedAESBinaryField4x64b, PackedAESBinaryField64x8b,
19		PackedAESBinaryField8x16b, PackedAESBinaryField8x32b, PackedAESBinaryField8x64b,
20	};
21
22	macro_rules! define_byte_sliced_test {
23		($module_name:ident, $name:ident, $scalar_type:ty, $associated_packed:ty) => {
24			mod $module_name {
25				use proptest::prelude::*;
26
27				use super::*;
28				use crate::{packed::PackedField, underlier::WithUnderlier, $scalar_type};
29
30				fn scalar_array_strategy() -> impl Strategy<Value = [$scalar_type; <$name>::WIDTH]>
31				{
32					any::<[<$scalar_type as WithUnderlier>::Underlier; <$name>::WIDTH]>()
33						.prop_map(|arr| arr.map(<$scalar_type>::from_underlier))
34				}
35
36				proptest! {
37					#[test]
38					fn check_add(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
39						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
40						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
41
42						let bytesliced_result = bytesliced_a + bytesliced_b;
43
44						for i in 0..<$name>::WIDTH {
45							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i));
46						}
47					}
48
49					#[test]
50					fn check_add_assign(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
51						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a);
52						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
53
54						bytesliced_a += bytesliced_b;
55
56						for i in 0..<$name>::WIDTH {
57							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i));
58						}
59					}
60
61					#[test]
62					fn check_sub(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
63						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
64						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
65
66						let bytesliced_result = bytesliced_a - bytesliced_b;
67
68						for i in 0..<$name>::WIDTH {
69							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i));
70						}
71					}
72
73					#[test]
74					fn check_sub_assign(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
75						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a);
76						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
77
78						bytesliced_a -= bytesliced_b;
79
80						for i in 0..<$name>::WIDTH {
81							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i));
82						}
83					}
84
85					#[test]
86					fn check_mul(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
87						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
88						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
89
90						let bytesliced_result = bytesliced_a * bytesliced_b;
91
92						for i in 0..<$name>::WIDTH {
93							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i));
94						}
95					}
96
97					#[test]
98					fn check_mul_assign(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
99						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a);
100						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
101
102						bytesliced_a *= bytesliced_b;
103
104						for i in 0..<$name>::WIDTH {
105							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i));
106						}
107					}
108
109					#[test]
110					fn check_inv(scalar_elems in scalar_array_strategy()) {
111						let bytesliced = <$name>::from_scalars(scalar_elems);
112
113						let bytesliced_result = bytesliced.invert_or_zero();
114
115						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
116							assert_eq!(scalar_elem.invert_or_zero(), bytesliced_result.get(i));
117						}
118					}
119
120					#[test]
121					fn check_square(scalar_elems in scalar_array_strategy()) {
122						let bytesliced = <$name>::from_scalars(scalar_elems);
123
124						let bytesliced_result = bytesliced.square();
125
126						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
127							assert_eq!(scalar_elem.square(), bytesliced_result.get(i));
128						}
129					}
130
131					#[test]
132					fn check_linear_transformation(scalar_elems in scalar_array_strategy()) {
133						use crate::linear_transformation::{PackedTransformationFactory, FieldLinearTransformation, Transformation};
134						use rand::{rngs::StdRng, SeedableRng};
135
136						let bytesliced = <$name>::from_scalars(scalar_elems);
137
138						let linear_transformation = FieldLinearTransformation::random(StdRng::seed_from_u64(0));
139						let packed_transformation = <$name>::make_packed_transformation(linear_transformation.clone());
140
141						let bytesliced_result = packed_transformation.transform(&bytesliced);
142
143						for i in 0..<$name>::WIDTH {
144							assert_eq!(linear_transformation.transform(&scalar_elems[i]), bytesliced_result.get(i));
145						}
146					}
147
148					#[test]
149					fn check_interleave(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
150						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
151						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
152
153						for log_block_len in 0..<$name>::LOG_WIDTH {
154							let (bytesliced_c, bytesliced_d) = bytesliced_a.interleave(bytesliced_b, log_block_len);
155
156							let block_len = 1 << log_block_len;
157							for offset in (0..<$name>::WIDTH).step_by(2 * block_len) {
158								for i in 0..block_len {
159									assert_eq!(bytesliced_c.get(offset + i), scalar_elems_a[offset + i]);
160									assert_eq!(bytesliced_c.get(offset + block_len + i), scalar_elems_b[offset + i]);
161
162									assert_eq!(bytesliced_d.get(offset + i), scalar_elems_a[offset + block_len + i]);
163									assert_eq!(bytesliced_d.get(offset + block_len + i), scalar_elems_b[offset + block_len + i]);
164								}
165							}
166						}
167					}
168
169					#[test]
170					fn check_transpose_to(scalar_elems in scalar_array_strategy()) {
171						let bytesliced = <$name>::from_scalars(scalar_elems);
172						let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
173						bytesliced.transpose_to(&mut destination);
174
175						for i in 0..<$name>::WIDTH {
176							assert_eq!(scalar_elems[i], get_packed_slice(&destination, i));
177						}
178					}
179
180					#[test]
181					fn check_transpose_from(scalar_elems in scalar_array_strategy()) {
182						let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
183						for i in 0..<$name>::WIDTH {
184							set_packed_slice(&mut destination, i, scalar_elems[i]);
185						}
186
187						let bytesliced = <$name>::transpose_from(&destination);
188
189						for i in 0..<$name>::WIDTH {
190							assert_eq!(get_packed_slice(&destination, i), bytesliced.get(i));
191						}
192					}
193				}
194			}
195		};
196	}
197
198	// 128-bit byte-sliced
199	define_byte_sliced_test!(
200		tests_3d_16x128,
201		ByteSlicedAES16x128b,
202		AESTowerField128b,
203		PackedAESBinaryField1x128b
204	);
205	define_byte_sliced_test!(
206		tests_3d_16x64,
207		ByteSlicedAES16x64b,
208		AESTowerField64b,
209		PackedAESBinaryField2x64b
210	);
211	define_byte_sliced_test!(
212		tests_3d_2x16x64,
213		ByteSlicedAES2x16x64b,
214		AESTowerField64b,
215		PackedAESBinaryField2x64b
216	);
217	define_byte_sliced_test!(
218		tests_3d_16x32,
219		ByteSlicedAES16x32b,
220		AESTowerField32b,
221		PackedAESBinaryField4x32b
222	);
223	define_byte_sliced_test!(
224		tests_3d_4x16x32,
225		ByteSlicedAES4x16x32b,
226		AESTowerField32b,
227		PackedAESBinaryField4x32b
228	);
229	define_byte_sliced_test!(
230		tests_3d_16x16,
231		ByteSlicedAES16x16b,
232		AESTowerField16b,
233		PackedAESBinaryField8x16b
234	);
235	define_byte_sliced_test!(
236		tests_3d_8x16x16,
237		ByteSlicedAES8x16x16b,
238		AESTowerField16b,
239		PackedAESBinaryField8x16b
240	);
241	define_byte_sliced_test!(
242		tests_3d_16x8,
243		ByteSlicedAES16x8b,
244		AESTowerField8b,
245		PackedAESBinaryField16x8b
246	);
247	define_byte_sliced_test!(
248		tests_3d_16x16x8,
249		ByteSlicedAES16x16x8b,
250		AESTowerField8b,
251		PackedAESBinaryField16x8b
252	);
253
254	// 256-bit byte-sliced
255	define_byte_sliced_test!(
256		tests_3d_32x128,
257		ByteSlicedAES32x128b,
258		AESTowerField128b,
259		PackedAESBinaryField2x128b
260	);
261	define_byte_sliced_test!(
262		tests_3d_32x64,
263		ByteSlicedAES32x64b,
264		AESTowerField64b,
265		PackedAESBinaryField4x64b
266	);
267	define_byte_sliced_test!(
268		tests_3d_2x32x64,
269		ByteSlicedAES2x32x64b,
270		AESTowerField64b,
271		PackedAESBinaryField4x64b
272	);
273	define_byte_sliced_test!(
274		tests_3d_32x32,
275		ByteSlicedAES32x32b,
276		AESTowerField32b,
277		PackedAESBinaryField8x32b
278	);
279	define_byte_sliced_test!(
280		tests_3d_4x32x32,
281		ByteSlicedAES4x32x32b,
282		AESTowerField32b,
283		PackedAESBinaryField8x32b
284	);
285	define_byte_sliced_test!(
286		tests_3d_32x16,
287		ByteSlicedAES32x16b,
288		AESTowerField16b,
289		PackedAESBinaryField16x16b
290	);
291	define_byte_sliced_test!(
292		tests_3d_8x32x16,
293		ByteSlicedAES8x32x16b,
294		AESTowerField16b,
295		PackedAESBinaryField16x16b
296	);
297	define_byte_sliced_test!(
298		tests_3d_32x8,
299		ByteSlicedAES32x8b,
300		AESTowerField8b,
301		PackedAESBinaryField32x8b
302	);
303	define_byte_sliced_test!(
304		tests_3d_16x32x8,
305		ByteSlicedAES16x32x8b,
306		AESTowerField8b,
307		PackedAESBinaryField32x8b
308	);
309
310	// 512-bit byte-sliced
311	define_byte_sliced_test!(
312		tests_3d_64x128,
313		ByteSlicedAES64x128b,
314		AESTowerField128b,
315		PackedAESBinaryField4x128b
316	);
317	define_byte_sliced_test!(
318		tests_3d_64x64,
319		ByteSlicedAES64x64b,
320		AESTowerField64b,
321		PackedAESBinaryField8x64b
322	);
323	define_byte_sliced_test!(
324		tests_3d_2x64x64,
325		ByteSlicedAES2x64x64b,
326		AESTowerField64b,
327		PackedAESBinaryField8x64b
328	);
329	define_byte_sliced_test!(
330		tests_3d_64x32,
331		ByteSlicedAES64x32b,
332		AESTowerField32b,
333		PackedAESBinaryField16x32b
334	);
335	define_byte_sliced_test!(
336		tests_3d_4x64x32,
337		ByteSlicedAES4x64x32b,
338		AESTowerField32b,
339		PackedAESBinaryField16x32b
340	);
341	define_byte_sliced_test!(
342		tests_3d_64x16,
343		ByteSlicedAES64x16b,
344		AESTowerField16b,
345		PackedAESBinaryField32x16b
346	);
347	define_byte_sliced_test!(
348		tests_3d_8x64x16,
349		ByteSlicedAES8x64x16b,
350		AESTowerField16b,
351		PackedAESBinaryField32x16b
352	);
353	define_byte_sliced_test!(
354		tests_3d_64x8,
355		ByteSlicedAES64x8b,
356		AESTowerField8b,
357		PackedAESBinaryField64x8b
358	);
359	define_byte_sliced_test!(
360		tests_3d_16x64x8,
361		ByteSlicedAES16x64x8b,
362		AESTowerField8b,
363		PackedAESBinaryField64x8b
364	);
365}