binius_field/arch/portable/byte_sliced/
mod.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3mod invert;
4mod multiply;
5mod packed_byte_sliced;
6mod square;
7mod underlier;
8
9pub use packed_byte_sliced::*;
10pub use underlier::ByteSlicedUnderlier;
11
12#[cfg(test)]
13pub mod tests {
14	use proptest::prelude::*;
15	use rand::{SeedableRng, distributions::Uniform, rngs::StdRng};
16
17	use super::*;
18	use crate::{
19		PackedAESBinaryField1x128b, PackedAESBinaryField2x64b, PackedAESBinaryField2x128b,
20		PackedAESBinaryField4x32b, PackedAESBinaryField4x64b, PackedAESBinaryField4x128b,
21		PackedAESBinaryField8x16b, PackedAESBinaryField8x32b, PackedAESBinaryField8x64b,
22		PackedAESBinaryField16x8b, PackedAESBinaryField16x16b, PackedAESBinaryField16x32b,
23		PackedAESBinaryField32x8b, PackedAESBinaryField32x16b, PackedAESBinaryField64x8b,
24		PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedField,
25		packed::{get_packed_slice, set_packed_slice},
26		underlier::WithUnderlier,
27	};
28
29	fn scalars_vec_strategy<P: PackedField<Scalar: WithUnderlier<Underlier: Arbitrary>>>()
30	-> impl Strategy<Value = Vec<P::Scalar>> {
31		proptest::collection::vec(
32			any::<<P::Scalar as WithUnderlier>::Underlier>().prop_map(P::Scalar::from_underlier),
33			P::WIDTH..=P::WIDTH,
34		)
35	}
36
37	macro_rules! define_byte_sliced_test {
38		($module_name:ident, $name:ident, $scalar_type:ty, $associated_packed:ty) => {
39			mod $module_name {
40				use super::*;
41
42				proptest! {
43					#[test]
44					fn check_from_fn(scalar_elems in scalars_vec_strategy::<$name>()) {
45						let bytesliced = <$name>::from_fn(|i| scalar_elems[i]);
46						for i in 0..<$name>::WIDTH {
47							assert_eq!(scalar_elems[i], bytesliced.get(i));
48						}
49					}
50
51					#[test]
52					fn check_add(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
53						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
54						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
55
56						let bytesliced_result = bytesliced_a + bytesliced_b;
57
58						for i in 0..<$name>::WIDTH {
59							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i));
60						}
61					}
62
63					#[test]
64					fn check_add_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
65						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
66						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
67
68						bytesliced_a += bytesliced_b;
69
70						for i in 0..<$name>::WIDTH {
71							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i));
72						}
73					}
74
75					#[test]
76					fn check_sub(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
77						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
78						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
79
80						let bytesliced_result = bytesliced_a - bytesliced_b;
81
82						for i in 0..<$name>::WIDTH {
83							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i));
84						}
85					}
86
87					#[test]
88					fn check_sub_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
89						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
90						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
91
92						bytesliced_a -= bytesliced_b;
93
94						for i in 0..<$name>::WIDTH {
95							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i));
96						}
97					}
98
99					#[test]
100					fn check_mul(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
101						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
102						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
103
104						let bytesliced_result = bytesliced_a * bytesliced_b;
105
106						for i in 0..<$name>::WIDTH {
107							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i));
108						}
109					}
110
111					#[test]
112					fn check_mul_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
113						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
114						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
115
116						bytesliced_a *= bytesliced_b;
117
118						for i in 0..<$name>::WIDTH {
119							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i));
120						}
121					}
122
123					#[test]
124					fn check_inv(scalar_elems in scalars_vec_strategy::<$name>()) {
125						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
126
127						let bytesliced_result = bytesliced.invert_or_zero();
128
129						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
130							assert_eq!(scalar_elem.invert_or_zero(), bytesliced_result.get(i));
131						}
132					}
133
134					#[test]
135					fn check_square(scalar_elems in scalars_vec_strategy::<$name>()) {
136						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
137
138						let bytesliced_result = bytesliced.square();
139
140						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
141							assert_eq!(scalar_elem.square(), bytesliced_result.get(i));
142						}
143					}
144
145					#[test]
146					fn check_linear_transformation(scalar_elems in scalars_vec_strategy::<$name>()) {
147						use crate::linear_transformation::{PackedTransformationFactory, FieldLinearTransformation, Transformation};
148						use rand::{rngs::StdRng, SeedableRng};
149
150						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
151
152						let linear_transformation = FieldLinearTransformation::random(StdRng::seed_from_u64(0));
153						let packed_transformation = <$name>::make_packed_transformation(linear_transformation.clone());
154
155						let bytesliced_result = packed_transformation.transform(&bytesliced);
156
157						for i in 0..<$name>::WIDTH {
158							assert_eq!(linear_transformation.transform(&scalar_elems[i]), bytesliced_result.get(i));
159						}
160					}
161
162					#[test]
163					fn check_interleave(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
164						let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
165						let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
166
167						for log_block_len in 0..<$name>::LOG_WIDTH {
168							let (bytesliced_c, bytesliced_d) = bytesliced_a.interleave(bytesliced_b, log_block_len);
169
170							let block_len = 1 << log_block_len;
171							for offset in (0..<$name>::WIDTH).step_by(2 * block_len) {
172								for i in 0..block_len {
173									assert_eq!(bytesliced_c.get(offset + i), scalar_elems_a[offset + i]);
174									assert_eq!(bytesliced_c.get(offset + block_len + i), scalar_elems_b[offset + i]);
175
176									assert_eq!(bytesliced_d.get(offset + i), scalar_elems_a[offset + block_len + i]);
177									assert_eq!(bytesliced_d.get(offset + block_len + i), scalar_elems_b[offset + block_len + i]);
178								}
179							}
180						}
181					}
182
183					#[test]
184					fn check_spread(scalar_elems in scalars_vec_strategy::<$name>()) {
185						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
186
187						let mut rng = StdRng::seed_from_u64(0);
188						for log_block_len in 0..<$name>::LOG_WIDTH {
189							// iterating over all possible block lengths is too slow
190							let distribution = Uniform::from(0..(1 << (<$name>::LOG_WIDTH - log_block_len)));
191							let block_index = rng.sample(distribution);
192
193							let bytesliced_result = bytesliced.spread(log_block_len, block_index);
194							let offset = block_index << log_block_len;
195							let repeat_count = 1 << (<$name>::LOG_WIDTH - log_block_len);
196
197							for i in 0..<$name>::WIDTH {
198								assert_eq!(bytesliced_result.get(i), scalar_elems[offset + i / repeat_count]);
199							}
200						}
201					}
202
203					#[test]
204					fn check_transpose_to(scalar_elems in scalars_vec_strategy::<$name>()) {
205						let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
206						let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
207						bytesliced.transpose_to(&mut destination);
208
209						for i in 0..<$name>::WIDTH {
210							assert_eq!(scalar_elems[i], get_packed_slice(&destination, i));
211						}
212					}
213
214					#[test]
215					fn check_transpose_from(scalar_elems in scalars_vec_strategy::<$name>()) {
216						let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
217						for i in 0..<$name>::WIDTH {
218							set_packed_slice(&mut destination, i, scalar_elems[i]);
219						}
220
221						let bytesliced = <$name>::transpose_from(&destination);
222
223						for i in 0..<$name>::WIDTH {
224							assert_eq!(get_packed_slice(&destination, i), bytesliced.get(i));
225						}
226					}
227				}
228			}
229		};
230	}
231
232	// 128-bit byte-sliced
233	define_byte_sliced_test!(
234		tests_3d_16x128,
235		ByteSlicedAES16x128b,
236		AESTowerField128b,
237		PackedAESBinaryField1x128b
238	);
239	define_byte_sliced_test!(
240		tests_3d_16x64,
241		ByteSlicedAES16x64b,
242		AESTowerField64b,
243		PackedAESBinaryField2x64b
244	);
245	define_byte_sliced_test!(
246		tests_3d_2x16x64,
247		ByteSlicedAES2x16x64b,
248		AESTowerField64b,
249		PackedAESBinaryField2x64b
250	);
251	define_byte_sliced_test!(
252		tests_3d_16x32,
253		ByteSlicedAES16x32b,
254		AESTowerField32b,
255		PackedAESBinaryField4x32b
256	);
257	define_byte_sliced_test!(
258		tests_3d_4x16x32,
259		ByteSlicedAES4x16x32b,
260		AESTowerField32b,
261		PackedAESBinaryField4x32b
262	);
263	define_byte_sliced_test!(
264		tests_3d_16x16,
265		ByteSlicedAES16x16b,
266		AESTowerField16b,
267		PackedAESBinaryField8x16b
268	);
269	define_byte_sliced_test!(
270		tests_3d_8x16x16,
271		ByteSlicedAES8x16x16b,
272		AESTowerField16b,
273		PackedAESBinaryField8x16b
274	);
275	define_byte_sliced_test!(
276		tests_3d_16x8,
277		ByteSlicedAES16x8b,
278		AESTowerField8b,
279		PackedAESBinaryField16x8b
280	);
281	define_byte_sliced_test!(
282		tests_3d_16x16x8,
283		ByteSlicedAES16x16x8b,
284		AESTowerField8b,
285		PackedAESBinaryField16x8b
286	);
287
288	define_byte_sliced_test!(
289		tests_3d_16x128x1,
290		ByteSliced16x128x1b,
291		BinaryField1b,
292		PackedBinaryField128x1b
293	);
294	define_byte_sliced_test!(
295		tests_3d_8x128x1,
296		ByteSliced8x128x1b,
297		BinaryField1b,
298		PackedBinaryField128x1b
299	);
300	define_byte_sliced_test!(
301		tests_3d_4x128x1,
302		ByteSliced4x128x1b,
303		BinaryField1b,
304		PackedBinaryField128x1b
305	);
306	define_byte_sliced_test!(
307		tests_3d_2x128x1,
308		ByteSliced2x128x1b,
309		BinaryField1b,
310		PackedBinaryField128x1b
311	);
312	define_byte_sliced_test!(
313		tests_3d_1x128x1,
314		ByteSliced1x128x1b,
315		BinaryField1b,
316		PackedBinaryField128x1b
317	);
318
319	// 256-bit byte-sliced
320	define_byte_sliced_test!(
321		tests_3d_32x128,
322		ByteSlicedAES32x128b,
323		AESTowerField128b,
324		PackedAESBinaryField2x128b
325	);
326	define_byte_sliced_test!(
327		tests_3d_32x64,
328		ByteSlicedAES32x64b,
329		AESTowerField64b,
330		PackedAESBinaryField4x64b
331	);
332	define_byte_sliced_test!(
333		tests_3d_2x32x64,
334		ByteSlicedAES2x32x64b,
335		AESTowerField64b,
336		PackedAESBinaryField4x64b
337	);
338	define_byte_sliced_test!(
339		tests_3d_32x32,
340		ByteSlicedAES32x32b,
341		AESTowerField32b,
342		PackedAESBinaryField8x32b
343	);
344	define_byte_sliced_test!(
345		tests_3d_4x32x32,
346		ByteSlicedAES4x32x32b,
347		AESTowerField32b,
348		PackedAESBinaryField8x32b
349	);
350	define_byte_sliced_test!(
351		tests_3d_32x16,
352		ByteSlicedAES32x16b,
353		AESTowerField16b,
354		PackedAESBinaryField16x16b
355	);
356	define_byte_sliced_test!(
357		tests_3d_8x32x16,
358		ByteSlicedAES8x32x16b,
359		AESTowerField16b,
360		PackedAESBinaryField16x16b
361	);
362	define_byte_sliced_test!(
363		tests_3d_32x8,
364		ByteSlicedAES32x8b,
365		AESTowerField8b,
366		PackedAESBinaryField32x8b
367	);
368	define_byte_sliced_test!(
369		tests_3d_16x32x8,
370		ByteSlicedAES16x32x8b,
371		AESTowerField8b,
372		PackedAESBinaryField32x8b
373	);
374
375	define_byte_sliced_test!(
376		tests_3d_16x256x1,
377		ByteSliced16x256x1b,
378		BinaryField1b,
379		PackedBinaryField256x1b
380	);
381	define_byte_sliced_test!(
382		tests_3d_8x256x1,
383		ByteSliced8x256x1b,
384		BinaryField1b,
385		PackedBinaryField256x1b
386	);
387	define_byte_sliced_test!(
388		tests_3d_4x256x1,
389		ByteSliced4x256x1b,
390		BinaryField1b,
391		PackedBinaryField256x1b
392	);
393	define_byte_sliced_test!(
394		tests_3d_2x256x1,
395		ByteSliced2x256x1b,
396		BinaryField1b,
397		PackedBinaryField256x1b
398	);
399	define_byte_sliced_test!(
400		tests_3d_1x256x1,
401		ByteSliced1x256x1b,
402		BinaryField1b,
403		PackedBinaryField256x1b
404	);
405
406	// 512-bit byte-sliced
407	define_byte_sliced_test!(
408		tests_3d_64x128,
409		ByteSlicedAES64x128b,
410		AESTowerField128b,
411		PackedAESBinaryField4x128b
412	);
413	define_byte_sliced_test!(
414		tests_3d_64x64,
415		ByteSlicedAES64x64b,
416		AESTowerField64b,
417		PackedAESBinaryField8x64b
418	);
419	define_byte_sliced_test!(
420		tests_3d_2x64x64,
421		ByteSlicedAES2x64x64b,
422		AESTowerField64b,
423		PackedAESBinaryField8x64b
424	);
425	define_byte_sliced_test!(
426		tests_3d_64x32,
427		ByteSlicedAES64x32b,
428		AESTowerField32b,
429		PackedAESBinaryField16x32b
430	);
431	define_byte_sliced_test!(
432		tests_3d_4x64x32,
433		ByteSlicedAES4x64x32b,
434		AESTowerField32b,
435		PackedAESBinaryField16x32b
436	);
437	define_byte_sliced_test!(
438		tests_3d_64x16,
439		ByteSlicedAES64x16b,
440		AESTowerField16b,
441		PackedAESBinaryField32x16b
442	);
443	define_byte_sliced_test!(
444		tests_3d_8x64x16,
445		ByteSlicedAES8x64x16b,
446		AESTowerField16b,
447		PackedAESBinaryField32x16b
448	);
449	define_byte_sliced_test!(
450		tests_3d_64x8,
451		ByteSlicedAES64x8b,
452		AESTowerField8b,
453		PackedAESBinaryField64x8b
454	);
455	define_byte_sliced_test!(
456		tests_3d_16x64x8,
457		ByteSlicedAES16x64x8b,
458		AESTowerField8b,
459		PackedAESBinaryField64x8b
460	);
461
462	define_byte_sliced_test!(
463		tests_3d_16x512x1,
464		ByteSliced16x512x1b,
465		BinaryField1b,
466		PackedBinaryField512x1b
467	);
468	define_byte_sliced_test!(
469		tests_3d_8x512x1,
470		ByteSliced8x512x1b,
471		BinaryField1b,
472		PackedBinaryField512x1b
473	);
474	define_byte_sliced_test!(
475		tests_3d_4x512x1,
476		ByteSliced4x512x1b,
477		BinaryField1b,
478		PackedBinaryField512x1b
479	);
480	define_byte_sliced_test!(
481		tests_3d_2x512x1,
482		ByteSliced2x512x1b,
483		BinaryField1b,
484		PackedBinaryField512x1b
485	);
486	define_byte_sliced_test!(
487		tests_3d_1x512x1,
488		ByteSliced1x512x1b,
489		BinaryField1b,
490		PackedBinaryField512x1b
491	);
492}