Skip to main content

binius_field/arch/portable/
packed.rs

1// Copyright 2024-2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4// This is because derive(bytemuck::TransparentWrapper) adds some type constraints to
5// PackedPrimitiveType in addition to the type constraints we define. Even more, annoying, the
6// allow attribute has to be added to the module, it doesn't work to add it to the struct
7// definition.
8#![allow(clippy::multiple_bound_locations)]
9
10use std::{
11	fmt::Debug,
12	iter::{Product, Sum},
13	marker::PhantomData,
14	ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
15};
16
17use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
18use bytemuck::{Pod, TransparentWrapper, Zeroable};
19use rand::{
20	Rng,
21	distr::{Distribution, StandardUniform},
22};
23
24use crate::{
25	BinaryField, Divisible, ExtensionField, Field, PackedField, WideMul,
26	arithmetic_traits::{InvertOrZero, Square},
27	field::FieldOps,
28	underlier::{NumCast, UnderlierType, UnderlierWithBitOps, WithUnderlier},
29};
30
31#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
32#[repr(transparent)]
33#[transparent(U)]
34pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
35	pub U,
36	pub PhantomData<Scalar>,
37);
38
39impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
40	pub const WIDTH: usize = {
41		assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
42
43		U::BITS / Scalar::N_BITS
44	};
45
46	pub const LOG_WIDTH: usize = {
47		let result = Self::WIDTH.ilog2();
48
49		assert!(2usize.pow(result) == Self::WIDTH);
50
51		result as usize
52	};
53
54	#[inline]
55	pub const fn from_underlier(val: U) -> Self {
56		Self(val, PhantomData)
57	}
58
59	#[inline]
60	pub const fn to_underlier(self) -> U {
61		self.0
62	}
63}
64
65impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
66	PackedPrimitiveType<U, Scalar>
67{
68	#[inline]
69	pub fn broadcast(scalar: Scalar) -> Self {
70		U::broadcast_subvalue(scalar.to_underlier()).into()
71	}
72}
73
74unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
75	for PackedPrimitiveType<U, Scalar>
76{
77	type Underlier = U;
78
79	#[inline(always)]
80	fn to_underlier(self) -> Self::Underlier {
81		TransparentWrapper::peel(self)
82	}
83
84	#[inline(always)]
85	fn to_underlier_ref(&self) -> &Self::Underlier {
86		TransparentWrapper::peel_ref(self)
87	}
88
89	#[inline(always)]
90	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
91		TransparentWrapper::peel_mut(self)
92	}
93
94	#[inline(always)]
95	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
96		TransparentWrapper::peel_slice(val)
97	}
98
99	#[inline(always)]
100	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
101		TransparentWrapper::peel_slice_mut(val)
102	}
103
104	#[inline(always)]
105	fn from_underlier(val: Self::Underlier) -> Self {
106		TransparentWrapper::wrap(val)
107	}
108
109	#[inline(always)]
110	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
111		TransparentWrapper::wrap_ref(val)
112	}
113
114	#[inline(always)]
115	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
116		TransparentWrapper::wrap_mut(val)
117	}
118
119	#[inline(always)]
120	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
121		TransparentWrapper::wrap_slice(val)
122	}
123
124	#[inline(always)]
125	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
126		TransparentWrapper::wrap_slice_mut(val)
127	}
128}
129
130// Note: this bound is deliberately phrased in terms of `Divisible` rather than `Self:
131// PackedField`. `PackedField` now has `WideMul<Output: Debug>` as a parent trait, and the trivial
132// `WideMul` impl sets `Output = Self`, so a `Self: PackedField` bound here would make `Debug`
133// depend on itself and overflow trait resolution.
134impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Debug
135	for PackedPrimitiveType<U, Scalar>
136{
137	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138		let width = checked_int_div(U::BITS, Scalar::N_BITS);
139		let values_str = (0..width)
140			// Safety: `i` ranges over `0..width`, the number of scalars packed in `U`.
141			.map(|i| Scalar::from_underlier(unsafe { self.0.get_subvalue(i) }))
142			.map(|value| format!("{value}"))
143			.collect::<Vec<_>>()
144			.join(",");
145
146		write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
147	}
148}
149
150impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
151	#[inline]
152	fn from(val: U) -> Self {
153		Self(val, PhantomData)
154	}
155}
156
157impl<U: UnderlierWithBitOps, Scalar: BinaryField> Neg for PackedPrimitiveType<U, Scalar> {
158	type Output = Self;
159
160	#[inline]
161	fn neg(self) -> Self::Output {
162		self
163	}
164}
165
166impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
167	type Output = Self;
168
169	#[inline]
170	#[allow(clippy::suspicious_arithmetic_impl)]
171	fn add(self, rhs: Self) -> Self::Output {
172		(self.0 ^ rhs.0).into()
173	}
174}
175
176impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add<&Self> for PackedPrimitiveType<U, Scalar> {
177	type Output = Self;
178
179	#[inline]
180	#[allow(clippy::suspicious_arithmetic_impl)]
181	fn add(self, rhs: &Self) -> Self::Output {
182		(self.0 ^ rhs.0).into()
183	}
184}
185
186impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
187	type Output = Self;
188
189	#[inline]
190	#[allow(clippy::suspicious_arithmetic_impl)]
191	fn sub(self, rhs: Self) -> Self::Output {
192		(self.0 ^ rhs.0).into()
193	}
194}
195
196impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub<&Self> for PackedPrimitiveType<U, Scalar> {
197	type Output = Self;
198
199	#[inline]
200	#[allow(clippy::suspicious_arithmetic_impl)]
201	fn sub(self, rhs: &Self) -> Self::Output {
202		(self.0 ^ rhs.0).into()
203	}
204}
205
206impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
207where
208	Self: Add<Output = Self>,
209{
210	fn add_assign(&mut self, rhs: Self) {
211		*self = *self + rhs;
212	}
213}
214
215impl<U: UnderlierType, Scalar: BinaryField> AddAssign<&Self> for PackedPrimitiveType<U, Scalar>
216where
217	Self: for<'a> Add<&'a Self, Output = Self>,
218{
219	fn add_assign(&mut self, rhs: &Self) {
220		*self = *self + rhs;
221	}
222}
223
224impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
225where
226	Self: Sub<Output = Self>,
227{
228	fn sub_assign(&mut self, rhs: Self) {
229		*self = *self - rhs;
230	}
231}
232
233impl<U: UnderlierType, Scalar: BinaryField> SubAssign<&Self> for PackedPrimitiveType<U, Scalar>
234where
235	Self: for<'a> Sub<&'a Self, Output = Self>,
236{
237	fn sub_assign(&mut self, rhs: &Self) {
238		*self = *self - rhs;
239	}
240}
241
242impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
243where
244	Self: Mul<Output = Self>,
245{
246	fn mul_assign(&mut self, rhs: Self) {
247		*self = *self * rhs;
248	}
249}
250
251impl<U: UnderlierType, Scalar: BinaryField> MulAssign<&Self> for PackedPrimitiveType<U, Scalar>
252where
253	Self: for<'a> Mul<&'a Self, Output = Self>,
254{
255	fn mul_assign(&mut self, rhs: &Self) {
256		*self = *self * rhs;
257	}
258}
259
260impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Add<Scalar>
261	for PackedPrimitiveType<U, Scalar>
262{
263	type Output = Self;
264
265	fn add(self, rhs: Scalar) -> Self::Output {
266		self + Self::broadcast(rhs)
267	}
268}
269
270impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Sub<Scalar>
271	for PackedPrimitiveType<U, Scalar>
272{
273	type Output = Self;
274
275	fn sub(self, rhs: Scalar) -> Self::Output {
276		self - Self::broadcast(rhs)
277	}
278}
279
280impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Mul<Scalar>
281	for PackedPrimitiveType<U, Scalar>
282where
283	Self: Mul<Output = Self>,
284{
285	type Output = Self;
286
287	fn mul(self, rhs: Scalar) -> Self::Output {
288		self * Self::broadcast(rhs)
289	}
290}
291
292impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> AddAssign<Scalar>
293	for PackedPrimitiveType<U, Scalar>
294{
295	fn add_assign(&mut self, rhs: Scalar) {
296		*self += Self::broadcast(rhs);
297	}
298}
299
300impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> SubAssign<Scalar>
301	for PackedPrimitiveType<U, Scalar>
302{
303	fn sub_assign(&mut self, rhs: Scalar) {
304		*self -= Self::broadcast(rhs);
305	}
306}
307
308impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> MulAssign<Scalar>
309	for PackedPrimitiveType<U, Scalar>
310where
311	Self: MulAssign<Self>,
312{
313	fn mul_assign(&mut self, rhs: Scalar) {
314		*self *= Self::broadcast(rhs);
315	}
316}
317
318impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
319where
320	Self: Add<Output = Self>,
321{
322	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
323		iter.fold(Self::from(U::default()), |result, next| result + next)
324	}
325}
326
327impl<'a, U: UnderlierType, Scalar: BinaryField> Sum<&'a Self> for PackedPrimitiveType<U, Scalar>
328where
329	Self: Add<Output = Self>,
330{
331	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
332		iter.fold(Self::from(U::default()), |result, next| result + *next)
333	}
334}
335
336impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product
337	for PackedPrimitiveType<U, Scalar>
338where
339	Self: Mul<Output = Self>,
340{
341	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
342		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
343	}
344}
345
346impl<'a, U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
347	Product<&'a Self> for PackedPrimitiveType<U, Scalar>
348where
349	Self: Mul<Output = Self>,
350{
351	fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
352		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * *next)
353	}
354}
355
356impl<U: UnderlierType, Scalar: BinaryField> Mul<&Self> for PackedPrimitiveType<U, Scalar>
357where
358	Self: Mul<Output = Self>,
359{
360	type Output = Self;
361
362	#[inline]
363	fn mul(self, rhs: &Self) -> Self::Output {
364		self * *rhs
365	}
366}
367
368unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
369	for PackedPrimitiveType<U, Scalar>
370{
371}
372
373unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
374
375impl<U, Scalar> FieldOps for PackedPrimitiveType<U, Scalar>
376where
377	Self: Square + InvertOrZero + Mul<Output = Self>,
378	U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
379	Scalar: BinaryField,
380{
381	type Scalar = Scalar;
382
383	#[inline]
384	fn zero() -> Self {
385		Self::from_underlier(U::ZERO)
386	}
387
388	#[inline]
389	fn one() -> Self {
390		Self::broadcast(Scalar::ONE)
391	}
392
393	fn square_transpose<FSub: Field>(elems: &mut [Self])
394	where
395		Scalar: ExtensionField<FSub>,
396	{
397		let log_degree = <Scalar as ExtensionField<FSub>>::LOG_DEGREE;
398		let degree = <Scalar as ExtensionField<FSub>>::DEGREE;
399		assert_eq!(elems.len(), degree);
400
401		let log_sub_bits = Scalar::N_BITS.ilog2() as usize - log_degree;
402
403		// See Hacker's Delight, Section 7-3.
404		for i in 0..log_degree {
405			for j in 0..1 << (log_degree - i - 1) {
406				for k in 0..1 << i {
407					let idx0 = (j << (i + 1)) | k;
408					let idx1 = idx0 | (1 << i);
409					let (u0, u1) = elems[idx0].0.interleave(elems[idx1].0, i + log_sub_bits);
410					elems[idx0] = u0.into();
411					elems[idx1] = u1.into();
412				}
413			}
414		}
415	}
416}
417
418impl<U, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
419where
420	Self:
421		Square + InvertOrZero + Mul<Output = Self> + WideMul<Output: Debug + Send + Sync + 'static>,
422	U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
423	Scalar: BinaryField,
424{
425	const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
426
427	#[inline]
428	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
429		Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
430	}
431
432	#[inline]
433	unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
434		unsafe {
435			self.0.set_subvalue(i, scalar.to_underlier());
436		}
437	}
438
439	#[inline]
440	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
441		Divisible::<Scalar::Underlier>::ref_iter(&self.0)
442			.map(|underlier| Scalar::from_underlier(underlier))
443	}
444
445	#[inline]
446	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
447		Divisible::<Scalar::Underlier>::value_iter(self.0)
448			.map(|underlier| Scalar::from_underlier(underlier))
449	}
450
451	#[inline]
452	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
453		Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
454			.map_skippable(|underlier| Scalar::from_underlier(underlier))
455	}
456
457	#[inline]
458	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
459		assert!(log_block_len < Self::LOG_WIDTH);
460		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
461		let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
462		(c.into(), d.into())
463	}
464
465	#[inline]
466	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
467		assert!(log_block_len < Self::LOG_WIDTH);
468		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
469		let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
470		(c.into(), d.into())
471	}
472
473	#[inline]
474	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
475		debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
476		debug_assert!(
477			block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
478			"{} < {}",
479			block_idx,
480			1 << (Self::LOG_WIDTH - log_block_len)
481		);
482
483		unsafe {
484			self.0
485				.spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
486				.into()
487		}
488	}
489
490	#[inline]
491	fn broadcast(scalar: Self::Scalar) -> Self {
492		Self::broadcast(scalar)
493	}
494
495	#[inline]
496	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
497		U::from_fn(move |i| f(i).to_underlier()).into()
498	}
499}
500
501impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
502	for StandardUniform
503{
504	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
505		PackedPrimitiveType::from_underlier(U::random(rng))
506	}
507}
508
509/// Multiply `PT1` values by upcasting to wider `PT2` type with the same scalar.
510/// This is useful for the cases when SIMD multiplication is faster.
511#[allow(dead_code)]
512pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
513where
514	PT1: PackedField + WithUnderlier,
515	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
516	PT2::Underlier: From<PT1::Underlier>,
517	PT1::Underlier: NumCast<PT2::Underlier>,
518{
519	let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
520	let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
521
522	let bigger_result = bigger_lhs * bigger_rhs;
523
524	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
525}
526
527/// Square `PT1` values by upcasting to wider `PT2` type with the same scalar.
528/// This is useful for the cases when SIMD square is faster.
529#[allow(dead_code)]
530pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
531where
532	PT1: PackedField + WithUnderlier,
533	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
534	PT2::Underlier: From<PT1::Underlier>,
535	PT1::Underlier: NumCast<PT2::Underlier>,
536{
537	let bigger_val = PT2::from_underlier(val.to_underlier().into());
538
539	let bigger_result = bigger_val.square();
540
541	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
542}
543
544/// Invert `PT1` values by upcasting to wider `PT2` type with the same scalar.
545/// This is useful for the cases when SIMD invert is faster.
546#[allow(dead_code)]
547pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
548where
549	PT1: PackedField + WithUnderlier,
550	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
551	PT2::Underlier: From<PT1::Underlier>,
552	PT1::Underlier: NumCast<PT2::Underlier>,
553{
554	let bigger_val = PT2::from_underlier(val.to_underlier().into());
555
556	let bigger_result = bigger_val.invert_or_zero();
557
558	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
559}