binius_field/arch/portable/byte_sliced/
mod.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3mod invert;
4mod multiply;
5mod packed_byte_sliced;
6mod square;
7mod underlier;
8
9pub use packed_byte_sliced::*;
10pub use underlier::ByteSlicedUnderlier;
11
12#[cfg(test)]
13pub mod tests {
14	use proptest::prelude::*;
15
16	use super::*;
17	use crate::{
18		packed::{get_packed_slice, set_packed_slice},
19		underlier::WithUnderlier,
20		PackedAESBinaryField16x16b, PackedAESBinaryField16x32b, PackedAESBinaryField16x8b,
21		PackedAESBinaryField1x128b, PackedAESBinaryField2x128b, PackedAESBinaryField2x64b,
22		PackedAESBinaryField32x16b, PackedAESBinaryField32x8b, PackedAESBinaryField4x128b,
23		PackedAESBinaryField4x32b, PackedAESBinaryField4x64b, PackedAESBinaryField64x8b,
24		PackedAESBinaryField8x16b, PackedAESBinaryField8x32b, PackedAESBinaryField8x64b,
25		PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedField,
26	};
27
28	fn scalars_vec_strategy<P: PackedField<Scalar: WithUnderlier<Underlier: Arbitrary>>>(
29	) -> impl Strategy<Value = Vec<P::Scalar>> {
30		proptest::collection::vec(
31			any::<<P::Scalar as WithUnderlier>::Underlier>().prop_map(P::Scalar::from_underlier),
32			P::WIDTH..=P::WIDTH,
33		)
34	}
35
36	macro_rules! define_byte_sliced_test {
37		($module_name:ident, $name:ident, $scalar_type:ty, $associated_packed:ty) => {
38			mod $module_name {
39				use super::*;
40
41				proptest! {
42					#[test]
43					fn check_from_fn(scalar_elems in scalars_vec_strategy::<$name>()) {
44						let bytesliced = <$name>::from_fn(|i| scalar_elems[i]);
45						for i in 0..<$name>::WIDTH {
46							assert_eq!(scalar_elems[i], bytesliced.get(i));
47						}
48					}
49
50					#[test]
51					fn check_add(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
52						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
53						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
54
55						let bytesliced_result = bytesliced_a + bytesliced_b;
56
57						for i in 0..<$name>::WIDTH {
58							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i));
59						}
60					}
61
62					#[test]
63					fn check_add_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
64						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
65						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
66
67						bytesliced_a += bytesliced_b;
68
69						for i in 0..<$name>::WIDTH {
70							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i));
71						}
72					}
73
74					#[test]
75					fn check_sub(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
76						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
77						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
78
79						let bytesliced_result = bytesliced_a - bytesliced_b;
80
81						for i in 0..<$name>::WIDTH {
82							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i));
83						}
84					}
85
86					#[test]
87					fn check_sub_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
88						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
89						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
90
91						bytesliced_a -= bytesliced_b;
92
93						for i in 0..<$name>::WIDTH {
94							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i));
95						}
96					}
97
98					#[test]
99					fn check_mul(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
100						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
101						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
102
103						let bytesliced_result = bytesliced_a * bytesliced_b;
104
105						for i in 0..<$name>::WIDTH {
106							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i));
107						}
108					}
109
110					#[test]
111					fn check_mul_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
112						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
113						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
114
115						bytesliced_a *= bytesliced_b;
116
117						for i in 0..<$name>::WIDTH {
118							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i));
119						}
120					}
121
122					#[test]
123					fn check_inv(scalar_elems in scalars_vec_strategy::<$name>()) {
124						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
125
126						let bytesliced_result = bytesliced.invert_or_zero();
127
128						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
129							assert_eq!(scalar_elem.invert_or_zero(), bytesliced_result.get(i));
130						}
131					}
132
133					#[test]
134					fn check_square(scalar_elems in scalars_vec_strategy::<$name>()) {
135						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
136
137						let bytesliced_result = bytesliced.square();
138
139						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
140							assert_eq!(scalar_elem.square(), bytesliced_result.get(i));
141						}
142					}
143
144					#[test]
145					fn check_linear_transformation(scalar_elems in scalars_vec_strategy::<$name>()) {
146						use crate::linear_transformation::{PackedTransformationFactory, FieldLinearTransformation, Transformation};
147						use rand::{rngs::StdRng, SeedableRng};
148
149						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
150
151						let linear_transformation = FieldLinearTransformation::random(StdRng::seed_from_u64(0));
152						let packed_transformation = <$name>::make_packed_transformation(linear_transformation.clone());
153
154						let bytesliced_result = packed_transformation.transform(&bytesliced);
155
156						for i in 0..<$name>::WIDTH {
157							assert_eq!(linear_transformation.transform(&scalar_elems[i]), bytesliced_result.get(i));
158						}
159					}
160
161					#[test]
162					fn check_interleave(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
163						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
164						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
165
166						for log_block_len in 0..<$name>::LOG_WIDTH {
167							let (bytesliced_c, bytesliced_d) = bytesliced_a.interleave(bytesliced_b, log_block_len);
168
169							let block_len = 1 << log_block_len;
170							for offset in (0..<$name>::WIDTH).step_by(2 * block_len) {
171								for i in 0..block_len {
172									assert_eq!(bytesliced_c.get(offset + i), scalar_elems_a[offset + i]);
173									assert_eq!(bytesliced_c.get(offset + block_len + i), scalar_elems_b[offset + i]);
174
175									assert_eq!(bytesliced_d.get(offset + i), scalar_elems_a[offset + block_len + i]);
176									assert_eq!(bytesliced_d.get(offset + block_len + i), scalar_elems_b[offset + block_len + i]);
177								}
178							}
179						}
180					}
181
182					#[test]
183					fn check_transpose_to(scalar_elems in scalars_vec_strategy::<$name>()) {
184						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
185						let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
186						bytesliced.transpose_to(&mut destination);
187
188						for i in 0..<$name>::WIDTH {
189							assert_eq!(scalar_elems[i], get_packed_slice(&destination, i));
190						}
191					}
192
193					#[test]
194					fn check_transpose_from(scalar_elems in scalars_vec_strategy::<$name>()) {
195						let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
196						for i in 0..<$name>::WIDTH {
197							set_packed_slice(&mut destination, i, scalar_elems[i]);
198						}
199
200						let bytesliced = <$name>::transpose_from(&destination);
201
202						for i in 0..<$name>::WIDTH {
203							assert_eq!(get_packed_slice(&destination, i), bytesliced.get(i));
204						}
205					}
206				}
207			}
208		};
209	}
210
211	// 128-bit byte-sliced
212	define_byte_sliced_test!(
213		tests_3d_16x128,
214		ByteSlicedAES16x128b,
215		AESTowerField128b,
216		PackedAESBinaryField1x128b
217	);
218	define_byte_sliced_test!(
219		tests_3d_16x64,
220		ByteSlicedAES16x64b,
221		AESTowerField64b,
222		PackedAESBinaryField2x64b
223	);
224	define_byte_sliced_test!(
225		tests_3d_2x16x64,
226		ByteSlicedAES2x16x64b,
227		AESTowerField64b,
228		PackedAESBinaryField2x64b
229	);
230	define_byte_sliced_test!(
231		tests_3d_16x32,
232		ByteSlicedAES16x32b,
233		AESTowerField32b,
234		PackedAESBinaryField4x32b
235	);
236	define_byte_sliced_test!(
237		tests_3d_4x16x32,
238		ByteSlicedAES4x16x32b,
239		AESTowerField32b,
240		PackedAESBinaryField4x32b
241	);
242	define_byte_sliced_test!(
243		tests_3d_16x16,
244		ByteSlicedAES16x16b,
245		AESTowerField16b,
246		PackedAESBinaryField8x16b
247	);
248	define_byte_sliced_test!(
249		tests_3d_8x16x16,
250		ByteSlicedAES8x16x16b,
251		AESTowerField16b,
252		PackedAESBinaryField8x16b
253	);
254	define_byte_sliced_test!(
255		tests_3d_16x8,
256		ByteSlicedAES16x8b,
257		AESTowerField8b,
258		PackedAESBinaryField16x8b
259	);
260	define_byte_sliced_test!(
261		tests_3d_16x16x8,
262		ByteSlicedAES16x16x8b,
263		AESTowerField8b,
264		PackedAESBinaryField16x8b
265	);
266
267	define_byte_sliced_test!(
268		tests_3d_16x128x1,
269		ByteSliced16x128x1b,
270		BinaryField1b,
271		PackedBinaryField128x1b
272	);
273	define_byte_sliced_test!(
274		tests_3d_8x128x1,
275		ByteSliced8x128x1b,
276		BinaryField1b,
277		PackedBinaryField128x1b
278	);
279	define_byte_sliced_test!(
280		tests_3d_4x128x1,
281		ByteSliced4x128x1b,
282		BinaryField1b,
283		PackedBinaryField128x1b
284	);
285	define_byte_sliced_test!(
286		tests_3d_2x128x1,
287		ByteSliced2x128x1b,
288		BinaryField1b,
289		PackedBinaryField128x1b
290	);
291	define_byte_sliced_test!(
292		tests_3d_1x128x1,
293		ByteSliced1x128x1b,
294		BinaryField1b,
295		PackedBinaryField128x1b
296	);
297
298	// 256-bit byte-sliced
299	define_byte_sliced_test!(
300		tests_3d_32x128,
301		ByteSlicedAES32x128b,
302		AESTowerField128b,
303		PackedAESBinaryField2x128b
304	);
305	define_byte_sliced_test!(
306		tests_3d_32x64,
307		ByteSlicedAES32x64b,
308		AESTowerField64b,
309		PackedAESBinaryField4x64b
310	);
311	define_byte_sliced_test!(
312		tests_3d_2x32x64,
313		ByteSlicedAES2x32x64b,
314		AESTowerField64b,
315		PackedAESBinaryField4x64b
316	);
317	define_byte_sliced_test!(
318		tests_3d_32x32,
319		ByteSlicedAES32x32b,
320		AESTowerField32b,
321		PackedAESBinaryField8x32b
322	);
323	define_byte_sliced_test!(
324		tests_3d_4x32x32,
325		ByteSlicedAES4x32x32b,
326		AESTowerField32b,
327		PackedAESBinaryField8x32b
328	);
329	define_byte_sliced_test!(
330		tests_3d_32x16,
331		ByteSlicedAES32x16b,
332		AESTowerField16b,
333		PackedAESBinaryField16x16b
334	);
335	define_byte_sliced_test!(
336		tests_3d_8x32x16,
337		ByteSlicedAES8x32x16b,
338		AESTowerField16b,
339		PackedAESBinaryField16x16b
340	);
341	define_byte_sliced_test!(
342		tests_3d_32x8,
343		ByteSlicedAES32x8b,
344		AESTowerField8b,
345		PackedAESBinaryField32x8b
346	);
347	define_byte_sliced_test!(
348		tests_3d_16x32x8,
349		ByteSlicedAES16x32x8b,
350		AESTowerField8b,
351		PackedAESBinaryField32x8b
352	);
353
354	define_byte_sliced_test!(
355		tests_3d_16x256x1,
356		ByteSliced16x256x1b,
357		BinaryField1b,
358		PackedBinaryField256x1b
359	);
360	define_byte_sliced_test!(
361		tests_3d_8x256x1,
362		ByteSliced8x256x1b,
363		BinaryField1b,
364		PackedBinaryField256x1b
365	);
366	define_byte_sliced_test!(
367		tests_3d_4x256x1,
368		ByteSliced4x256x1b,
369		BinaryField1b,
370		PackedBinaryField256x1b
371	);
372	define_byte_sliced_test!(
373		tests_3d_2x256x1,
374		ByteSliced2x256x1b,
375		BinaryField1b,
376		PackedBinaryField256x1b
377	);
378	define_byte_sliced_test!(
379		tests_3d_1x256x1,
380		ByteSliced1x256x1b,
381		BinaryField1b,
382		PackedBinaryField256x1b
383	);
384
385	// 512-bit byte-sliced
386	define_byte_sliced_test!(
387		tests_3d_64x128,
388		ByteSlicedAES64x128b,
389		AESTowerField128b,
390		PackedAESBinaryField4x128b
391	);
392	define_byte_sliced_test!(
393		tests_3d_64x64,
394		ByteSlicedAES64x64b,
395		AESTowerField64b,
396		PackedAESBinaryField8x64b
397	);
398	define_byte_sliced_test!(
399		tests_3d_2x64x64,
400		ByteSlicedAES2x64x64b,
401		AESTowerField64b,
402		PackedAESBinaryField8x64b
403	);
404	define_byte_sliced_test!(
405		tests_3d_64x32,
406		ByteSlicedAES64x32b,
407		AESTowerField32b,
408		PackedAESBinaryField16x32b
409	);
410	define_byte_sliced_test!(
411		tests_3d_4x64x32,
412		ByteSlicedAES4x64x32b,
413		AESTowerField32b,
414		PackedAESBinaryField16x32b
415	);
416	define_byte_sliced_test!(
417		tests_3d_64x16,
418		ByteSlicedAES64x16b,
419		AESTowerField16b,
420		PackedAESBinaryField32x16b
421	);
422	define_byte_sliced_test!(
423		tests_3d_8x64x16,
424		ByteSlicedAES8x64x16b,
425		AESTowerField16b,
426		PackedAESBinaryField32x16b
427	);
428	define_byte_sliced_test!(
429		tests_3d_64x8,
430		ByteSlicedAES64x8b,
431		AESTowerField8b,
432		PackedAESBinaryField64x8b
433	);
434	define_byte_sliced_test!(
435		tests_3d_16x64x8,
436		ByteSlicedAES16x64x8b,
437		AESTowerField8b,
438		PackedAESBinaryField64x8b
439	);
440
441	define_byte_sliced_test!(
442		tests_3d_16x512x1,
443		ByteSliced16x512x1b,
444		BinaryField1b,
445		PackedBinaryField512x1b
446	);
447	define_byte_sliced_test!(
448		tests_3d_8x512x1,
449		ByteSliced8x512x1b,
450		BinaryField1b,
451		PackedBinaryField512x1b
452	);
453	define_byte_sliced_test!(
454		tests_3d_4x512x1,
455		ByteSliced4x512x1b,
456		BinaryField1b,
457		PackedBinaryField512x1b
458	);
459	define_byte_sliced_test!(
460		tests_3d_2x512x1,
461		ByteSliced2x512x1b,
462		BinaryField1b,
463		PackedBinaryField512x1b
464	);
465	define_byte_sliced_test!(
466		tests_3d_1x512x1,
467		ByteSliced1x512x1b,
468		BinaryField1b,
469		PackedBinaryField512x1b
470	);
471}