1mod invert;
4mod multiply;
5mod packed_byte_sliced;
6mod square;
7mod underlier;
8
9pub use packed_byte_sliced::*;
10pub use underlier::ByteSlicedUnderlier;
11
12#[cfg(test)]
13pub mod tests {
14 use proptest::prelude::*;
15 use rand::{SeedableRng, distributions::Uniform, rngs::StdRng};
16
17 use super::*;
18 use crate::{
19 PackedAESBinaryField1x128b, PackedAESBinaryField2x64b, PackedAESBinaryField2x128b,
20 PackedAESBinaryField4x32b, PackedAESBinaryField4x64b, PackedAESBinaryField4x128b,
21 PackedAESBinaryField8x16b, PackedAESBinaryField8x32b, PackedAESBinaryField8x64b,
22 PackedAESBinaryField16x8b, PackedAESBinaryField16x16b, PackedAESBinaryField16x32b,
23 PackedAESBinaryField32x8b, PackedAESBinaryField32x16b, PackedAESBinaryField64x8b,
24 PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b, PackedField,
25 packed::{get_packed_slice, set_packed_slice},
26 underlier::WithUnderlier,
27 };
28
29 fn scalars_vec_strategy<P: PackedField<Scalar: WithUnderlier<Underlier: Arbitrary>>>()
30 -> impl Strategy<Value = Vec<P::Scalar>> {
31 proptest::collection::vec(
32 any::<<P::Scalar as WithUnderlier>::Underlier>().prop_map(P::Scalar::from_underlier),
33 P::WIDTH..=P::WIDTH,
34 )
35 }
36
37 macro_rules! define_byte_sliced_test {
38 ($module_name:ident, $name:ident, $scalar_type:ty, $associated_packed:ty) => {
39 mod $module_name {
40 use super::*;
41
42 proptest! {
43 #[test]
44 fn check_from_fn(scalar_elems in scalars_vec_strategy::<$name>()) {
45 let bytesliced = <$name>::from_fn(|i| scalar_elems[i]);
46 for i in 0..<$name>::WIDTH {
47 assert_eq!(scalar_elems[i], bytesliced.get(i));
48 }
49 }
50
51 #[test]
52 fn check_add(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
53 let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
54 let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
55
56 let bytesliced_result = bytesliced_a + bytesliced_b;
57
58 for i in 0..<$name>::WIDTH {
59 assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i));
60 }
61 }
62
63 #[test]
64 fn check_add_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
65 let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
66 let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
67
68 bytesliced_a += bytesliced_b;
69
70 for i in 0..<$name>::WIDTH {
71 assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i));
72 }
73 }
74
75 #[test]
76 fn check_sub(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
77 let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
78 let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
79
80 let bytesliced_result = bytesliced_a - bytesliced_b;
81
82 for i in 0..<$name>::WIDTH {
83 assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i));
84 }
85 }
86
87 #[test]
88 fn check_sub_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
89 let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
90 let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
91
92 bytesliced_a -= bytesliced_b;
93
94 for i in 0..<$name>::WIDTH {
95 assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i));
96 }
97 }
98
99 #[test]
100 fn check_mul(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
101 let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
102 let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
103
104 let bytesliced_result = bytesliced_a * bytesliced_b;
105
106 for i in 0..<$name>::WIDTH {
107 assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i));
108 }
109 }
110
111 #[test]
112 fn check_mul_assign(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
113 let mut bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
114 let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
115
116 bytesliced_a *= bytesliced_b;
117
118 for i in 0..<$name>::WIDTH {
119 assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i));
120 }
121 }
122
123 #[test]
124 fn check_inv(scalar_elems in scalars_vec_strategy::<$name>()) {
125 let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
126
127 let bytesliced_result = bytesliced.invert_or_zero();
128
129 for (i, scalar_elem) in scalar_elems.iter().enumerate() {
130 assert_eq!(scalar_elem.invert_or_zero(), bytesliced_result.get(i));
131 }
132 }
133
134 #[test]
135 fn check_square(scalar_elems in scalars_vec_strategy::<$name>()) {
136 let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
137
138 let bytesliced_result = bytesliced.square();
139
140 for (i, scalar_elem) in scalar_elems.iter().enumerate() {
141 assert_eq!(scalar_elem.square(), bytesliced_result.get(i));
142 }
143 }
144
145 #[test]
146 fn check_linear_transformation(scalar_elems in scalars_vec_strategy::<$name>()) {
147 use crate::linear_transformation::{PackedTransformationFactory, FieldLinearTransformation, Transformation};
148 use rand::{rngs::StdRng, SeedableRng};
149
150 let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
151
152 let linear_transformation = FieldLinearTransformation::random(StdRng::seed_from_u64(0));
153 let packed_transformation = <$name>::make_packed_transformation(linear_transformation.clone());
154
155 let bytesliced_result = packed_transformation.transform(&bytesliced);
156
157 for i in 0..<$name>::WIDTH {
158 assert_eq!(linear_transformation.transform(&scalar_elems[i]), bytesliced_result.get(i));
159 }
160 }
161
162 #[test]
163 fn check_interleave(scalar_elems_a in scalars_vec_strategy::<$name>(), scalar_elems_b in scalars_vec_strategy::<$name>()) {
164 let bytesliced_a = <$name>::from_scalars(scalar_elems_a.iter().copied());
165 let bytesliced_b = <$name>::from_scalars(scalar_elems_b.iter().copied());
166
167 for log_block_len in 0..<$name>::LOG_WIDTH {
168 let (bytesliced_c, bytesliced_d) = bytesliced_a.interleave(bytesliced_b, log_block_len);
169
170 let block_len = 1 << log_block_len;
171 for offset in (0..<$name>::WIDTH).step_by(2 * block_len) {
172 for i in 0..block_len {
173 assert_eq!(bytesliced_c.get(offset + i), scalar_elems_a[offset + i]);
174 assert_eq!(bytesliced_c.get(offset + block_len + i), scalar_elems_b[offset + i]);
175
176 assert_eq!(bytesliced_d.get(offset + i), scalar_elems_a[offset + block_len + i]);
177 assert_eq!(bytesliced_d.get(offset + block_len + i), scalar_elems_b[offset + block_len + i]);
178 }
179 }
180 }
181 }
182
183 #[test]
184 fn check_spread(scalar_elems in scalars_vec_strategy::<$name>()) {
185 let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
186
187 let mut rng = StdRng::seed_from_u64(0);
188 for log_block_len in 0..<$name>::LOG_WIDTH {
189 let distribution = Uniform::from(0..(1 << (<$name>::LOG_WIDTH - log_block_len)));
191 let block_index = rng.sample(distribution);
192
193 let bytesliced_result = bytesliced.spread(log_block_len, block_index);
194 let offset = block_index << log_block_len;
195 let repeat_count = 1 << (<$name>::LOG_WIDTH - log_block_len);
196
197 for i in 0..<$name>::WIDTH {
198 assert_eq!(bytesliced_result.get(i), scalar_elems[offset + i / repeat_count]);
199 }
200 }
201 }
202
203 #[test]
204 fn check_transpose_to(scalar_elems in scalars_vec_strategy::<$name>()) {
205 let bytesliced = <$name>::from_scalars(scalar_elems.iter().copied());
206 let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
207 bytesliced.transpose_to(&mut destination);
208
209 for i in 0..<$name>::WIDTH {
210 assert_eq!(scalar_elems[i], get_packed_slice(&destination, i));
211 }
212 }
213
214 #[test]
215 fn check_transpose_from(scalar_elems in scalars_vec_strategy::<$name>()) {
216 let mut destination = [<$associated_packed>::zero(); <$name>::HEIGHT_BYTES];
217 for i in 0..<$name>::WIDTH {
218 set_packed_slice(&mut destination, i, scalar_elems[i]);
219 }
220
221 let bytesliced = <$name>::transpose_from(&destination);
222
223 for i in 0..<$name>::WIDTH {
224 assert_eq!(get_packed_slice(&destination, i), bytesliced.get(i));
225 }
226 }
227 }
228 }
229 };
230 }
231
232 define_byte_sliced_test!(
234 tests_3d_16x128,
235 ByteSlicedAES16x128b,
236 AESTowerField128b,
237 PackedAESBinaryField1x128b
238 );
239 define_byte_sliced_test!(
240 tests_3d_16x64,
241 ByteSlicedAES16x64b,
242 AESTowerField64b,
243 PackedAESBinaryField2x64b
244 );
245 define_byte_sliced_test!(
246 tests_3d_2x16x64,
247 ByteSlicedAES2x16x64b,
248 AESTowerField64b,
249 PackedAESBinaryField2x64b
250 );
251 define_byte_sliced_test!(
252 tests_3d_16x32,
253 ByteSlicedAES16x32b,
254 AESTowerField32b,
255 PackedAESBinaryField4x32b
256 );
257 define_byte_sliced_test!(
258 tests_3d_4x16x32,
259 ByteSlicedAES4x16x32b,
260 AESTowerField32b,
261 PackedAESBinaryField4x32b
262 );
263 define_byte_sliced_test!(
264 tests_3d_16x16,
265 ByteSlicedAES16x16b,
266 AESTowerField16b,
267 PackedAESBinaryField8x16b
268 );
269 define_byte_sliced_test!(
270 tests_3d_8x16x16,
271 ByteSlicedAES8x16x16b,
272 AESTowerField16b,
273 PackedAESBinaryField8x16b
274 );
275 define_byte_sliced_test!(
276 tests_3d_16x8,
277 ByteSlicedAES16x8b,
278 AESTowerField8b,
279 PackedAESBinaryField16x8b
280 );
281 define_byte_sliced_test!(
282 tests_3d_16x16x8,
283 ByteSlicedAES16x16x8b,
284 AESTowerField8b,
285 PackedAESBinaryField16x8b
286 );
287
288 define_byte_sliced_test!(
289 tests_3d_16x128x1,
290 ByteSliced16x128x1b,
291 BinaryField1b,
292 PackedBinaryField128x1b
293 );
294 define_byte_sliced_test!(
295 tests_3d_8x128x1,
296 ByteSliced8x128x1b,
297 BinaryField1b,
298 PackedBinaryField128x1b
299 );
300 define_byte_sliced_test!(
301 tests_3d_4x128x1,
302 ByteSliced4x128x1b,
303 BinaryField1b,
304 PackedBinaryField128x1b
305 );
306 define_byte_sliced_test!(
307 tests_3d_2x128x1,
308 ByteSliced2x128x1b,
309 BinaryField1b,
310 PackedBinaryField128x1b
311 );
312 define_byte_sliced_test!(
313 tests_3d_1x128x1,
314 ByteSliced1x128x1b,
315 BinaryField1b,
316 PackedBinaryField128x1b
317 );
318
319 define_byte_sliced_test!(
321 tests_3d_32x128,
322 ByteSlicedAES32x128b,
323 AESTowerField128b,
324 PackedAESBinaryField2x128b
325 );
326 define_byte_sliced_test!(
327 tests_3d_32x64,
328 ByteSlicedAES32x64b,
329 AESTowerField64b,
330 PackedAESBinaryField4x64b
331 );
332 define_byte_sliced_test!(
333 tests_3d_2x32x64,
334 ByteSlicedAES2x32x64b,
335 AESTowerField64b,
336 PackedAESBinaryField4x64b
337 );
338 define_byte_sliced_test!(
339 tests_3d_32x32,
340 ByteSlicedAES32x32b,
341 AESTowerField32b,
342 PackedAESBinaryField8x32b
343 );
344 define_byte_sliced_test!(
345 tests_3d_4x32x32,
346 ByteSlicedAES4x32x32b,
347 AESTowerField32b,
348 PackedAESBinaryField8x32b
349 );
350 define_byte_sliced_test!(
351 tests_3d_32x16,
352 ByteSlicedAES32x16b,
353 AESTowerField16b,
354 PackedAESBinaryField16x16b
355 );
356 define_byte_sliced_test!(
357 tests_3d_8x32x16,
358 ByteSlicedAES8x32x16b,
359 AESTowerField16b,
360 PackedAESBinaryField16x16b
361 );
362 define_byte_sliced_test!(
363 tests_3d_32x8,
364 ByteSlicedAES32x8b,
365 AESTowerField8b,
366 PackedAESBinaryField32x8b
367 );
368 define_byte_sliced_test!(
369 tests_3d_16x32x8,
370 ByteSlicedAES16x32x8b,
371 AESTowerField8b,
372 PackedAESBinaryField32x8b
373 );
374
375 define_byte_sliced_test!(
376 tests_3d_16x256x1,
377 ByteSliced16x256x1b,
378 BinaryField1b,
379 PackedBinaryField256x1b
380 );
381 define_byte_sliced_test!(
382 tests_3d_8x256x1,
383 ByteSliced8x256x1b,
384 BinaryField1b,
385 PackedBinaryField256x1b
386 );
387 define_byte_sliced_test!(
388 tests_3d_4x256x1,
389 ByteSliced4x256x1b,
390 BinaryField1b,
391 PackedBinaryField256x1b
392 );
393 define_byte_sliced_test!(
394 tests_3d_2x256x1,
395 ByteSliced2x256x1b,
396 BinaryField1b,
397 PackedBinaryField256x1b
398 );
399 define_byte_sliced_test!(
400 tests_3d_1x256x1,
401 ByteSliced1x256x1b,
402 BinaryField1b,
403 PackedBinaryField256x1b
404 );
405
406 define_byte_sliced_test!(
408 tests_3d_64x128,
409 ByteSlicedAES64x128b,
410 AESTowerField128b,
411 PackedAESBinaryField4x128b
412 );
413 define_byte_sliced_test!(
414 tests_3d_64x64,
415 ByteSlicedAES64x64b,
416 AESTowerField64b,
417 PackedAESBinaryField8x64b
418 );
419 define_byte_sliced_test!(
420 tests_3d_2x64x64,
421 ByteSlicedAES2x64x64b,
422 AESTowerField64b,
423 PackedAESBinaryField8x64b
424 );
425 define_byte_sliced_test!(
426 tests_3d_64x32,
427 ByteSlicedAES64x32b,
428 AESTowerField32b,
429 PackedAESBinaryField16x32b
430 );
431 define_byte_sliced_test!(
432 tests_3d_4x64x32,
433 ByteSlicedAES4x64x32b,
434 AESTowerField32b,
435 PackedAESBinaryField16x32b
436 );
437 define_byte_sliced_test!(
438 tests_3d_64x16,
439 ByteSlicedAES64x16b,
440 AESTowerField16b,
441 PackedAESBinaryField32x16b
442 );
443 define_byte_sliced_test!(
444 tests_3d_8x64x16,
445 ByteSlicedAES8x64x16b,
446 AESTowerField16b,
447 PackedAESBinaryField32x16b
448 );
449 define_byte_sliced_test!(
450 tests_3d_64x8,
451 ByteSlicedAES64x8b,
452 AESTowerField8b,
453 PackedAESBinaryField64x8b
454 );
455 define_byte_sliced_test!(
456 tests_3d_16x64x8,
457 ByteSlicedAES16x64x8b,
458 AESTowerField8b,
459 PackedAESBinaryField64x8b
460 );
461
462 define_byte_sliced_test!(
463 tests_3d_16x512x1,
464 ByteSliced16x512x1b,
465 BinaryField1b,
466 PackedBinaryField512x1b
467 );
468 define_byte_sliced_test!(
469 tests_3d_8x512x1,
470 ByteSliced8x512x1b,
471 BinaryField1b,
472 PackedBinaryField512x1b
473 );
474 define_byte_sliced_test!(
475 tests_3d_4x512x1,
476 ByteSliced4x512x1b,
477 BinaryField1b,
478 PackedBinaryField512x1b
479 );
480 define_byte_sliced_test!(
481 tests_3d_2x512x1,
482 ByteSliced2x512x1b,
483 BinaryField1b,
484 PackedBinaryField512x1b
485 );
486 define_byte_sliced_test!(
487 tests_3d_1x512x1,
488 ByteSliced1x512x1b,
489 BinaryField1b,
490 PackedBinaryField512x1b
491 );
492}