binius_field/arch/portable/byte_sliced/
mod.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3mod invert;
4mod multiply;
5mod packed_byte_sliced;
6mod square;
7
8pub use packed_byte_sliced::*;
9
10#[cfg(test)]
11pub mod tests {
12	macro_rules! define_byte_sliced_test {
13		($module_name:ident, $name:ident, $scalar_type:ty) => {
14			mod $module_name{
15				use proptest::prelude::*;
16				use crate::{$scalar_type, underlier::WithUnderlier, packed::PackedField, arch::byte_sliced::$name};
17
18				fn scalar_array_strategy() -> impl Strategy<Value = [$scalar_type; <$name>::WIDTH]> {
19					any::<[<$scalar_type as WithUnderlier>::Underlier; <$name>::WIDTH]>().prop_map(|arr| arr.map(<$scalar_type>::from_underlier))
20				}
21
22				proptest! {
23					#[test]
24					fn check_add(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
25						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
26						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
27
28						let bytesliced_result = bytesliced_a + bytesliced_b;
29
30						for i in 0..<$name>::WIDTH {
31							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i));
32						}
33					}
34
35					#[test]
36					fn check_add_assign(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
37						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a);
38						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
39
40						bytesliced_a += bytesliced_b;
41
42						for i in 0..<$name>::WIDTH {
43							assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i));
44						}
45					}
46
47					#[test]
48					fn check_sub(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
49						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
50						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
51
52						let bytesliced_result = bytesliced_a - bytesliced_b;
53
54						for i in 0..<$name>::WIDTH {
55							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i));
56						}
57					}
58
59					#[test]
60					fn check_sub_assign(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
61						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a);
62						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
63
64						bytesliced_a -= bytesliced_b;
65
66						for i in 0..<$name>::WIDTH {
67							assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i));
68						}
69					}
70
71					#[test]
72					fn check_mul(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
73						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
74						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
75
76						let bytesliced_result = bytesliced_a * bytesliced_b;
77
78						for i in 0..<$name>::WIDTH {
79							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i));
80						}
81					}
82
83					#[test]
84					fn check_mul_assign(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
85						let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a);
86						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
87
88						bytesliced_a *= bytesliced_b;
89
90						for i in 0..<$name>::WIDTH {
91							assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i));
92						}
93					}
94
95					#[test]
96					fn check_inv(scalar_elems in scalar_array_strategy()) {
97						let bytesliced = <$name>::from_scalars(scalar_elems);
98
99						let bytesliced_result = bytesliced.invert_or_zero();
100
101						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
102							assert_eq!(scalar_elem.invert_or_zero(), bytesliced_result.get(i));
103						}
104					}
105
106					#[test]
107					fn check_square(scalar_elems in scalar_array_strategy()) {
108						let bytesliced = <$name>::from_scalars(scalar_elems);
109
110						let bytesliced_result = bytesliced.square();
111
112						for (i, scalar_elem) in scalar_elems.iter().enumerate() {
113							assert_eq!(scalar_elem.square(), bytesliced_result.get(i));
114						}
115					}
116
117					#[test]
118					fn check_linear_transformation(scalar_elems in scalar_array_strategy()) {
119						use crate::linear_transformation::{PackedTransformationFactory, FieldLinearTransformation, Transformation};
120						use rand::{rngs::StdRng, SeedableRng};
121
122						let bytesliced = <$name>::from_scalars(scalar_elems);
123
124						let linear_transformation = FieldLinearTransformation::random(StdRng::seed_from_u64(0));
125						let packed_transformation = <$name>::make_packed_transformation(linear_transformation.clone());
126
127						let bytesliced_result = packed_transformation.transform(&bytesliced);
128
129						for i in 0..<$name>::WIDTH {
130							assert_eq!(linear_transformation.transform(&scalar_elems[i]), bytesliced_result.get(i));
131						}
132					}
133
134					#[test]
135					fn check_interleave(scalar_elems_a in scalar_array_strategy(), scalar_elems_b in scalar_array_strategy()) {
136						let bytesliced_a = <$name>::from_scalars(scalar_elems_a);
137						let bytesliced_b = <$name>::from_scalars(scalar_elems_b);
138
139						for log_block_len in 0..<$name>::LOG_WIDTH {
140							let (bytesliced_c, bytesliced_d) = bytesliced_a.interleave(bytesliced_b, log_block_len);
141
142							let block_len = 1 << log_block_len;
143							for offset in (0..<$name>::WIDTH).step_by(2 * block_len) {
144								for i in 0..block_len {
145									assert_eq!(bytesliced_c.get(offset + i), scalar_elems_a[offset + i]);
146									assert_eq!(bytesliced_c.get(offset + block_len + i), scalar_elems_b[offset + i]);
147
148									assert_eq!(bytesliced_d.get(offset + i), scalar_elems_a[offset + block_len + i]);
149									assert_eq!(bytesliced_d.get(offset + block_len + i), scalar_elems_b[offset + block_len + i]);
150								}
151							}
152						}
153					}
154				}
155			}
156		};
157	}
158
159	// 128-bit byte-sliced
160	define_byte_sliced_test!(tests_16x128, ByteSlicedAES16x128b, AESTowerField128b);
161	define_byte_sliced_test!(tests_16x64, ByteSlicedAES16x64b, AESTowerField64b);
162	define_byte_sliced_test!(tests_16x32, ByteSlicedAES16x32b, AESTowerField32b);
163	define_byte_sliced_test!(tests_16x16, ByteSlicedAES16x16b, AESTowerField16b);
164	define_byte_sliced_test!(tests_16x8, ByteSlicedAES16x8b, AESTowerField8b);
165
166	define_byte_sliced_test!(tests_16x16x8b, ByteSlicedAES16x16x8b, AESTowerField8b);
167	define_byte_sliced_test!(tests_8x16x8b, ByteSlicedAES8x16x8b, AESTowerField8b);
168	define_byte_sliced_test!(tests_4x16x8b, ByteSlicedAES4x16x8b, AESTowerField8b);
169	define_byte_sliced_test!(tests_2x16x8b, ByteSlicedAES2x16x8b, AESTowerField8b);
170
171	// 256-bit byte-sliced
172	define_byte_sliced_test!(tests_32x128, ByteSlicedAES32x128b, AESTowerField128b);
173	define_byte_sliced_test!(tests_32x64, ByteSlicedAES32x64b, AESTowerField64b);
174	define_byte_sliced_test!(tests_32x32, ByteSlicedAES32x32b, AESTowerField32b);
175	define_byte_sliced_test!(tests_32x16, ByteSlicedAES32x16b, AESTowerField16b);
176	define_byte_sliced_test!(tests_32x8, ByteSlicedAES32x8b, AESTowerField8b);
177
178	define_byte_sliced_test!(tests_16x32x8b, ByteSlicedAES16x32x8b, AESTowerField8b);
179	define_byte_sliced_test!(tests_8x32x8b, ByteSlicedAES8x32x8b, AESTowerField8b);
180	define_byte_sliced_test!(tests_4x32x8b, ByteSlicedAES4x32x8b, AESTowerField8b);
181	define_byte_sliced_test!(tests_2x32x8b, ByteSlicedAES2x32x8b, AESTowerField8b);
182
183	// 512-bit byte-sliced
184	define_byte_sliced_test!(tests_64x128, ByteSlicedAES64x128b, AESTowerField128b);
185	define_byte_sliced_test!(tests_64x64, ByteSlicedAES64x64b, AESTowerField64b);
186	define_byte_sliced_test!(tests_64x32, ByteSlicedAES64x32b, AESTowerField32b);
187	define_byte_sliced_test!(tests_64x16, ByteSlicedAES64x16b, AESTowerField16b);
188	define_byte_sliced_test!(tests_64x8, ByteSlicedAES64x8b, AESTowerField8b);
189
190	define_byte_sliced_test!(tests_16x64x8b, ByteSlicedAES16x64x8b, AESTowerField8b);
191	define_byte_sliced_test!(tests_8x64x8b, ByteSlicedAES8x64x8b, AESTowerField8b);
192	define_byte_sliced_test!(tests_4x64x8b, ByteSlicedAES4x64x8b, AESTowerField8b);
193	define_byte_sliced_test!(tests_2x64x8b, ByteSlicedAES2x64x8b, AESTowerField8b);
194}