binius_field/arch/portable/
packed.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3// This is because derive(bytemuck::TransparentWrapper) adds some type constraints to
4// PackedPrimitiveType in addition to the type constraints we define. Even more, annoying, the
5// allow attribute has to be added to the module, it doesn't work to add it to the struct
6// definition.
7#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10	fmt::Debug,
11	iter::{Product, Sum},
12	marker::PhantomData,
13	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::{
19	Rng,
20	distr::{Distribution, StandardUniform},
21};
22
23use super::packed_arithmetic::UnderlierWithBitConstants;
24use crate::{
25	BinaryField, PackedField,
26	arithmetic_traits::{Broadcast, InvertOrZero, MulAlpha, Square},
27	underlier::{
28		IterationMethods, IterationStrategy, NumCast, U1, U2, U4, UnderlierType,
29		UnderlierWithBitOps, WithUnderlier,
30	},
31};
32
33#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
34#[repr(transparent)]
35#[transparent(U)]
36pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
37	pub U,
38	pub PhantomData<Scalar>,
39);
40
41impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
42	pub const WIDTH: usize = {
43		assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
44
45		U::BITS / Scalar::N_BITS
46	};
47
48	pub const LOG_WIDTH: usize = {
49		let result = Self::WIDTH.ilog2();
50
51		assert!(2usize.pow(result) == Self::WIDTH);
52
53		result as usize
54	};
55
56	#[inline]
57	pub const fn from_underlier(val: U) -> Self {
58		Self(val, PhantomData)
59	}
60
61	#[inline]
62	pub const fn to_underlier(self) -> U {
63		self.0
64	}
65}
66
67unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
68	for PackedPrimitiveType<U, Scalar>
69{
70	type Underlier = U;
71
72	#[inline(always)]
73	fn to_underlier(self) -> Self::Underlier {
74		TransparentWrapper::peel(self)
75	}
76
77	#[inline(always)]
78	fn to_underlier_ref(&self) -> &Self::Underlier {
79		TransparentWrapper::peel_ref(self)
80	}
81
82	#[inline(always)]
83	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
84		TransparentWrapper::peel_mut(self)
85	}
86
87	#[inline(always)]
88	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
89		TransparentWrapper::peel_slice(val)
90	}
91
92	#[inline(always)]
93	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
94		TransparentWrapper::peel_slice_mut(val)
95	}
96
97	#[inline(always)]
98	fn from_underlier(val: Self::Underlier) -> Self {
99		TransparentWrapper::wrap(val)
100	}
101
102	#[inline(always)]
103	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
104		TransparentWrapper::wrap_ref(val)
105	}
106
107	#[inline(always)]
108	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
109		TransparentWrapper::wrap_mut(val)
110	}
111
112	#[inline(always)]
113	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
114		TransparentWrapper::wrap_slice(val)
115	}
116
117	#[inline(always)]
118	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
119		TransparentWrapper::wrap_slice_mut(val)
120	}
121}
122
123impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
124where
125	Self: PackedField,
126{
127	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128		let width = checked_int_div(U::BITS, Scalar::N_BITS);
129		let values_str = self
130			.iter()
131			.map(|value| format!("{value}"))
132			.collect::<Vec<_>>()
133			.join(",");
134
135		write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
136	}
137}
138
139impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
140	#[inline]
141	fn from(val: U) -> Self {
142		Self(val, PhantomData)
143	}
144}
145
146impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
147	type Output = Self;
148
149	#[inline]
150	#[allow(clippy::suspicious_arithmetic_impl)]
151	fn add(self, rhs: Self) -> Self::Output {
152		(self.0 ^ rhs.0).into()
153	}
154}
155
156impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
157	type Output = Self;
158
159	#[inline]
160	#[allow(clippy::suspicious_arithmetic_impl)]
161	fn sub(self, rhs: Self) -> Self::Output {
162		(self.0 ^ rhs.0).into()
163	}
164}
165
166impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
167where
168	Self: Add<Output = Self>,
169{
170	fn add_assign(&mut self, rhs: Self) {
171		*self = *self + rhs;
172	}
173}
174
175impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
176where
177	Self: Sub<Output = Self>,
178{
179	fn sub_assign(&mut self, rhs: Self) {
180		*self = *self - rhs;
181	}
182}
183
184impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
185where
186	Self: Mul<Output = Self>,
187{
188	fn mul_assign(&mut self, rhs: Self) {
189		*self = *self * rhs;
190	}
191}
192
193impl<U: UnderlierType, Scalar: BinaryField> Add<Scalar> for PackedPrimitiveType<U, Scalar>
194where
195	Self: Broadcast<Scalar> + Add<Output = Self>,
196{
197	type Output = Self;
198
199	fn add(self, rhs: Scalar) -> Self::Output {
200		self + Self::broadcast(rhs)
201	}
202}
203
204impl<U: UnderlierType, Scalar: BinaryField> Sub<Scalar> for PackedPrimitiveType<U, Scalar>
205where
206	Self: Broadcast<Scalar> + Sub<Output = Self>,
207{
208	type Output = Self;
209
210	fn sub(self, rhs: Scalar) -> Self::Output {
211		self - Self::broadcast(rhs)
212	}
213}
214
215impl<U: UnderlierType, Scalar: BinaryField> Mul<Scalar> for PackedPrimitiveType<U, Scalar>
216where
217	Self: Broadcast<Scalar> + Mul<Output = Self>,
218{
219	type Output = Self;
220
221	fn mul(self, rhs: Scalar) -> Self::Output {
222		self * Self::broadcast(rhs)
223	}
224}
225
226impl<U: UnderlierType, Scalar: BinaryField> AddAssign<Scalar> for PackedPrimitiveType<U, Scalar>
227where
228	Self: Broadcast<Scalar> + AddAssign<Self>,
229{
230	fn add_assign(&mut self, rhs: Scalar) {
231		*self += Self::broadcast(rhs);
232	}
233}
234
235impl<U: UnderlierType, Scalar: BinaryField> SubAssign<Scalar> for PackedPrimitiveType<U, Scalar>
236where
237	Self: Broadcast<Scalar> + SubAssign<Self>,
238{
239	fn sub_assign(&mut self, rhs: Scalar) {
240		*self -= Self::broadcast(rhs);
241	}
242}
243
244impl<U: UnderlierType, Scalar: BinaryField> MulAssign<Scalar> for PackedPrimitiveType<U, Scalar>
245where
246	Self: Broadcast<Scalar> + MulAssign<Self>,
247{
248	fn mul_assign(&mut self, rhs: Scalar) {
249		*self *= Self::broadcast(rhs);
250	}
251}
252
253impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
254where
255	Self: Add<Output = Self>,
256{
257	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
258		iter.fold(Self::from(U::default()), |result, next| result + next)
259	}
260}
261
262impl<U: UnderlierType, Scalar: BinaryField> Product for PackedPrimitiveType<U, Scalar>
263where
264	Self: Broadcast<Scalar> + Mul<Output = Self>,
265{
266	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
267		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
268	}
269}
270
271unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
272	for PackedPrimitiveType<U, Scalar>
273{
274}
275
276unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
277
278impl<U: UnderlierWithBitOps, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
279where
280	Self: Broadcast<Scalar> + Square + InvertOrZero + Mul<Output = Self>,
281	U: UnderlierWithBitConstants + From<Scalar::Underlier> + Send + Sync + 'static,
282	Scalar: BinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
283	Scalar::Underlier: NumCast<U>,
284	IterationMethods<Scalar::Underlier, U>: IterationStrategy<Scalar::Underlier, U>,
285{
286	type Scalar = Scalar;
287
288	const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
289
290	#[inline]
291	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
292		Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
293	}
294
295	#[inline]
296	unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
297		unsafe {
298			self.0.set_subvalue(i, scalar.to_underlier());
299		}
300	}
301
302	#[inline]
303	fn zero() -> Self {
304		Self::from_underlier(U::ZERO)
305	}
306
307	#[inline]
308	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
309		IterationMethods::<Scalar::Underlier, U>::ref_iter(&self.0)
310			.map(|underlier| Scalar::from_underlier(underlier))
311	}
312
313	#[inline]
314	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
315		IterationMethods::<Scalar::Underlier, U>::value_iter(self.0)
316			.map(|underlier| Scalar::from_underlier(underlier))
317	}
318
319	#[inline]
320	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
321		IterationMethods::<Scalar::Underlier, U>::slice_iter(Self::to_underliers_ref(slice))
322			.map_skippable(|underlier| Scalar::from_underlier(underlier))
323	}
324
325	#[inline]
326	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
327		assert!(log_block_len < Self::LOG_WIDTH);
328		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
329		let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
330		(c.into(), d.into())
331	}
332
333	#[inline]
334	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
335		assert!(log_block_len < Self::LOG_WIDTH);
336		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
337		let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
338		(c.into(), d.into())
339	}
340
341	#[inline]
342	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
343		debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
344		debug_assert!(
345			block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
346			"{} < {}",
347			block_idx,
348			1 << (Self::LOG_WIDTH - log_block_len)
349		);
350
351		unsafe {
352			self.0
353				.spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
354				.into()
355		}
356	}
357
358	#[inline]
359	fn broadcast(scalar: Self::Scalar) -> Self {
360		<Self as Broadcast<Self::Scalar>>::broadcast(scalar)
361	}
362
363	#[inline]
364	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
365		U::from_fn(move |i| f(i).to_underlier()).into()
366	}
367
368	#[inline]
369	fn square(self) -> Self {
370		<Self as Square>::square(self)
371	}
372
373	#[inline]
374	fn invert_or_zero(self) -> Self {
375		<Self as InvertOrZero>::invert_or_zero(self)
376	}
377}
378
379impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
380	for StandardUniform
381{
382	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
383		PackedPrimitiveType::from_underlier(U::random(rng))
384	}
385}
386
387/// Multiply `PT1` values by upcasting to wider `PT2` type with the same scalar.
388/// This is useful for the cases when SIMD multiplication is faster.
389#[allow(dead_code)]
390pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
391where
392	PT1: PackedField + WithUnderlier,
393	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
394	PT2::Underlier: From<PT1::Underlier>,
395	PT1::Underlier: NumCast<PT2::Underlier>,
396{
397	let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
398	let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
399
400	let bigger_result = bigger_lhs * bigger_rhs;
401
402	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
403}
404
405/// Square `PT1` values by upcasting to wider `PT2` type with the same scalar.
406/// This is useful for the cases when SIMD square is faster.
407#[allow(dead_code)]
408pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
409where
410	PT1: PackedField + WithUnderlier,
411	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
412	PT2::Underlier: From<PT1::Underlier>,
413	PT1::Underlier: NumCast<PT2::Underlier>,
414{
415	let bigger_val = PT2::from_underlier(val.to_underlier().into());
416
417	let bigger_result = bigger_val.square();
418
419	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
420}
421
422/// Invert `PT1` values by upcasting to wider `PT2` type with the same scalar.
423/// This is useful for the cases when SIMD invert is faster.
424#[allow(dead_code)]
425pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
426where
427	PT1: PackedField + WithUnderlier,
428	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
429	PT2::Underlier: From<PT1::Underlier>,
430	PT1::Underlier: NumCast<PT2::Underlier>,
431{
432	let bigger_val = PT2::from_underlier(val.to_underlier().into());
433
434	let bigger_result = bigger_val.invert_or_zero();
435
436	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
437}
438
439/// Multiply by alpha `PT1` values by upcasting to wider `PT2` type with the same scalar.
440/// This is useful for the cases when SIMD multiply by alpha is faster.
441#[allow(dead_code)]
442pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
443where
444	PT1: PackedField + WithUnderlier,
445	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
446	PT2::Underlier: From<PT1::Underlier>,
447	PT1::Underlier: NumCast<PT2::Underlier>,
448{
449	let bigger_val = PT2::from_underlier(val.to_underlier().into());
450
451	let bigger_result = bigger_val.mul_alpha();
452
453	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
454}
455
456macro_rules! impl_pack_scalar {
457	($underlier:ty) => {
458		impl<F> $crate::as_packed_field::PackScalar<F> for $underlier
459		where
460			F: BinaryField,
461			PackedPrimitiveType<$underlier, F>:
462				$crate::packed::PackedField<Scalar = F> + WithUnderlier<Underlier = $underlier>,
463		{
464			type Packed = PackedPrimitiveType<$underlier, F>;
465		}
466	};
467}
468
469pub(crate) use impl_pack_scalar;
470
471impl_pack_scalar!(U1);
472impl_pack_scalar!(U2);
473impl_pack_scalar!(U4);
474impl_pack_scalar!(u8);
475impl_pack_scalar!(u16);
476impl_pack_scalar!(u32);
477impl_pack_scalar!(u64);
478impl_pack_scalar!(u128);