binius_field/arch/portable/
packed.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3// This is because derive(bytemuck::TransparentWrapper) adds some type constraints to
4// PackedPrimitiveType in addition to the type constraints we define. Even more, annoying, the
5// allow attribute has to be added to the module, it doesn't work to add it to the struct
6// definition.
7#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10	fmt::Debug,
11	iter::{Product, Sum},
12	marker::PhantomData,
13	ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::{
19	Rng,
20	distr::{Distribution, StandardUniform},
21};
22
23use crate::{
24	BinaryField, Divisible, PackedField,
25	arithmetic_traits::{InvertOrZero, MulAlpha, Square},
26	field::FieldOps,
27	underlier::{NumCast, UnderlierType, UnderlierWithBitOps, WithUnderlier},
28};
29
30#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
31#[repr(transparent)]
32#[transparent(U)]
33pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
34	pub U,
35	pub PhantomData<Scalar>,
36);
37
38impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
39	pub const WIDTH: usize = {
40		assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
41
42		U::BITS / Scalar::N_BITS
43	};
44
45	pub const LOG_WIDTH: usize = {
46		let result = Self::WIDTH.ilog2();
47
48		assert!(2usize.pow(result) == Self::WIDTH);
49
50		result as usize
51	};
52
53	#[inline]
54	pub const fn from_underlier(val: U) -> Self {
55		Self(val, PhantomData)
56	}
57
58	#[inline]
59	pub const fn to_underlier(self) -> U {
60		self.0
61	}
62}
63
64impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
65	PackedPrimitiveType<U, Scalar>
66{
67	#[inline]
68	pub fn broadcast(scalar: Scalar) -> Self {
69		U::broadcast_subvalue(scalar.to_underlier()).into()
70	}
71}
72
73unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
74	for PackedPrimitiveType<U, Scalar>
75{
76	type Underlier = U;
77
78	#[inline(always)]
79	fn to_underlier(self) -> Self::Underlier {
80		TransparentWrapper::peel(self)
81	}
82
83	#[inline(always)]
84	fn to_underlier_ref(&self) -> &Self::Underlier {
85		TransparentWrapper::peel_ref(self)
86	}
87
88	#[inline(always)]
89	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
90		TransparentWrapper::peel_mut(self)
91	}
92
93	#[inline(always)]
94	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
95		TransparentWrapper::peel_slice(val)
96	}
97
98	#[inline(always)]
99	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
100		TransparentWrapper::peel_slice_mut(val)
101	}
102
103	#[inline(always)]
104	fn from_underlier(val: Self::Underlier) -> Self {
105		TransparentWrapper::wrap(val)
106	}
107
108	#[inline(always)]
109	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
110		TransparentWrapper::wrap_ref(val)
111	}
112
113	#[inline(always)]
114	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
115		TransparentWrapper::wrap_mut(val)
116	}
117
118	#[inline(always)]
119	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
120		TransparentWrapper::wrap_slice(val)
121	}
122
123	#[inline(always)]
124	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
125		TransparentWrapper::wrap_slice_mut(val)
126	}
127}
128
129impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
130where
131	Self: PackedField,
132{
133	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134		let width = checked_int_div(U::BITS, Scalar::N_BITS);
135		let values_str = self
136			.iter()
137			.map(|value| format!("{value}"))
138			.collect::<Vec<_>>()
139			.join(",");
140
141		write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
142	}
143}
144
145impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
146	#[inline]
147	fn from(val: U) -> Self {
148		Self(val, PhantomData)
149	}
150}
151
152impl<U: UnderlierWithBitOps, Scalar: BinaryField> Neg for PackedPrimitiveType<U, Scalar> {
153	type Output = Self;
154
155	#[inline]
156	fn neg(self) -> Self::Output {
157		self
158	}
159}
160
161impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
162	type Output = Self;
163
164	#[inline]
165	#[allow(clippy::suspicious_arithmetic_impl)]
166	fn add(self, rhs: Self) -> Self::Output {
167		(self.0 ^ rhs.0).into()
168	}
169}
170
171impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add<&Self> for PackedPrimitiveType<U, Scalar> {
172	type Output = Self;
173
174	#[inline]
175	#[allow(clippy::suspicious_arithmetic_impl)]
176	fn add(self, rhs: &Self) -> Self::Output {
177		(self.0 ^ rhs.0).into()
178	}
179}
180
181impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
182	type Output = Self;
183
184	#[inline]
185	#[allow(clippy::suspicious_arithmetic_impl)]
186	fn sub(self, rhs: Self) -> Self::Output {
187		(self.0 ^ rhs.0).into()
188	}
189}
190
191impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub<&Self> for PackedPrimitiveType<U, Scalar> {
192	type Output = Self;
193
194	#[inline]
195	#[allow(clippy::suspicious_arithmetic_impl)]
196	fn sub(self, rhs: &Self) -> Self::Output {
197		(self.0 ^ rhs.0).into()
198	}
199}
200
201impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
202where
203	Self: Add<Output = Self>,
204{
205	fn add_assign(&mut self, rhs: Self) {
206		*self = *self + rhs;
207	}
208}
209
210impl<U: UnderlierType, Scalar: BinaryField> AddAssign<&Self> for PackedPrimitiveType<U, Scalar>
211where
212	Self: for<'a> Add<&'a Self, Output = Self>,
213{
214	fn add_assign(&mut self, rhs: &Self) {
215		*self = *self + rhs;
216	}
217}
218
219impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
220where
221	Self: Sub<Output = Self>,
222{
223	fn sub_assign(&mut self, rhs: Self) {
224		*self = *self - rhs;
225	}
226}
227
228impl<U: UnderlierType, Scalar: BinaryField> SubAssign<&Self> for PackedPrimitiveType<U, Scalar>
229where
230	Self: for<'a> Sub<&'a Self, Output = Self>,
231{
232	fn sub_assign(&mut self, rhs: &Self) {
233		*self = *self - rhs;
234	}
235}
236
237impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
238where
239	Self: Mul<Output = Self>,
240{
241	fn mul_assign(&mut self, rhs: Self) {
242		*self = *self * rhs;
243	}
244}
245
246impl<U: UnderlierType, Scalar: BinaryField> MulAssign<&Self> for PackedPrimitiveType<U, Scalar>
247where
248	Self: for<'a> Mul<&'a Self, Output = Self>,
249{
250	fn mul_assign(&mut self, rhs: &Self) {
251		*self = *self * rhs;
252	}
253}
254
255impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Add<Scalar>
256	for PackedPrimitiveType<U, Scalar>
257{
258	type Output = Self;
259
260	fn add(self, rhs: Scalar) -> Self::Output {
261		self + Self::broadcast(rhs)
262	}
263}
264
265impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Sub<Scalar>
266	for PackedPrimitiveType<U, Scalar>
267{
268	type Output = Self;
269
270	fn sub(self, rhs: Scalar) -> Self::Output {
271		self - Self::broadcast(rhs)
272	}
273}
274
275impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Mul<Scalar>
276	for PackedPrimitiveType<U, Scalar>
277where
278	Self: Mul<Output = Self>,
279{
280	type Output = Self;
281
282	fn mul(self, rhs: Scalar) -> Self::Output {
283		self * Self::broadcast(rhs)
284	}
285}
286
287impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> AddAssign<Scalar>
288	for PackedPrimitiveType<U, Scalar>
289{
290	fn add_assign(&mut self, rhs: Scalar) {
291		*self += Self::broadcast(rhs);
292	}
293}
294
295impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> SubAssign<Scalar>
296	for PackedPrimitiveType<U, Scalar>
297{
298	fn sub_assign(&mut self, rhs: Scalar) {
299		*self -= Self::broadcast(rhs);
300	}
301}
302
303impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> MulAssign<Scalar>
304	for PackedPrimitiveType<U, Scalar>
305where
306	Self: MulAssign<Self>,
307{
308	fn mul_assign(&mut self, rhs: Scalar) {
309		*self *= Self::broadcast(rhs);
310	}
311}
312
313impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
314where
315	Self: Add<Output = Self>,
316{
317	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
318		iter.fold(Self::from(U::default()), |result, next| result + next)
319	}
320}
321
322impl<'a, U: UnderlierType, Scalar: BinaryField> Sum<&'a Self> for PackedPrimitiveType<U, Scalar>
323where
324	Self: Add<Output = Self>,
325{
326	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
327		iter.fold(Self::from(U::default()), |result, next| result + *next)
328	}
329}
330
331impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product
332	for PackedPrimitiveType<U, Scalar>
333where
334	Self: Mul<Output = Self>,
335{
336	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
337		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
338	}
339}
340
341impl<'a, U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
342	Product<&'a Self> for PackedPrimitiveType<U, Scalar>
343where
344	Self: Mul<Output = Self>,
345{
346	fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
347		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * *next)
348	}
349}
350
351impl<U: UnderlierType, Scalar: BinaryField> Mul<&Self> for PackedPrimitiveType<U, Scalar>
352where
353	Self: Mul<Output = Self>,
354{
355	type Output = Self;
356
357	#[inline]
358	fn mul(self, rhs: &Self) -> Self::Output {
359		self * *rhs
360	}
361}
362
363unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
364	for PackedPrimitiveType<U, Scalar>
365{
366}
367
368unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
369
370impl<U, Scalar> FieldOps<Scalar> for PackedPrimitiveType<U, Scalar>
371where
372	Self: Square + InvertOrZero + Mul<Output = Self>,
373	U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
374	Scalar: BinaryField,
375{
376	#[inline]
377	fn zero() -> Self {
378		Self::from_underlier(U::ZERO)
379	}
380
381	#[inline]
382	fn one() -> Self {
383		Self::broadcast(Scalar::ONE)
384	}
385}
386
387impl<U, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
388where
389	Self: Square + InvertOrZero + Mul<Output = Self>,
390	U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
391	Scalar: BinaryField,
392{
393	type Scalar = Scalar;
394
395	const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
396
397	#[inline]
398	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
399		Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
400	}
401
402	#[inline]
403	unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
404		unsafe {
405			self.0.set_subvalue(i, scalar.to_underlier());
406		}
407	}
408
409	#[inline]
410	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
411		Divisible::<Scalar::Underlier>::ref_iter(&self.0)
412			.map(|underlier| Scalar::from_underlier(underlier))
413	}
414
415	#[inline]
416	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
417		Divisible::<Scalar::Underlier>::value_iter(self.0)
418			.map(|underlier| Scalar::from_underlier(underlier))
419	}
420
421	#[inline]
422	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
423		Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
424			.map_skippable(|underlier| Scalar::from_underlier(underlier))
425	}
426
427	#[inline]
428	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
429		assert!(log_block_len < Self::LOG_WIDTH);
430		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
431		let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
432		(c.into(), d.into())
433	}
434
435	#[inline]
436	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
437		assert!(log_block_len < Self::LOG_WIDTH);
438		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
439		let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
440		(c.into(), d.into())
441	}
442
443	#[inline]
444	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
445		debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
446		debug_assert!(
447			block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
448			"{} < {}",
449			block_idx,
450			1 << (Self::LOG_WIDTH - log_block_len)
451		);
452
453		unsafe {
454			self.0
455				.spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
456				.into()
457		}
458	}
459
460	#[inline]
461	fn broadcast(scalar: Self::Scalar) -> Self {
462		Self::broadcast(scalar)
463	}
464
465	#[inline]
466	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
467		U::from_fn(move |i| f(i).to_underlier()).into()
468	}
469}
470
471impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
472	for StandardUniform
473{
474	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
475		PackedPrimitiveType::from_underlier(U::random(rng))
476	}
477}
478
479/// Multiply `PT1` values by upcasting to wider `PT2` type with the same scalar.
480/// This is useful for the cases when SIMD multiplication is faster.
481#[allow(dead_code)]
482pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
483where
484	PT1: PackedField + WithUnderlier,
485	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
486	PT2::Underlier: From<PT1::Underlier>,
487	PT1::Underlier: NumCast<PT2::Underlier>,
488{
489	let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
490	let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
491
492	let bigger_result = bigger_lhs * bigger_rhs;
493
494	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
495}
496
497/// Square `PT1` values by upcasting to wider `PT2` type with the same scalar.
498/// This is useful for the cases when SIMD square is faster.
499#[allow(dead_code)]
500pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
501where
502	PT1: PackedField + WithUnderlier,
503	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
504	PT2::Underlier: From<PT1::Underlier>,
505	PT1::Underlier: NumCast<PT2::Underlier>,
506{
507	let bigger_val = PT2::from_underlier(val.to_underlier().into());
508
509	let bigger_result = bigger_val.square();
510
511	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
512}
513
514/// Invert `PT1` values by upcasting to wider `PT2` type with the same scalar.
515/// This is useful for the cases when SIMD invert is faster.
516#[allow(dead_code)]
517pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
518where
519	PT1: PackedField + WithUnderlier,
520	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
521	PT2::Underlier: From<PT1::Underlier>,
522	PT1::Underlier: NumCast<PT2::Underlier>,
523{
524	let bigger_val = PT2::from_underlier(val.to_underlier().into());
525
526	let bigger_result = bigger_val.invert_or_zero();
527
528	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
529}
530
531/// Multiply by alpha `PT1` values by upcasting to wider `PT2` type with the same scalar.
532/// This is useful for the cases when SIMD multiply by alpha is faster.
533#[allow(dead_code)]
534pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
535where
536	PT1: PackedField + WithUnderlier,
537	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
538	PT2::Underlier: From<PT1::Underlier>,
539	PT1::Underlier: NumCast<PT2::Underlier>,
540{
541	let bigger_val = PT2::from_underlier(val.to_underlier().into());
542
543	let bigger_result = bigger_val.mul_alpha();
544
545	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
546}