Skip to main content

binius_field/
packed.rs

1// Copyright 2023-2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4//! Traits for packed field elements which support SIMD implementations.
5//!
6//! Interfaces are derived from [`plonky2`](https://github.com/mir-protocol/plonky2).
7
8use std::{
9	fmt::Debug,
10	iter,
11	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
12};
13
14use binius_utils::{
15	iter::IterExtensions,
16	random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
17};
18use bytemuck::Zeroable;
19
20use super::{PackedExtension, Random, arithmetic_traits::Square};
21use crate::{BinaryField, Divisible, Field, Maskable, WideMul, field::FieldOps};
22
23/// A packed field represents a vector of underlying field elements.
24///
25/// Arithmetic operations on packed field elements can be accelerated with SIMD CPU instructions.
26/// The vector width is a constant, `WIDTH`. This trait requires that the width must be a power of
27/// two.
28pub trait PackedField:
29	Default
30	+ Debug
31	+ Clone
32	+ Copy
33	+ Eq
34	+ Sized
35	+ FieldOps
36	+ Add<Self::Scalar, Output = Self>
37	+ Sub<Self::Scalar, Output = Self>
38	+ Mul<Self::Scalar, Output = Self>
39	+ AddAssign<Self::Scalar>
40	+ SubAssign<Self::Scalar>
41	+ MulAssign<Self::Scalar>
42	+ Send
43	+ Sync
44	+ Zeroable
45	+ Random
46	+ WideMul<Output: Debug + Send + Sync + 'static>
47	+ 'static
48	// A packed field divides into its `WIDTH` scalars. Scalar element access (`get`/`set` and
49	// their `_unchecked` variants), broadcast, and the scalar iterators are all provided by this
50	// supertrait.
51	+ Divisible<Self::Scalar>
52	// A packed field supports branchless per-lane masking over its scalars.
53	+ Maskable<Self::Scalar>
54{
55	/// Base-2 logarithm of the number of field elements packed into one packed element.
56	///
57	/// This is the number of scalars the packed field divides into, i.e. its `Divisible` log-count.
58	const LOG_WIDTH: usize = <Self as Divisible<Self::Scalar>>::LOG_N;
59
60	/// The number of field elements packed into one packed element.
61	///
62	/// WIDTH is guaranteed to equal 2^LOG_WIDTH.
63	const WIDTH: usize = 1 << Self::LOG_WIDTH;
64
65	#[inline]
66	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
67		(0..Self::WIDTH).map_skippable(move |i|
68			// Safety: `i` is always less than `WIDTH`
69			unsafe { self.get_unchecked(i) })
70	}
71
72	#[inline]
73	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
74		(0..Self::WIDTH).map_skippable(move |i|
75			// Safety: `i` is always less than `WIDTH`
76			unsafe { self.get_unchecked(i) })
77	}
78
79	#[inline]
80	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
81		slice.iter().flat_map(Self::iter)
82	}
83
84	/// Initialize zero position with `scalar`, set other elements to zero.
85	#[inline(always)]
86	fn set_single(scalar: Self::Scalar) -> Self {
87		let mut result = Self::default();
88		result.set(0, scalar);
89		result
90	}
91
92	/// Construct a packed field element from a function that returns scalar values by index.
93	fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
94
95	/// Creates a packed field from a fallible function applied to each index.
96	fn try_from_fn<E>(mut f: impl FnMut(usize) -> Result<Self::Scalar, E>) -> Result<Self, E> {
97		let mut result = Self::default();
98		for i in 0..Self::WIDTH {
99			let scalar = f(i)?;
100			unsafe {
101				result.set_unchecked(i, scalar);
102			};
103		}
104		Ok(result)
105	}
106
107	/// Construct a packed field element from a sequence of scalars.
108	///
109	/// If the number of values in the sequence is less than the packing width, the remaining
110	/// elements are set to zero. If greater than the packing width, the excess elements are
111	/// ignored.
112	#[inline]
113	fn from_scalars(values: impl IntoIterator<Item = Self::Scalar>) -> Self {
114		let mut result = Self::default();
115		for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
116			result.set(i, val);
117		}
118		result
119	}
120
121	/// Returns the value to the power `exp`.
122	fn pow(self, exp: u64) -> Self {
123		let mut res = Self::one();
124		for i in (0..64).rev() {
125			res = Square::square(res);
126			if ((exp >> i) & 1) == 1 {
127				res.mul_assign(self)
128			}
129		}
130		res
131	}
132
133	/// Interleaves blocks of this packed vector with another packed vector.
134	///
135	/// The operation can be seen as stacking the two vectors, dividing them into 2x2 matrices of
136	/// blocks, where each block is 2^`log_block_width` elements, and transposing the matrices.
137	///
138	/// Consider this example, where `LOG_WIDTH` is 3 and `log_block_len` is 1:
139	///     A = [a0, a1, a2, a3, a4, a5, a6, a7]
140	///     B = [b0, b1, b2, b3, b4, b5, b6, b7]
141	///
142	/// The interleaved result is
143	///     A' = [a0, a1, b0, b1, a4, a5, b4, b5]
144	///     B' = [a2, a3, b2, b3, a6, a7, b6, b7]
145	///
146	/// ## Preconditions
147	/// * `log_block_len` must be strictly less than `LOG_WIDTH`.
148	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
149
150	/// Unzips interleaved blocks of this packed vector with another packed vector.
151	///
152	/// Consider this example, where `LOG_WIDTH` is 3 and `log_block_len` is 1:
153	///    A = [a0, a1, b0, b1, a2, a3, b2, b3]
154	///    B = [a4, a5, b4, b5, a6, a7, b6, b7]
155	///
156	/// The transposed result is
157	///    A' = [a0, a1, a2, a3, a4, a5, a6, a7]
158	///    B' = [b0, b1, b2, b3, b4, b5, b6, b7]
159	///
160	/// ## Preconditions
161	/// * `log_block_len` must be strictly less than `LOG_WIDTH`.
162	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
163
164	/// Spread takes a block of elements within a packed field and repeats them to the full packing
165	/// width.
166	///
167	/// Spread can be seen as an extension of the functionality of [`Divisible::broadcast`].
168	///
169	/// ## Examples
170	///
171	/// ```
172	/// use binius_field::{BinaryField1b, PackedField, PackedBinaryField8x1b};
173	///
174	/// let input =
175	///     PackedBinaryField8x1b::from_scalars([0, 1, 0, 1, 0, 1, 0, 1].map(BinaryField1b::from));
176	/// assert_eq!(
177	///     input.spread(0, 1),
178	///     PackedBinaryField8x1b::from_scalars([1, 1, 1, 1, 1, 1, 1, 1].map(BinaryField1b::from))
179	/// );
180	/// assert_eq!(
181	///     input.spread(1, 0),
182	///     PackedBinaryField8x1b::from_scalars([0, 0, 0, 0, 1, 1, 1, 1].map(BinaryField1b::from))
183	/// );
184	/// assert_eq!(
185	///     input.spread(2, 0),
186	///     PackedBinaryField8x1b::from_scalars([0, 0, 1, 1, 0, 0, 1, 1].map(BinaryField1b::from))
187	/// );
188	/// assert_eq!(input.spread(3, 0), input);
189	/// ```
190	///
191	/// ## Preconditions
192	///
193	/// * `log_block_len` must be less than or equal to `LOG_WIDTH`.
194	/// * `block_idx` must be less than `2^(Self::LOG_WIDTH - log_block_len)`.
195	#[inline]
196	fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
197		assert!(log_block_len <= Self::LOG_WIDTH);
198		assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
199
200		// Safety: is guaranteed by the preconditions.
201		unsafe { self.spread_unchecked(log_block_len, block_idx) }
202	}
203
204	/// Unsafe version of [`Self::spread`].
205	///
206	/// # Safety
207	/// The caller must ensure that `log_block_len` is less than or equal to `LOG_WIDTH` and
208	/// `block_idx` is less than `2^(Self::LOG_WIDTH - log_block_len)`.
209	#[inline]
210	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
211		let block_len = 1 << log_block_len;
212		let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
213
214		Self::from_scalars(
215			self.iter()
216				.skip(block_idx * block_len)
217				.take(block_len)
218				.flat_map(|elem| iter::repeat_n(elem, repeat)),
219		)
220	}
221}
222
223/// Iterate over scalar values in a packed field slice.
224///
225/// The iterator skips the first `offset` elements. This is more efficient than skipping elements of
226/// the iterator returned.
227#[inline]
228pub fn iter_packed_slice_with_offset<P: PackedField>(
229	packed: &[P],
230	offset: usize,
231) -> impl Iterator<Item = P::Scalar> + '_ + Send {
232	let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
233		(&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
234	} else {
235		(&[], 0)
236	};
237
238	P::iter_slice(packed).skip(offset)
239}
240
241#[inline(always)]
242pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
243	assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
244
245	unsafe { get_packed_slice_unchecked(packed, i) }
246}
247
248/// Returns the scalar at the given index without bounds checking.
249/// # Safety
250/// The caller must ensure that `i` is less than `P::WIDTH * packed.len()`.
251#[inline(always)]
252pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
253	// TODO: Consider putting a get_in_slice method on Divisible
254
255	// Safety:
256	// - `i / P::WIDTH` is within the bounds of `packed` if `i` is less than
257	//   `len_packed_slice(packed)`
258	// - `i % P::WIDTH` is always less than `P::WIDTH
259	unsafe {
260		packed
261			.get_unchecked(i >> P::LOG_WIDTH)
262			.get_unchecked(i % P::WIDTH)
263	}
264}
265
266/// Sets the scalar at the given index without bounds checking.
267/// # Safety
268/// The caller must ensure that `i` is less than `P::WIDTH * packed.len()`.
269#[inline]
270pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
271	packed: &mut [P],
272	i: usize,
273	scalar: P::Scalar,
274) {
275	// TODO: Consider putting a set_in_slice method on Divisible
276
277	// Safety: if `i` is less than `len_packed_slice(packed)`, then
278	// - `i / P::WIDTH` is within the bounds of `packed`
279	// - `i % P::WIDTH` is always less than `P::WIDTH
280	unsafe {
281		packed
282			.get_unchecked_mut(i >> P::LOG_WIDTH)
283			.set_unchecked(i % P::WIDTH, scalar)
284	}
285}
286
287#[inline]
288pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
289	assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
290
291	unsafe { set_packed_slice_unchecked(packed, i, scalar) }
292}
293
294#[inline(always)]
295pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
296	packed.len() << P::LOG_WIDTH
297}
298
299/// Construct a packed field element from a function that returns scalar values by index with the
300/// given offset in packed elements. E.g. if `offset` is 2, and `WIDTH` is 4, `f(9)` will be used
301/// to set the scalar at index 1 in the packed element.
302#[inline]
303pub fn packed_from_fn_with_offset<P: PackedField>(
304	offset: usize,
305	mut f: impl FnMut(usize) -> P::Scalar,
306) -> P {
307	P::from_fn(|i| f(i + offset * P::WIDTH))
308}
309
310/// Multiply packed field element by a subfield scalar.
311pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
312	P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
313}
314
315/// Pack a slice of scalars into a vector of packed field elements.
316pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
317	scalars
318		.chunks(P::WIDTH)
319		.map(|chunk| P::from_scalars(chunk.iter().copied()))
320		.collect()
321}
322
323/// A slice of packed field elements as a collection of scalars.
324#[derive(Clone)]
325pub struct PackedSlice<'a, P: PackedField> {
326	slice: &'a [P],
327	len: usize,
328}
329
330impl<'a, P: PackedField> PackedSlice<'a, P> {
331	#[inline(always)]
332	pub const fn new(slice: &'a [P]) -> Self {
333		Self {
334			slice,
335			len: len_packed_slice(slice),
336		}
337	}
338
339	#[inline(always)]
340	pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
341		assert!(len <= len_packed_slice(slice));
342
343		Self { slice, len }
344	}
345}
346
347impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
348	#[inline(always)]
349	fn len(&self) -> usize {
350		self.len
351	}
352
353	#[inline(always)]
354	unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
355		unsafe { get_packed_slice_unchecked(self.slice, index) }
356	}
357}
358
359/// A mutable slice of packed field elements as a collection of scalars.
360pub struct PackedSliceMut<'a, P: PackedField> {
361	slice: &'a mut [P],
362	len: usize,
363}
364
365impl<'a, P: PackedField> PackedSliceMut<'a, P> {
366	#[inline(always)]
367	pub const fn new(slice: &'a mut [P]) -> Self {
368		let len = len_packed_slice(slice);
369		Self { slice, len }
370	}
371
372	#[inline(always)]
373	pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
374		assert!(len <= len_packed_slice(slice));
375
376		Self { slice, len }
377	}
378}
379
380impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
381	#[inline(always)]
382	fn len(&self) -> usize {
383		self.len
384	}
385
386	#[inline(always)]
387	unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
388		unsafe { get_packed_slice_unchecked(self.slice, index) }
389	}
390}
391impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
392	#[inline(always)]
393	unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
394		unsafe { set_packed_slice_unchecked(self.slice, index, value) }
395	}
396}
397
398impl<F: Field> PackedField for F {
399	// LOG_WIDTH defaults to `<Self as Divisible<Self>>::LOG_N`, which is 0 for a scalar field.
400	// Scalar element access (`get_unchecked`/`set_unchecked`) is provided by the reflexive
401	// `Divisible<Self>` impl.
402
403	#[inline]
404	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
405		iter::once(*self)
406	}
407
408	#[inline]
409	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
410		iter::once(self)
411	}
412
413	#[inline]
414	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
415		slice.iter().copied()
416	}
417
418	fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
419		panic!("cannot interleave when WIDTH = 1");
420	}
421
422	fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
423		panic!("cannot transpose when WIDTH = 1");
424	}
425
426	#[inline]
427	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
428		f(0)
429	}
430
431	#[inline]
432	unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
433		self
434	}
435}
436
437/// A helper trait to make the generic bounds shorter
438pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
439
440impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
441
442#[cfg(test)]
443mod tests {
444	use itertools::Itertools;
445	use rand::prelude::*;
446
447	use super::*;
448	use crate::{
449		AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedAESBinaryField1x8b,
450		PackedAESBinaryField16x8b, PackedAESBinaryField32x8b, PackedAESBinaryField64x8b,
451		PackedBinaryField1x1b, PackedBinaryField2x1b, PackedBinaryField4x1b, PackedBinaryField8x1b,
452		PackedBinaryField16x1b, PackedBinaryField32x1b, PackedBinaryField64x1b,
453		PackedBinaryField128x1b, PackedBinaryField256x1b, PackedBinaryField512x1b,
454		PackedBinaryGhash1x128b, PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
455	};
456
457	trait PackedFieldTest {
458		fn run<P: PackedField>(&self);
459	}
460
461	/// Run the test for all the packed fields defined in this crate.
462	fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
463		// B1
464		test.run::<BinaryField1b>();
465		test.run::<PackedBinaryField1x1b>();
466		test.run::<PackedBinaryField2x1b>();
467		test.run::<PackedBinaryField4x1b>();
468		test.run::<PackedBinaryField8x1b>();
469		test.run::<PackedBinaryField16x1b>();
470		test.run::<PackedBinaryField32x1b>();
471		test.run::<PackedBinaryField64x1b>();
472		test.run::<PackedBinaryField128x1b>();
473		test.run::<PackedBinaryField256x1b>();
474		test.run::<PackedBinaryField512x1b>();
475
476		// AES
477		test.run::<AESTowerField8b>();
478		test.run::<PackedAESBinaryField1x8b>();
479		test.run::<PackedAESBinaryField16x8b>();
480		test.run::<PackedAESBinaryField32x8b>();
481		test.run::<PackedAESBinaryField64x8b>();
482
483		// GHASH
484		test.run::<BinaryField128bGhash>();
485		test.run::<PackedBinaryGhash1x128b>();
486		test.run::<PackedBinaryGhash2x128b>();
487		test.run::<PackedBinaryGhash4x128b>();
488	}
489
490	fn check_value_iteration<P: PackedField>(mut rng: impl Rng) {
491		let packed = P::random(&mut rng);
492		let mut iter = packed.iter();
493		for i in 0..P::WIDTH {
494			assert_eq!(packed.get(i), iter.next().unwrap());
495		}
496		assert!(iter.next().is_none());
497	}
498
499	fn check_ref_iteration<P: PackedField>(mut rng: impl Rng) {
500		let packed = P::random(&mut rng);
501		let mut iter = packed.into_iter();
502		for i in 0..P::WIDTH {
503			assert_eq!(packed.get(i), iter.next().unwrap());
504		}
505		assert!(iter.next().is_none());
506	}
507
508	fn check_slice_iteration<P: PackedField>(mut rng: impl Rng) {
509		for len in [0, 1, 5] {
510			let packed = std::iter::repeat_with(|| P::random(&mut rng))
511				.take(len)
512				.collect::<Vec<_>>();
513
514			let elements_count = len * P::WIDTH;
515			for offset in [
516				0,
517				1,
518				rng.random_range(0..elements_count.max(1)),
519				elements_count.saturating_sub(1),
520				elements_count,
521			] {
522				let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
523				let expected = (offset..elements_count)
524					.map(|i| get_packed_slice(&packed, i))
525					.collect::<Vec<_>>();
526
527				assert_eq!(actual, expected);
528			}
529		}
530	}
531
532	struct PackedFieldIterationTest;
533
534	impl PackedFieldTest for PackedFieldIterationTest {
535		fn run<P: PackedField>(&self) {
536			let mut rng = StdRng::seed_from_u64(0);
537
538			check_value_iteration::<P>(&mut rng);
539			check_ref_iteration::<P>(&mut rng);
540			check_slice_iteration::<P>(&mut rng);
541		}
542	}
543
544	#[test]
545	fn test_iteration() {
546		run_for_all_packed_fields(&PackedFieldIterationTest);
547	}
548
549	fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
550		assert_eq!(collection.len(), expected.len());
551
552		for (i, v) in expected.iter().enumerate() {
553			assert_eq!(&collection.get(i), v);
554			assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
555		}
556	}
557
558	fn check_collection_get_set<F: Field>(
559		collection: &mut impl RandomAccessSequenceMut<F>,
560		random: &mut impl FnMut() -> F,
561	) {
562		for i in 0..collection.len() {
563			let value = random();
564			collection.set(i, value);
565			assert_eq!(collection.get(i), value);
566			assert_eq!(unsafe { collection.get_unchecked(i) }, value);
567		}
568	}
569
570	#[test]
571	fn check_packed_slice() {
572		let slice: &[PackedAESBinaryField16x8b] = &[];
573		let packed_slice = PackedSlice::new(slice);
574		check_collection(&packed_slice, &[]);
575		let packed_slice = PackedSlice::new_with_len(slice, 0);
576		check_collection(&packed_slice, &[]);
577
578		let mut rng = StdRng::seed_from_u64(0);
579		let slice: &[PackedAESBinaryField16x8b] = &[
580			PackedAESBinaryField16x8b::random(&mut rng),
581			PackedAESBinaryField16x8b::random(&mut rng),
582		];
583		let packed_slice = PackedSlice::new(slice);
584		check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
585
586		let packed_slice = PackedSlice::new_with_len(slice, 3);
587		check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
588	}
589
590	#[test]
591	fn check_packed_slice_mut() {
592		let mut rng = StdRng::seed_from_u64(0);
593		let mut random = || AESTowerField8b::random(&mut rng);
594
595		let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
596		let packed_slice = PackedSliceMut::new(slice);
597		check_collection(&packed_slice, &[]);
598		let packed_slice = PackedSliceMut::new_with_len(slice, 0);
599		check_collection(&packed_slice, &[]);
600
601		let mut rng = StdRng::seed_from_u64(0);
602		let slice: &mut [PackedAESBinaryField16x8b] = &mut [
603			PackedAESBinaryField16x8b::random(&mut rng),
604			PackedAESBinaryField16x8b::random(&mut rng),
605		];
606		let values = PackedField::iter_slice(slice).collect_vec();
607		let mut packed_slice = PackedSliceMut::new(slice);
608		check_collection(&packed_slice, &values);
609		check_collection_get_set(&mut packed_slice, &mut random);
610
611		let values = PackedField::iter_slice(slice).collect_vec();
612		let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
613		check_collection(&packed_slice, &values[..3]);
614		check_collection_get_set(&mut packed_slice, &mut random);
615	}
616}