binius_field/arch/portable/
packed.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3// This is because derive(bytemuck::TransparentWrapper) adds some type constraints to
4// PackedPrimitiveType in addition to the type constraints we define. Even more, annoying, the
5// allow attribute has to be added to the module, it doesn't work to add it to the struct
6// definition.
7#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10	fmt::Debug,
11	iter::{Product, Sum},
12	marker::PhantomData,
13	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::{
19	Rng,
20	distr::{Distribution, StandardUniform},
21};
22
23use super::packed_arithmetic::UnderlierWithBitConstants;
24use crate::{
25	BinaryField, Divisible, PackedField,
26	arithmetic_traits::{Broadcast, InvertOrZero, MulAlpha, Square},
27	underlier::{NumCast, U1, U2, U4, UnderlierType, UnderlierWithBitOps, WithUnderlier},
28};
29
30#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
31#[repr(transparent)]
32#[transparent(U)]
33pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
34	pub U,
35	pub PhantomData<Scalar>,
36);
37
38impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
39	pub const WIDTH: usize = {
40		assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
41
42		U::BITS / Scalar::N_BITS
43	};
44
45	pub const LOG_WIDTH: usize = {
46		let result = Self::WIDTH.ilog2();
47
48		assert!(2usize.pow(result) == Self::WIDTH);
49
50		result as usize
51	};
52
53	#[inline]
54	pub const fn from_underlier(val: U) -> Self {
55		Self(val, PhantomData)
56	}
57
58	#[inline]
59	pub const fn to_underlier(self) -> U {
60		self.0
61	}
62}
63
64unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
65	for PackedPrimitiveType<U, Scalar>
66{
67	type Underlier = U;
68
69	#[inline(always)]
70	fn to_underlier(self) -> Self::Underlier {
71		TransparentWrapper::peel(self)
72	}
73
74	#[inline(always)]
75	fn to_underlier_ref(&self) -> &Self::Underlier {
76		TransparentWrapper::peel_ref(self)
77	}
78
79	#[inline(always)]
80	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
81		TransparentWrapper::peel_mut(self)
82	}
83
84	#[inline(always)]
85	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
86		TransparentWrapper::peel_slice(val)
87	}
88
89	#[inline(always)]
90	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
91		TransparentWrapper::peel_slice_mut(val)
92	}
93
94	#[inline(always)]
95	fn from_underlier(val: Self::Underlier) -> Self {
96		TransparentWrapper::wrap(val)
97	}
98
99	#[inline(always)]
100	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
101		TransparentWrapper::wrap_ref(val)
102	}
103
104	#[inline(always)]
105	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
106		TransparentWrapper::wrap_mut(val)
107	}
108
109	#[inline(always)]
110	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
111		TransparentWrapper::wrap_slice(val)
112	}
113
114	#[inline(always)]
115	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
116		TransparentWrapper::wrap_slice_mut(val)
117	}
118}
119
120impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
121where
122	Self: PackedField,
123{
124	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125		let width = checked_int_div(U::BITS, Scalar::N_BITS);
126		let values_str = self
127			.iter()
128			.map(|value| format!("{value}"))
129			.collect::<Vec<_>>()
130			.join(",");
131
132		write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
133	}
134}
135
136impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
137	#[inline]
138	fn from(val: U) -> Self {
139		Self(val, PhantomData)
140	}
141}
142
143impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
144	type Output = Self;
145
146	#[inline]
147	#[allow(clippy::suspicious_arithmetic_impl)]
148	fn add(self, rhs: Self) -> Self::Output {
149		(self.0 ^ rhs.0).into()
150	}
151}
152
153impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
154	type Output = Self;
155
156	#[inline]
157	#[allow(clippy::suspicious_arithmetic_impl)]
158	fn sub(self, rhs: Self) -> Self::Output {
159		(self.0 ^ rhs.0).into()
160	}
161}
162
163impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
164where
165	Self: Add<Output = Self>,
166{
167	fn add_assign(&mut self, rhs: Self) {
168		*self = *self + rhs;
169	}
170}
171
172impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
173where
174	Self: Sub<Output = Self>,
175{
176	fn sub_assign(&mut self, rhs: Self) {
177		*self = *self - rhs;
178	}
179}
180
181impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
182where
183	Self: Mul<Output = Self>,
184{
185	fn mul_assign(&mut self, rhs: Self) {
186		*self = *self * rhs;
187	}
188}
189
190impl<U: UnderlierType, Scalar: BinaryField> Add<Scalar> for PackedPrimitiveType<U, Scalar>
191where
192	Self: Broadcast<Scalar> + Add<Output = Self>,
193{
194	type Output = Self;
195
196	fn add(self, rhs: Scalar) -> Self::Output {
197		self + Self::broadcast(rhs)
198	}
199}
200
201impl<U: UnderlierType, Scalar: BinaryField> Sub<Scalar> for PackedPrimitiveType<U, Scalar>
202where
203	Self: Broadcast<Scalar> + Sub<Output = Self>,
204{
205	type Output = Self;
206
207	fn sub(self, rhs: Scalar) -> Self::Output {
208		self - Self::broadcast(rhs)
209	}
210}
211
212impl<U: UnderlierType, Scalar: BinaryField> Mul<Scalar> for PackedPrimitiveType<U, Scalar>
213where
214	Self: Broadcast<Scalar> + Mul<Output = Self>,
215{
216	type Output = Self;
217
218	fn mul(self, rhs: Scalar) -> Self::Output {
219		self * Self::broadcast(rhs)
220	}
221}
222
223impl<U: UnderlierType, Scalar: BinaryField> AddAssign<Scalar> for PackedPrimitiveType<U, Scalar>
224where
225	Self: Broadcast<Scalar> + AddAssign<Self>,
226{
227	fn add_assign(&mut self, rhs: Scalar) {
228		*self += Self::broadcast(rhs);
229	}
230}
231
232impl<U: UnderlierType, Scalar: BinaryField> SubAssign<Scalar> for PackedPrimitiveType<U, Scalar>
233where
234	Self: Broadcast<Scalar> + SubAssign<Self>,
235{
236	fn sub_assign(&mut self, rhs: Scalar) {
237		*self -= Self::broadcast(rhs);
238	}
239}
240
241impl<U: UnderlierType, Scalar: BinaryField> MulAssign<Scalar> for PackedPrimitiveType<U, Scalar>
242where
243	Self: Broadcast<Scalar> + MulAssign<Self>,
244{
245	fn mul_assign(&mut self, rhs: Scalar) {
246		*self *= Self::broadcast(rhs);
247	}
248}
249
250impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
251where
252	Self: Add<Output = Self>,
253{
254	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
255		iter.fold(Self::from(U::default()), |result, next| result + next)
256	}
257}
258
259impl<U: UnderlierType, Scalar: BinaryField> Product for PackedPrimitiveType<U, Scalar>
260where
261	Self: Broadcast<Scalar> + Mul<Output = Self>,
262{
263	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
264		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
265	}
266}
267
268unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
269	for PackedPrimitiveType<U, Scalar>
270{
271}
272
273unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
274
275impl<U: UnderlierWithBitOps, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
276where
277	Self: Broadcast<Scalar> + Square + InvertOrZero + Mul<Output = Self>,
278	U: UnderlierWithBitConstants
279		+ Divisible<Scalar::Underlier>
280		+ From<Scalar::Underlier>
281		+ Send
282		+ Sync
283		+ 'static,
284	Scalar: BinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
285	Scalar::Underlier: NumCast<U>,
286{
287	type Scalar = Scalar;
288
289	const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
290
291	#[inline]
292	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
293		Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
294	}
295
296	#[inline]
297	unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
298		unsafe {
299			self.0.set_subvalue(i, scalar.to_underlier());
300		}
301	}
302
303	#[inline]
304	fn zero() -> Self {
305		Self::from_underlier(U::ZERO)
306	}
307
308	#[inline]
309	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
310		Divisible::<Scalar::Underlier>::ref_iter(&self.0)
311			.map(|underlier| Scalar::from_underlier(underlier))
312	}
313
314	#[inline]
315	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
316		Divisible::<Scalar::Underlier>::value_iter(self.0)
317			.map(|underlier| Scalar::from_underlier(underlier))
318	}
319
320	#[inline]
321	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
322		Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
323			.map_skippable(|underlier| Scalar::from_underlier(underlier))
324	}
325
326	#[inline]
327	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
328		assert!(log_block_len < Self::LOG_WIDTH);
329		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
330		let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
331		(c.into(), d.into())
332	}
333
334	#[inline]
335	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
336		assert!(log_block_len < Self::LOG_WIDTH);
337		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
338		let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
339		(c.into(), d.into())
340	}
341
342	#[inline]
343	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
344		debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
345		debug_assert!(
346			block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
347			"{} < {}",
348			block_idx,
349			1 << (Self::LOG_WIDTH - log_block_len)
350		);
351
352		unsafe {
353			self.0
354				.spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
355				.into()
356		}
357	}
358
359	#[inline]
360	fn broadcast(scalar: Self::Scalar) -> Self {
361		<Self as Broadcast<Self::Scalar>>::broadcast(scalar)
362	}
363
364	#[inline]
365	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
366		U::from_fn(move |i| f(i).to_underlier()).into()
367	}
368
369	#[inline]
370	fn square(self) -> Self {
371		<Self as Square>::square(self)
372	}
373
374	#[inline]
375	fn invert_or_zero(self) -> Self {
376		<Self as InvertOrZero>::invert_or_zero(self)
377	}
378}
379
380impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
381	for StandardUniform
382{
383	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
384		PackedPrimitiveType::from_underlier(U::random(rng))
385	}
386}
387
388/// Multiply `PT1` values by upcasting to wider `PT2` type with the same scalar.
389/// This is useful for the cases when SIMD multiplication is faster.
390#[allow(dead_code)]
391pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
392where
393	PT1: PackedField + WithUnderlier,
394	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
395	PT2::Underlier: From<PT1::Underlier>,
396	PT1::Underlier: NumCast<PT2::Underlier>,
397{
398	let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
399	let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
400
401	let bigger_result = bigger_lhs * bigger_rhs;
402
403	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
404}
405
406/// Square `PT1` values by upcasting to wider `PT2` type with the same scalar.
407/// This is useful for the cases when SIMD square is faster.
408#[allow(dead_code)]
409pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
410where
411	PT1: PackedField + WithUnderlier,
412	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
413	PT2::Underlier: From<PT1::Underlier>,
414	PT1::Underlier: NumCast<PT2::Underlier>,
415{
416	let bigger_val = PT2::from_underlier(val.to_underlier().into());
417
418	let bigger_result = bigger_val.square();
419
420	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
421}
422
423/// Invert `PT1` values by upcasting to wider `PT2` type with the same scalar.
424/// This is useful for the cases when SIMD invert is faster.
425#[allow(dead_code)]
426pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
427where
428	PT1: PackedField + WithUnderlier,
429	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
430	PT2::Underlier: From<PT1::Underlier>,
431	PT1::Underlier: NumCast<PT2::Underlier>,
432{
433	let bigger_val = PT2::from_underlier(val.to_underlier().into());
434
435	let bigger_result = bigger_val.invert_or_zero();
436
437	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
438}
439
440/// Multiply by alpha `PT1` values by upcasting to wider `PT2` type with the same scalar.
441/// This is useful for the cases when SIMD multiply by alpha is faster.
442#[allow(dead_code)]
443pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
444where
445	PT1: PackedField + WithUnderlier,
446	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
447	PT2::Underlier: From<PT1::Underlier>,
448	PT1::Underlier: NumCast<PT2::Underlier>,
449{
450	let bigger_val = PT2::from_underlier(val.to_underlier().into());
451
452	let bigger_result = bigger_val.mul_alpha();
453
454	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
455}
456
457macro_rules! impl_pack_scalar {
458	($underlier:ty) => {
459		impl<F> $crate::as_packed_field::PackScalar<F> for $underlier
460		where
461			F: BinaryField,
462			PackedPrimitiveType<$underlier, F>:
463				$crate::packed::PackedField<Scalar = F> + WithUnderlier<Underlier = $underlier>,
464		{
465			type Packed = PackedPrimitiveType<$underlier, F>;
466		}
467	};
468}
469
470pub(crate) use impl_pack_scalar;
471
472impl_pack_scalar!(U1);
473impl_pack_scalar!(U2);
474impl_pack_scalar!(U4);
475impl_pack_scalar!(u8);
476impl_pack_scalar!(u16);
477impl_pack_scalar!(u32);
478impl_pack_scalar!(u64);
479impl_pack_scalar!(u128);