binius_field/
packed.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3//! Traits for packed field elements which support SIMD implementations.
4//!
5//! Interfaces are derived from [`plonky2`](https://github.com/mir-protocol/plonky2).
6
7use std::{
8	fmt::Debug,
9	iter::{self, Product, Sum},
10	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14	iter::IterExtensions,
15	random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18use rand::RngCore;
19
20use super::{
21	Error,
22	arithmetic_traits::{Broadcast, MulAlpha, Square},
23	binary_field_arithmetic::TowerFieldArithmetic,
24};
25use crate::{
26	BinaryField, Field, PackedExtension, arithmetic_traits::InvertOrZero,
27	is_packed_field_indexable, underlier::WithUnderlier, unpack_if_possible_mut,
28};
29
30/// A packed field represents a vector of underlying field elements.
31///
32/// Arithmetic operations on packed field elements can be accelerated with SIMD CPU instructions.
33/// The vector width is a constant, `WIDTH`. This trait requires that the width must be a power of
34/// two.
35pub trait PackedField:
36	Default
37	+ Debug
38	+ Clone
39	+ Copy
40	+ Eq
41	+ Sized
42	+ Add<Output = Self>
43	+ Sub<Output = Self>
44	+ Mul<Output = Self>
45	+ AddAssign
46	+ SubAssign
47	+ MulAssign
48	+ Add<Self::Scalar, Output = Self>
49	+ Sub<Self::Scalar, Output = Self>
50	+ Mul<Self::Scalar, Output = Self>
51	+ AddAssign<Self::Scalar>
52	+ SubAssign<Self::Scalar>
53	+ MulAssign<Self::Scalar>
54	// TODO: Get rid of Sum and Product. It's confusing with nested impls of Packed.
55	+ Sum
56	+ Product
57	+ Send
58	+ Sync
59	+ Zeroable
60	+ 'static
61{
62	type Scalar: Field;
63
64	/// Base-2 logarithm of the number of field elements packed into one packed element.
65	const LOG_WIDTH: usize;
66
67	/// The number of field elements packed into one packed element.
68	///
69	/// WIDTH is guaranteed to equal 2^LOG_WIDTH.
70	const WIDTH: usize = 1 << Self::LOG_WIDTH;
71
72	/// Get the scalar at a given index without bounds checking.
73	/// # Safety
74	/// The caller must ensure that `i` is less than `WIDTH`.
75	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
76
77	/// Set the scalar at a given index without bounds checking.
78	/// # Safety
79	/// The caller must ensure that `i` is less than `WIDTH`.
80	unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
81
82	/// Get the scalar at a given index.
83	#[inline]
84	fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
85		(i < Self::WIDTH)
86			.then_some(unsafe { self.get_unchecked(i) })
87			.ok_or(Error::IndexOutOfRange {
88				index: i,
89				max: Self::WIDTH,
90			})
91	}
92
93	/// Set the scalar at a given index.
94	#[inline]
95	fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
96		(i < Self::WIDTH)
97			.then(|| unsafe { self.set_unchecked(i, scalar) })
98			.ok_or(Error::IndexOutOfRange {
99				index: i,
100				max: Self::WIDTH,
101			})
102	}
103
104	/// Get the scalar at a given index.
105	#[inline]
106	fn get(&self, i: usize) -> Self::Scalar {
107		self.get_checked(i).expect("index must be less than width")
108	}
109
110	/// Set the scalar at a given index.
111	#[inline]
112	fn set(&mut self, i: usize, scalar: Self::Scalar) {
113		self.set_checked(i, scalar).expect("index must be less than width")
114	}
115
116	#[inline]
117	fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
118		(0..Self::WIDTH).map_skippable(move |i|
119			// Safety: `i` is always less than `WIDTH`
120			unsafe { self.get_unchecked(i) })
121	}
122
123	#[inline]
124	fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
125		(0..Self::WIDTH).map_skippable(move |i|
126			// Safety: `i` is always less than `WIDTH`
127			unsafe { self.get_unchecked(i) })
128	}
129
130	#[inline]
131	fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
132		slice.iter().flat_map(Self::iter)
133	}
134
135	#[inline]
136	fn zero() -> Self {
137		Self::broadcast(Self::Scalar::ZERO)
138	}
139
140	#[inline]
141	fn one() -> Self {
142		Self::broadcast(Self::Scalar::ONE)
143	}
144
145	/// Initialize zero position with `scalar`, set other elements to zero.
146	#[inline(always)]
147	fn set_single(scalar: Self::Scalar) -> Self {
148		let mut result = Self::default();
149		result.set(0, scalar);
150
151		result
152	}
153
154	fn random(rng: impl RngCore) -> Self;
155	fn broadcast(scalar: Self::Scalar) -> Self;
156
157	/// Construct a packed field element from a function that returns scalar values by index.
158	fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
159
160	/// Creates a packed field from a fallible function applied to each index.
161	fn try_from_fn<E>(
162            mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
163        ) -> Result<Self, E> {
164            let mut result = Self::default();
165            for i in 0..Self::WIDTH {
166                let scalar = f(i)?;
167                unsafe {
168                    result.set_unchecked(i, scalar);
169                };
170            }
171            Ok(result)
172        }
173
174	/// Construct a packed field element from a sequence of scalars.
175	///
176	/// If the number of values in the sequence is less than the packing width, the remaining
177	/// elements are set to zero. If greater than the packing width, the excess elements are
178	/// ignored.
179	fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
180		let mut result = Self::default();
181		for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
182			result.set(i, val);
183		}
184		result
185	}
186
187	/// Returns the value multiplied by itself
188	fn square(self) -> Self;
189
190	/// Returns the value to the power `exp`.
191	fn pow(self, exp: u64) -> Self {
192		let mut res = Self::one();
193		for i in (0..64).rev() {
194			res = res.square();
195			if ((exp >> i) & 1) == 1 {
196				res.mul_assign(self)
197			}
198		}
199		res
200	}
201
202	/// Returns the packed inverse values or zeroes at indices where `self` is zero.
203	fn invert_or_zero(self) -> Self;
204
205	/// Interleaves blocks of this packed vector with another packed vector.
206	///
207	/// The operation can be seen as stacking the two vectors, dividing them into 2x2 matrices of
208	/// blocks, where each block is 2^`log_block_width` elements, and transposing the matrices.
209	///
210	/// Consider this example, where `LOG_WIDTH` is 3 and `log_block_len` is 1:
211	///     A = [a0, a1, a2, a3, a4, a5, a6, a7]
212	///     B = [b0, b1, b2, b3, b4, b5, b6, b7]
213	///
214	/// The interleaved result is
215	///     A' = [a0, a1, b0, b1, a4, a5, b4, b5]
216	///     B' = [a2, a3, b2, b3, a6, a7, b6, b7]
217	///
218	/// ## Preconditions
219	/// * `log_block_len` must be strictly less than `LOG_WIDTH`.
220	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
221
222	/// Unzips interleaved blocks of this packed vector with another packed vector.
223	/// 
224	/// Consider this example, where `LOG_WIDTH` is 3 and `log_block_len` is 1:
225	///    A = [a0, a1, b0, b1, a2, a3, b2, b3]
226	///    B = [a4, a5, b4, b5, a6, a7, b6, b7]
227	/// 
228	/// The transposed result is
229	///    A' = [a0, a1, a2, a3, a4, a5, a6, a7]
230	///    B' = [b0, b1, b2, b3, b4, b5, b6, b7]
231	///
232	/// ## Preconditions
233	/// * `log_block_len` must be strictly less than `LOG_WIDTH`.
234	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
235
236	/// Spread takes a block of elements within a packed field and repeats them to the full packing
237	/// width.
238	///
239	/// Spread can be seen as an extension of the functionality of [`Self::broadcast`].
240	///
241	/// ## Examples
242	///
243	/// ```
244	/// use binius_field::{BinaryField16b, PackedField, PackedBinaryField8x16b};
245	///
246	/// let input =
247	///     PackedBinaryField8x16b::from_scalars([0, 1, 2, 3, 4, 5, 6, 7].map(BinaryField16b::new));
248	/// assert_eq!(
249	///     input.spread(0, 5),
250	///     PackedBinaryField8x16b::from_scalars([5, 5, 5, 5, 5, 5, 5, 5].map(BinaryField16b::new))
251	/// );
252	/// assert_eq!(
253	///     input.spread(1, 2),
254	///     PackedBinaryField8x16b::from_scalars([4, 4, 4, 4, 5, 5, 5, 5].map(BinaryField16b::new))
255	/// );
256	/// assert_eq!(
257	///     input.spread(2, 1),
258	///     PackedBinaryField8x16b::from_scalars([4, 4, 5, 5, 6, 6, 7, 7].map(BinaryField16b::new))
259	/// );
260	/// assert_eq!(input.spread(3, 0), input);
261	/// ```
262	///
263	/// ## Preconditions
264	///
265	/// * `log_block_len` must be less than or equal to `LOG_WIDTH`.
266	/// * `block_idx` must be less than `2^(Self::LOG_WIDTH - log_block_len)`.
267	#[inline]
268	fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
269		assert!(log_block_len <= Self::LOG_WIDTH);
270		assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
271
272		// Safety: is guaranteed by the preconditions.
273		unsafe {
274			self.spread_unchecked(log_block_len, block_idx)
275		}
276	}
277
278	/// Unsafe version of [`Self::spread`].
279	///
280	/// # Safety
281	/// The caller must ensure that `log_block_len` is less than or equal to `LOG_WIDTH` and `block_idx` is less than `2^(Self::LOG_WIDTH - log_block_len)`.
282	#[inline]
283	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
284		let block_len = 1 << log_block_len;
285		let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
286
287		Self::from_scalars(
288			self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
289		)
290	}
291}
292
293/// Iterate over scalar values in a packed field slice.
294///
295/// The iterator skips the first `offset` elements. This is more efficient than skipping elements of
296/// the iterator returned.
297#[inline]
298pub fn iter_packed_slice_with_offset<P: PackedField>(
299	packed: &[P],
300	offset: usize,
301) -> impl Iterator<Item = P::Scalar> + '_ + Send {
302	let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
303		(&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
304	} else {
305		(&[], 0)
306	};
307
308	P::iter_slice(packed).skip(offset)
309}
310
311#[inline(always)]
312pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
313	assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
314
315	unsafe { get_packed_slice_unchecked(packed, i) }
316}
317
318/// Returns the scalar at the given index without bounds checking.
319/// # Safety
320/// The caller must ensure that `i` is less than `P::WIDTH * packed.len()`.
321#[inline(always)]
322pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
323	if is_packed_field_indexable::<P>() {
324		// Safety:
325		//  - We can safely cast the pointer to `P::Scalar` because `P` is `PackedFieldIndexable`
326		//  - `i` is guaranteed to be less than `len_packed_slice(packed)`
327		unsafe { *(packed.as_ptr() as *const P::Scalar).add(i) }
328	} else {
329		// Safety:
330		// - `i / P::WIDTH` is within the bounds of `packed` if `i` is less than
331		//   `len_packed_slice(packed)`
332		// - `i % P::WIDTH` is always less than `P::WIDTH
333		unsafe {
334			packed
335				.get_unchecked(i >> P::LOG_WIDTH)
336				.get_unchecked(i % P::WIDTH)
337		}
338	}
339}
340
341#[inline]
342pub fn get_packed_slice_checked<P: PackedField>(
343	packed: &[P],
344	i: usize,
345) -> Result<P::Scalar, Error> {
346	if i >> P::LOG_WIDTH < packed.len() {
347		// Safety: `i` is guaranteed to be less than `len_packed_slice(packed)`
348		Ok(unsafe { get_packed_slice_unchecked(packed, i) })
349	} else {
350		Err(Error::IndexOutOfRange {
351			index: i,
352			max: len_packed_slice(packed),
353		})
354	}
355}
356
357/// Sets the scalar at the given index without bounds checking.
358/// # Safety
359/// The caller must ensure that `i` is less than `P::WIDTH * packed.len()`.
360#[inline]
361pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
362	packed: &mut [P],
363	i: usize,
364	scalar: P::Scalar,
365) {
366	if is_packed_field_indexable::<P>() {
367		// Safety:
368		//  - We can safely cast the pointer to `P::Scalar` because `P` is `PackedFieldIndexable`
369		//  - `i` is guaranteed to be less than `len_packed_slice(packed)`
370		unsafe {
371			*(packed.as_mut_ptr() as *mut P::Scalar).add(i) = scalar;
372		}
373	} else {
374		// Safety: if `i` is less than `len_packed_slice(packed)`, then
375		// - `i / P::WIDTH` is within the bounds of `packed`
376		// - `i % P::WIDTH` is always less than `P::WIDTH
377		unsafe {
378			packed
379				.get_unchecked_mut(i >> P::LOG_WIDTH)
380				.set_unchecked(i % P::WIDTH, scalar)
381		}
382	}
383}
384
385#[inline]
386pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
387	assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
388
389	unsafe { set_packed_slice_unchecked(packed, i, scalar) }
390}
391
392#[inline]
393pub fn set_packed_slice_checked<P: PackedField>(
394	packed: &mut [P],
395	i: usize,
396	scalar: P::Scalar,
397) -> Result<(), Error> {
398	if i >> P::LOG_WIDTH < packed.len() {
399		// Safety: `i` is guaranteed to be less than `len_packed_slice(packed)`
400		unsafe { set_packed_slice_unchecked(packed, i, scalar) };
401		Ok(())
402	} else {
403		Err(Error::IndexOutOfRange {
404			index: i,
405			max: len_packed_slice(packed),
406		})
407	}
408}
409
410#[inline(always)]
411pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
412	packed.len() << P::LOG_WIDTH
413}
414
415/// Construct a packed field element from a function that returns scalar values by index with the
416/// given offset in packed elements. E.g. if `offset` is 2, and `WIDTH` is 4, `f(9)` will be used
417/// to set the scalar at index 1 in the packed element.
418#[inline]
419pub fn packed_from_fn_with_offset<P: PackedField>(
420	offset: usize,
421	mut f: impl FnMut(usize) -> P::Scalar,
422) -> P {
423	P::from_fn(|i| f(i + offset * P::WIDTH))
424}
425
426/// Multiply packed field element by a subfield scalar.
427pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
428	use crate::underlier::UnderlierType;
429
430	// This is a workaround not to make the multiplication slower in certain cases.
431	// TODO: implement efficient strategy to multiply packed field by a subfield scalar.
432	let subfield_bits = FS::Underlier::BITS;
433	let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
434
435	if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
436		P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
437	} else {
438		P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
439	}
440}
441
442/// Pack a slice of scalars into a vector of packed field elements.
443pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
444	scalars
445		.chunks(P::WIDTH)
446		.map(|chunk| P::from_scalars(chunk.iter().copied()))
447		.collect()
448}
449
450/// Copy scalar elements to a vector of packed field elements.
451pub fn copy_packed_from_scalars_slice<P: PackedField>(src: &[P::Scalar], dst: &mut [P]) {
452	unpack_if_possible_mut(
453		dst,
454		|scalars| {
455			scalars[0..src.len()].copy_from_slice(src);
456		},
457		|packed| {
458			let chunks = src.chunks_exact(P::WIDTH);
459			let remainder = chunks.remainder();
460			for (chunk, packed) in chunks.zip(packed.iter_mut()) {
461				*packed = P::from_scalars(chunk.iter().copied());
462			}
463
464			if !remainder.is_empty() {
465				let offset = (src.len() >> P::LOG_WIDTH) << P::LOG_WIDTH;
466				let packed = &mut packed[offset];
467				for (i, scalar) in remainder.iter().enumerate() {
468					// Safety: `i` is guaranteed to be less than `P::WIDTH`
469					unsafe { packed.set_unchecked(i, *scalar) };
470				}
471			}
472		},
473	);
474}
475
476/// A slice of packed field elements as a collection of scalars.
477#[derive(Clone)]
478pub struct PackedSlice<'a, P: PackedField> {
479	slice: &'a [P],
480	len: usize,
481}
482
483impl<'a, P: PackedField> PackedSlice<'a, P> {
484	#[inline(always)]
485	pub fn new(slice: &'a [P]) -> Self {
486		Self {
487			slice,
488			len: len_packed_slice(slice),
489		}
490	}
491
492	#[inline(always)]
493	pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
494		assert!(len <= len_packed_slice(slice));
495
496		Self { slice, len }
497	}
498}
499
500impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
501	#[inline(always)]
502	fn len(&self) -> usize {
503		self.len
504	}
505
506	#[inline(always)]
507	unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
508		unsafe { get_packed_slice_unchecked(self.slice, index) }
509	}
510}
511
512/// A mutable slice of packed field elements as a collection of scalars.
513pub struct PackedSliceMut<'a, P: PackedField> {
514	slice: &'a mut [P],
515	len: usize,
516}
517
518impl<'a, P: PackedField> PackedSliceMut<'a, P> {
519	#[inline(always)]
520	pub fn new(slice: &'a mut [P]) -> Self {
521		let len = len_packed_slice(slice);
522		Self { slice, len }
523	}
524
525	#[inline(always)]
526	pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
527		assert!(len <= len_packed_slice(slice));
528
529		Self { slice, len }
530	}
531}
532
533impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
534	#[inline(always)]
535	fn len(&self) -> usize {
536		self.len
537	}
538
539	#[inline(always)]
540	unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
541		unsafe { get_packed_slice_unchecked(self.slice, index) }
542	}
543}
544impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
545	#[inline(always)]
546	unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
547		unsafe { set_packed_slice_unchecked(self.slice, index, value) }
548	}
549}
550
551impl<F: Field> Broadcast<F> for F {
552	fn broadcast(scalar: F) -> Self {
553		scalar
554	}
555}
556
557impl<T: TowerFieldArithmetic> MulAlpha for T {
558	#[inline]
559	fn mul_alpha(self) -> Self {
560		<Self as TowerFieldArithmetic>::multiply_alpha(self)
561	}
562}
563
564impl<F: Field> PackedField for F {
565	type Scalar = F;
566
567	const LOG_WIDTH: usize = 0;
568
569	#[inline]
570	unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
571		*self
572	}
573
574	#[inline]
575	unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
576		*self = scalar;
577	}
578
579	#[inline]
580	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
581		iter::once(*self)
582	}
583
584	#[inline]
585	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
586		iter::once(self)
587	}
588
589	#[inline]
590	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
591		slice.iter().copied()
592	}
593
594	fn random(rng: impl RngCore) -> Self {
595		<Self as Field>::random(rng)
596	}
597
598	fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
599		panic!("cannot interleave when WIDTH = 1");
600	}
601
602	fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
603		panic!("cannot transpose when WIDTH = 1");
604	}
605
606	fn broadcast(scalar: Self::Scalar) -> Self {
607		scalar
608	}
609
610	fn square(self) -> Self {
611		<Self as Square>::square(self)
612	}
613
614	fn invert_or_zero(self) -> Self {
615		<Self as InvertOrZero>::invert_or_zero(self)
616	}
617
618	#[inline]
619	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
620		f(0)
621	}
622
623	#[inline]
624	unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
625		self
626	}
627}
628
629/// A helper trait to make the generic bounds shorter
630pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
631
632impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
633
634#[cfg(test)]
635mod tests {
636	use itertools::Itertools;
637	use rand::{
638		SeedableRng,
639		distributions::{Distribution, Uniform},
640		rngs::StdRng,
641	};
642
643	use super::*;
644	use crate::{
645		AESTowerField8b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField128b,
646		BinaryField1b, BinaryField2b, BinaryField4b, BinaryField8b, BinaryField16b, BinaryField32b,
647		BinaryField64b, BinaryField128b, BinaryField128bPolyval, PackedField,
648		arch::{
649			byte_sliced::*, packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*,
650			packed_32::*, packed_64::*, packed_128::*, packed_256::*, packed_512::*,
651			packed_aes_8::*, packed_aes_16::*, packed_aes_32::*, packed_aes_64::*,
652			packed_aes_128::*, packed_aes_256::*, packed_aes_512::*, packed_polyval_128::*,
653			packed_polyval_256::*, packed_polyval_512::*,
654		},
655	};
656
657	trait PackedFieldTest {
658		fn run<P: PackedField>(&self);
659	}
660
661	/// Run the test for all the packed fields defined in this crate.
662	fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
663		// canonical tower
664		test.run::<BinaryField1b>();
665		test.run::<BinaryField2b>();
666		test.run::<BinaryField4b>();
667		test.run::<BinaryField8b>();
668		test.run::<BinaryField16b>();
669		test.run::<BinaryField32b>();
670		test.run::<BinaryField64b>();
671		test.run::<BinaryField128b>();
672
673		// packed canonical tower
674		test.run::<PackedBinaryField1x1b>();
675		test.run::<PackedBinaryField2x1b>();
676		test.run::<PackedBinaryField1x2b>();
677		test.run::<PackedBinaryField4x1b>();
678		test.run::<PackedBinaryField2x2b>();
679		test.run::<PackedBinaryField1x4b>();
680		test.run::<PackedBinaryField8x1b>();
681		test.run::<PackedBinaryField4x2b>();
682		test.run::<PackedBinaryField2x4b>();
683		test.run::<PackedBinaryField1x8b>();
684		test.run::<PackedBinaryField16x1b>();
685		test.run::<PackedBinaryField8x2b>();
686		test.run::<PackedBinaryField4x4b>();
687		test.run::<PackedBinaryField2x8b>();
688		test.run::<PackedBinaryField1x16b>();
689		test.run::<PackedBinaryField32x1b>();
690		test.run::<PackedBinaryField16x2b>();
691		test.run::<PackedBinaryField8x4b>();
692		test.run::<PackedBinaryField4x8b>();
693		test.run::<PackedBinaryField2x16b>();
694		test.run::<PackedBinaryField1x32b>();
695		test.run::<PackedBinaryField64x1b>();
696		test.run::<PackedBinaryField32x2b>();
697		test.run::<PackedBinaryField16x4b>();
698		test.run::<PackedBinaryField8x8b>();
699		test.run::<PackedBinaryField4x16b>();
700		test.run::<PackedBinaryField2x32b>();
701		test.run::<PackedBinaryField1x64b>();
702		test.run::<PackedBinaryField128x1b>();
703		test.run::<PackedBinaryField64x2b>();
704		test.run::<PackedBinaryField32x4b>();
705		test.run::<PackedBinaryField16x8b>();
706		test.run::<PackedBinaryField8x16b>();
707		test.run::<PackedBinaryField4x32b>();
708		test.run::<PackedBinaryField2x64b>();
709		test.run::<PackedBinaryField1x128b>();
710		test.run::<PackedBinaryField256x1b>();
711		test.run::<PackedBinaryField128x2b>();
712		test.run::<PackedBinaryField64x4b>();
713		test.run::<PackedBinaryField32x8b>();
714		test.run::<PackedBinaryField16x16b>();
715		test.run::<PackedBinaryField8x32b>();
716		test.run::<PackedBinaryField4x64b>();
717		test.run::<PackedBinaryField2x128b>();
718		test.run::<PackedBinaryField512x1b>();
719		test.run::<PackedBinaryField256x2b>();
720		test.run::<PackedBinaryField128x4b>();
721		test.run::<PackedBinaryField64x8b>();
722		test.run::<PackedBinaryField32x16b>();
723		test.run::<PackedBinaryField16x32b>();
724		test.run::<PackedBinaryField8x64b>();
725		test.run::<PackedBinaryField4x128b>();
726
727		// AES tower
728		test.run::<AESTowerField8b>();
729		test.run::<AESTowerField16b>();
730		test.run::<AESTowerField32b>();
731		test.run::<AESTowerField64b>();
732		test.run::<AESTowerField128b>();
733
734		// packed AES tower
735		test.run::<PackedAESBinaryField1x8b>();
736		test.run::<PackedAESBinaryField2x8b>();
737		test.run::<PackedAESBinaryField1x16b>();
738		test.run::<PackedAESBinaryField4x8b>();
739		test.run::<PackedAESBinaryField2x16b>();
740		test.run::<PackedAESBinaryField1x32b>();
741		test.run::<PackedAESBinaryField8x8b>();
742		test.run::<PackedAESBinaryField4x16b>();
743		test.run::<PackedAESBinaryField2x32b>();
744		test.run::<PackedAESBinaryField1x64b>();
745		test.run::<PackedAESBinaryField16x8b>();
746		test.run::<PackedAESBinaryField8x16b>();
747		test.run::<PackedAESBinaryField4x32b>();
748		test.run::<PackedAESBinaryField2x64b>();
749		test.run::<PackedAESBinaryField1x128b>();
750		test.run::<PackedAESBinaryField32x8b>();
751		test.run::<PackedAESBinaryField16x16b>();
752		test.run::<PackedAESBinaryField8x32b>();
753		test.run::<PackedAESBinaryField4x64b>();
754		test.run::<PackedAESBinaryField2x128b>();
755		test.run::<PackedAESBinaryField64x8b>();
756		test.run::<PackedAESBinaryField32x16b>();
757		test.run::<PackedAESBinaryField16x32b>();
758		test.run::<PackedAESBinaryField8x64b>();
759		test.run::<PackedAESBinaryField4x128b>();
760
761		// Byte-sliced AES tower
762		test.run::<ByteSlicedAES16x128b>();
763		test.run::<ByteSlicedAES16x64b>();
764		test.run::<ByteSlicedAES2x16x64b>();
765		test.run::<ByteSlicedAES16x32b>();
766		test.run::<ByteSlicedAES4x16x32b>();
767		test.run::<ByteSlicedAES16x16b>();
768		test.run::<ByteSlicedAES8x16x16b>();
769		test.run::<ByteSlicedAES16x8b>();
770		test.run::<ByteSlicedAES16x16x8b>();
771
772		test.run::<ByteSliced16x128x1b>();
773		test.run::<ByteSliced8x128x1b>();
774		test.run::<ByteSliced4x128x1b>();
775		test.run::<ByteSliced2x128x1b>();
776		test.run::<ByteSliced1x128x1b>();
777
778		test.run::<ByteSlicedAES32x128b>();
779		test.run::<ByteSlicedAES32x64b>();
780		test.run::<ByteSlicedAES2x32x64b>();
781		test.run::<ByteSlicedAES32x32b>();
782		test.run::<ByteSlicedAES4x32x32b>();
783		test.run::<ByteSlicedAES32x16b>();
784		test.run::<ByteSlicedAES8x32x16b>();
785		test.run::<ByteSlicedAES32x8b>();
786		test.run::<ByteSlicedAES16x32x8b>();
787
788		test.run::<ByteSliced16x256x1b>();
789		test.run::<ByteSliced8x256x1b>();
790		test.run::<ByteSliced4x256x1b>();
791		test.run::<ByteSliced2x256x1b>();
792		test.run::<ByteSliced1x256x1b>();
793
794		test.run::<ByteSlicedAES64x128b>();
795		test.run::<ByteSlicedAES64x64b>();
796		test.run::<ByteSlicedAES2x64x64b>();
797		test.run::<ByteSlicedAES64x32b>();
798		test.run::<ByteSlicedAES4x64x32b>();
799		test.run::<ByteSlicedAES64x16b>();
800		test.run::<ByteSlicedAES8x64x16b>();
801		test.run::<ByteSlicedAES64x8b>();
802		test.run::<ByteSlicedAES16x64x8b>();
803
804		test.run::<ByteSliced16x512x1b>();
805		test.run::<ByteSliced8x512x1b>();
806		test.run::<ByteSliced4x512x1b>();
807		test.run::<ByteSliced2x512x1b>();
808		test.run::<ByteSliced1x512x1b>();
809
810		// polyval tower
811		test.run::<BinaryField128bPolyval>();
812
813		// packed polyval tower
814		test.run::<PackedBinaryPolyval1x128b>();
815		test.run::<PackedBinaryPolyval2x128b>();
816		test.run::<PackedBinaryPolyval4x128b>();
817	}
818
819	fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
820		let packed = P::random(&mut rng);
821		let mut iter = packed.iter();
822		for i in 0..P::WIDTH {
823			assert_eq!(packed.get(i), iter.next().unwrap());
824		}
825		assert!(iter.next().is_none());
826	}
827
828	fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
829		let packed = P::random(&mut rng);
830		let mut iter = packed.into_iter();
831		for i in 0..P::WIDTH {
832			assert_eq!(packed.get(i), iter.next().unwrap());
833		}
834		assert!(iter.next().is_none());
835	}
836
837	fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
838		for len in [0, 1, 5] {
839			let packed = std::iter::repeat_with(|| P::random(&mut rng))
840				.take(len)
841				.collect::<Vec<_>>();
842
843			let elements_count = len * P::WIDTH;
844			for offset in [
845				0,
846				1,
847				Uniform::new(0, elements_count.max(1)).sample(&mut rng),
848				elements_count.saturating_sub(1),
849				elements_count,
850			] {
851				let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
852				let expected = (offset..elements_count)
853					.map(|i| get_packed_slice(&packed, i))
854					.collect::<Vec<_>>();
855
856				assert_eq!(actual, expected);
857			}
858		}
859	}
860
861	struct PackedFieldIterationTest;
862
863	impl PackedFieldTest for PackedFieldIterationTest {
864		fn run<P: PackedField>(&self) {
865			let mut rng = StdRng::seed_from_u64(0);
866
867			check_value_iteration::<P>(&mut rng);
868			check_ref_iteration::<P>(&mut rng);
869			check_slice_iteration::<P>(&mut rng);
870		}
871	}
872
873	#[test]
874	fn test_iteration() {
875		run_for_all_packed_fields(&PackedFieldIterationTest);
876	}
877
878	fn check_copy_from_scalars<P: PackedField>(mut rng: impl RngCore) {
879		let scalars = (0..100)
880			.map(|_| <<P as PackedField>::Scalar as Field>::random(&mut rng))
881			.collect::<Vec<_>>();
882
883		let mut packed_copy = vec![P::zero(); 100];
884
885		for len in [0, 2, 4, 8, 12, 16] {
886			copy_packed_from_scalars_slice(&scalars[0..len], &mut packed_copy);
887
888			for (i, &scalar) in scalars[0..len].iter().enumerate() {
889				assert_eq!(get_packed_slice(&packed_copy, i), scalar);
890			}
891			for i in len..100 {
892				assert_eq!(get_packed_slice(&packed_copy, i), P::Scalar::ZERO);
893			}
894		}
895	}
896
897	#[test]
898	fn test_copy_from_scalars() {
899		let mut rng = StdRng::seed_from_u64(0);
900
901		check_copy_from_scalars::<PackedBinaryField16x8b>(&mut rng);
902		check_copy_from_scalars::<PackedBinaryField32x4b>(&mut rng);
903	}
904
905	fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
906		assert_eq!(collection.len(), expected.len());
907
908		for (i, v) in expected.iter().enumerate() {
909			assert_eq!(&collection.get(i), v);
910			assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
911		}
912	}
913
914	fn check_collection_get_set<F: Field>(
915		collection: &mut impl RandomAccessSequenceMut<F>,
916		r#gen: &mut impl FnMut() -> F,
917	) {
918		for i in 0..collection.len() {
919			let value = r#gen();
920			collection.set(i, value);
921			assert_eq!(collection.get(i), value);
922			assert_eq!(unsafe { collection.get_unchecked(i) }, value);
923		}
924	}
925
926	#[test]
927	fn check_packed_slice() {
928		let slice: &[PackedBinaryField16x8b] = &[];
929		let packed_slice = PackedSlice::new(slice);
930		check_collection(&packed_slice, &[]);
931		let packed_slice = PackedSlice::new_with_len(slice, 0);
932		check_collection(&packed_slice, &[]);
933
934		let mut rng = StdRng::seed_from_u64(0);
935		let slice: &[PackedBinaryField16x8b] = &[
936			PackedBinaryField16x8b::random(&mut rng),
937			PackedBinaryField16x8b::random(&mut rng),
938		];
939		let packed_slice = PackedSlice::new(slice);
940		check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
941
942		let packed_slice = PackedSlice::new_with_len(slice, 3);
943		check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
944	}
945
946	#[test]
947	fn check_packed_slice_mut() {
948		let mut rng = StdRng::seed_from_u64(0);
949		let mut r#gen = || <BinaryField8b as Field>::random(&mut rng);
950
951		let slice: &mut [PackedBinaryField16x8b] = &mut [];
952		let packed_slice = PackedSliceMut::new(slice);
953		check_collection(&packed_slice, &[]);
954		let packed_slice = PackedSliceMut::new_with_len(slice, 0);
955		check_collection(&packed_slice, &[]);
956
957		let mut rng = StdRng::seed_from_u64(0);
958		let slice: &mut [PackedBinaryField16x8b] = &mut [
959			PackedBinaryField16x8b::random(&mut rng),
960			PackedBinaryField16x8b::random(&mut rng),
961		];
962		let values = PackedField::iter_slice(slice).collect_vec();
963		let mut packed_slice = PackedSliceMut::new(slice);
964		check_collection(&packed_slice, &values);
965		check_collection_get_set(&mut packed_slice, &mut r#gen);
966
967		let values = PackedField::iter_slice(slice).collect_vec();
968		let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
969		check_collection(&packed_slice, &values[..3]);
970		check_collection_get_set(&mut packed_slice, &mut r#gen);
971	}
972}