binius_field/arch/portable/
packed.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3// This is because derive(bytemuck::TransparentWrapper) adds some type constraints to
4// PackedPrimitiveType in addition to the type constraints we define. Even more, annoying, the
5// allow attribute has to be added to the module, it doesn't work to add it to the struct
6// definition.
7#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10	fmt::Debug,
11	iter::{Product, Sum},
12	marker::PhantomData,
13	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::{
19	Rng,
20	distr::{Distribution, StandardUniform},
21};
22
23use crate::{
24	BinaryField, Divisible, PackedField,
25	arithmetic_traits::{InvertOrZero, MulAlpha, Square},
26	underlier::{NumCast, UnderlierType, UnderlierWithBitOps, WithUnderlier},
27};
28
29#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
30#[repr(transparent)]
31#[transparent(U)]
32pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
33	pub U,
34	pub PhantomData<Scalar>,
35);
36
37impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
38	pub const WIDTH: usize = {
39		assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
40
41		U::BITS / Scalar::N_BITS
42	};
43
44	pub const LOG_WIDTH: usize = {
45		let result = Self::WIDTH.ilog2();
46
47		assert!(2usize.pow(result) == Self::WIDTH);
48
49		result as usize
50	};
51
52	#[inline]
53	pub const fn from_underlier(val: U) -> Self {
54		Self(val, PhantomData)
55	}
56
57	#[inline]
58	pub const fn to_underlier(self) -> U {
59		self.0
60	}
61}
62
63impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
64	PackedPrimitiveType<U, Scalar>
65{
66	#[inline]
67	pub fn broadcast(scalar: Scalar) -> Self {
68		U::broadcast_subvalue(scalar.to_underlier()).into()
69	}
70}
71
72unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
73	for PackedPrimitiveType<U, Scalar>
74{
75	type Underlier = U;
76
77	#[inline(always)]
78	fn to_underlier(self) -> Self::Underlier {
79		TransparentWrapper::peel(self)
80	}
81
82	#[inline(always)]
83	fn to_underlier_ref(&self) -> &Self::Underlier {
84		TransparentWrapper::peel_ref(self)
85	}
86
87	#[inline(always)]
88	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
89		TransparentWrapper::peel_mut(self)
90	}
91
92	#[inline(always)]
93	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
94		TransparentWrapper::peel_slice(val)
95	}
96
97	#[inline(always)]
98	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
99		TransparentWrapper::peel_slice_mut(val)
100	}
101
102	#[inline(always)]
103	fn from_underlier(val: Self::Underlier) -> Self {
104		TransparentWrapper::wrap(val)
105	}
106
107	#[inline(always)]
108	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
109		TransparentWrapper::wrap_ref(val)
110	}
111
112	#[inline(always)]
113	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
114		TransparentWrapper::wrap_mut(val)
115	}
116
117	#[inline(always)]
118	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
119		TransparentWrapper::wrap_slice(val)
120	}
121
122	#[inline(always)]
123	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
124		TransparentWrapper::wrap_slice_mut(val)
125	}
126}
127
128impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
129where
130	Self: PackedField,
131{
132	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133		let width = checked_int_div(U::BITS, Scalar::N_BITS);
134		let values_str = self
135			.iter()
136			.map(|value| format!("{value}"))
137			.collect::<Vec<_>>()
138			.join(",");
139
140		write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
141	}
142}
143
144impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
145	#[inline]
146	fn from(val: U) -> Self {
147		Self(val, PhantomData)
148	}
149}
150
151impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
152	type Output = Self;
153
154	#[inline]
155	#[allow(clippy::suspicious_arithmetic_impl)]
156	fn add(self, rhs: Self) -> Self::Output {
157		(self.0 ^ rhs.0).into()
158	}
159}
160
161impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
162	type Output = Self;
163
164	#[inline]
165	#[allow(clippy::suspicious_arithmetic_impl)]
166	fn sub(self, rhs: Self) -> Self::Output {
167		(self.0 ^ rhs.0).into()
168	}
169}
170
171impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
172where
173	Self: Add<Output = Self>,
174{
175	fn add_assign(&mut self, rhs: Self) {
176		*self = *self + rhs;
177	}
178}
179
180impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
181where
182	Self: Sub<Output = Self>,
183{
184	fn sub_assign(&mut self, rhs: Self) {
185		*self = *self - rhs;
186	}
187}
188
189impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
190where
191	Self: Mul<Output = Self>,
192{
193	fn mul_assign(&mut self, rhs: Self) {
194		*self = *self * rhs;
195	}
196}
197
198impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Add<Scalar>
199	for PackedPrimitiveType<U, Scalar>
200{
201	type Output = Self;
202
203	fn add(self, rhs: Scalar) -> Self::Output {
204		self + Self::broadcast(rhs)
205	}
206}
207
208impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Sub<Scalar>
209	for PackedPrimitiveType<U, Scalar>
210{
211	type Output = Self;
212
213	fn sub(self, rhs: Scalar) -> Self::Output {
214		self - Self::broadcast(rhs)
215	}
216}
217
218impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Mul<Scalar>
219	for PackedPrimitiveType<U, Scalar>
220where
221	Self: Mul<Output = Self>,
222{
223	type Output = Self;
224
225	fn mul(self, rhs: Scalar) -> Self::Output {
226		self * Self::broadcast(rhs)
227	}
228}
229
230impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> AddAssign<Scalar>
231	for PackedPrimitiveType<U, Scalar>
232{
233	fn add_assign(&mut self, rhs: Scalar) {
234		*self += Self::broadcast(rhs);
235	}
236}
237
238impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> SubAssign<Scalar>
239	for PackedPrimitiveType<U, Scalar>
240{
241	fn sub_assign(&mut self, rhs: Scalar) {
242		*self -= Self::broadcast(rhs);
243	}
244}
245
246impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> MulAssign<Scalar>
247	for PackedPrimitiveType<U, Scalar>
248where
249	Self: MulAssign<Self>,
250{
251	fn mul_assign(&mut self, rhs: Scalar) {
252		*self *= Self::broadcast(rhs);
253	}
254}
255
256impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
257where
258	Self: Add<Output = Self>,
259{
260	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
261		iter.fold(Self::from(U::default()), |result, next| result + next)
262	}
263}
264
265impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product
266	for PackedPrimitiveType<U, Scalar>
267where
268	Self: Mul<Output = Self>,
269{
270	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
271		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
272	}
273}
274
275unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
276	for PackedPrimitiveType<U, Scalar>
277{
278}
279
280unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
281
282impl<U, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
283where
284	Self: Square + InvertOrZero + Mul<Output = Self>,
285	U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
286	Scalar: BinaryField,
287{
288	type Scalar = Scalar;
289
290	const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
291
292	#[inline]
293	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
294		Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
295	}
296
297	#[inline]
298	unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
299		unsafe {
300			self.0.set_subvalue(i, scalar.to_underlier());
301		}
302	}
303
304	#[inline]
305	fn zero() -> Self {
306		Self::from_underlier(U::ZERO)
307	}
308
309	#[inline]
310	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
311		Divisible::<Scalar::Underlier>::ref_iter(&self.0)
312			.map(|underlier| Scalar::from_underlier(underlier))
313	}
314
315	#[inline]
316	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
317		Divisible::<Scalar::Underlier>::value_iter(self.0)
318			.map(|underlier| Scalar::from_underlier(underlier))
319	}
320
321	#[inline]
322	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
323		Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
324			.map_skippable(|underlier| Scalar::from_underlier(underlier))
325	}
326
327	#[inline]
328	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
329		assert!(log_block_len < Self::LOG_WIDTH);
330		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
331		let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
332		(c.into(), d.into())
333	}
334
335	#[inline]
336	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
337		assert!(log_block_len < Self::LOG_WIDTH);
338		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
339		let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
340		(c.into(), d.into())
341	}
342
343	#[inline]
344	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
345		debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
346		debug_assert!(
347			block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
348			"{} < {}",
349			block_idx,
350			1 << (Self::LOG_WIDTH - log_block_len)
351		);
352
353		unsafe {
354			self.0
355				.spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
356				.into()
357		}
358	}
359
360	#[inline]
361	fn broadcast(scalar: Self::Scalar) -> Self {
362		Self::broadcast(scalar)
363	}
364
365	#[inline]
366	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
367		U::from_fn(move |i| f(i).to_underlier()).into()
368	}
369
370	#[inline]
371	fn square(self) -> Self {
372		<Self as Square>::square(self)
373	}
374
375	#[inline]
376	fn invert_or_zero(self) -> Self {
377		<Self as InvertOrZero>::invert_or_zero(self)
378	}
379}
380
381impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
382	for StandardUniform
383{
384	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
385		PackedPrimitiveType::from_underlier(U::random(rng))
386	}
387}
388
389/// Multiply `PT1` values by upcasting to wider `PT2` type with the same scalar.
390/// This is useful for the cases when SIMD multiplication is faster.
391#[allow(dead_code)]
392pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
393where
394	PT1: PackedField + WithUnderlier,
395	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
396	PT2::Underlier: From<PT1::Underlier>,
397	PT1::Underlier: NumCast<PT2::Underlier>,
398{
399	let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
400	let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
401
402	let bigger_result = bigger_lhs * bigger_rhs;
403
404	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
405}
406
407/// Square `PT1` values by upcasting to wider `PT2` type with the same scalar.
408/// This is useful for the cases when SIMD square is faster.
409#[allow(dead_code)]
410pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
411where
412	PT1: PackedField + WithUnderlier,
413	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
414	PT2::Underlier: From<PT1::Underlier>,
415	PT1::Underlier: NumCast<PT2::Underlier>,
416{
417	let bigger_val = PT2::from_underlier(val.to_underlier().into());
418
419	let bigger_result = bigger_val.square();
420
421	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
422}
423
424/// Invert `PT1` values by upcasting to wider `PT2` type with the same scalar.
425/// This is useful for the cases when SIMD invert is faster.
426#[allow(dead_code)]
427pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
428where
429	PT1: PackedField + WithUnderlier,
430	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
431	PT2::Underlier: From<PT1::Underlier>,
432	PT1::Underlier: NumCast<PT2::Underlier>,
433{
434	let bigger_val = PT2::from_underlier(val.to_underlier().into());
435
436	let bigger_result = bigger_val.invert_or_zero();
437
438	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
439}
440
441/// Multiply by alpha `PT1` values by upcasting to wider `PT2` type with the same scalar.
442/// This is useful for the cases when SIMD multiply by alpha is faster.
443#[allow(dead_code)]
444pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
445where
446	PT1: PackedField + WithUnderlier,
447	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
448	PT2::Underlier: From<PT1::Underlier>,
449	PT1::Underlier: NumCast<PT2::Underlier>,
450{
451	let bigger_val = PT2::from_underlier(val.to_underlier().into());
452
453	let bigger_result = bigger_val.mul_alpha();
454
455	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
456}