binius_field/arch/portable/
packed.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3// This is because derive(bytemuck::TransparentWrapper) adds some type constraints to
4// PackedPrimitiveType in addition to the type constraints we define. Even more, annoying, the
5// allow attribute has to be added to the module, it doesn't work to add it to the struct
6// definition.
7#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10	fmt::Debug,
11	iter::{Product, Sum},
12	marker::PhantomData,
13	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::RngCore;
19use subtle::{Choice, ConstantTimeEq};
20
21use super::packed_arithmetic::UnderlierWithBitConstants;
22use crate::{
23	arithmetic_traits::{Broadcast, InvertOrZero, MulAlpha, Square},
24	underlier::{
25		IterationMethods, IterationStrategy, NumCast, UnderlierType, UnderlierWithBitOps,
26		WithUnderlier, U1, U2, U4,
27	},
28	BinaryField, PackedField,
29};
30
31#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
32#[repr(transparent)]
33#[transparent(U)]
34pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
35	pub U,
36	pub PhantomData<Scalar>,
37);
38
39impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
40	pub const WIDTH: usize = {
41		assert!(U::BITS % Scalar::N_BITS == 0);
42
43		U::BITS / Scalar::N_BITS
44	};
45
46	pub const LOG_WIDTH: usize = {
47		let result = Self::WIDTH.ilog2();
48
49		assert!(2usize.pow(result) == Self::WIDTH);
50
51		result as usize
52	};
53
54	#[inline]
55	pub const fn from_underlier(val: U) -> Self {
56		Self(val, PhantomData)
57	}
58
59	#[inline]
60	pub const fn to_underlier(self) -> U {
61		self.0
62	}
63}
64
65unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
66	for PackedPrimitiveType<U, Scalar>
67{
68	type Underlier = U;
69
70	#[inline(always)]
71	fn to_underlier(self) -> Self::Underlier {
72		TransparentWrapper::peel(self)
73	}
74
75	#[inline(always)]
76	fn to_underlier_ref(&self) -> &Self::Underlier {
77		TransparentWrapper::peel_ref(self)
78	}
79
80	#[inline(always)]
81	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
82		TransparentWrapper::peel_mut(self)
83	}
84
85	#[inline(always)]
86	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
87		TransparentWrapper::peel_slice(val)
88	}
89
90	#[inline(always)]
91	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
92		TransparentWrapper::peel_slice_mut(val)
93	}
94
95	#[inline(always)]
96	fn from_underlier(val: Self::Underlier) -> Self {
97		TransparentWrapper::wrap(val)
98	}
99
100	#[inline(always)]
101	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
102		TransparentWrapper::wrap_ref(val)
103	}
104
105	#[inline(always)]
106	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
107		TransparentWrapper::wrap_mut(val)
108	}
109
110	#[inline(always)]
111	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
112		TransparentWrapper::wrap_slice(val)
113	}
114
115	#[inline(always)]
116	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
117		TransparentWrapper::wrap_slice_mut(val)
118	}
119}
120
121impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
122where
123	Self: PackedField,
124{
125	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126		let width = checked_int_div(U::BITS, Scalar::N_BITS);
127		let values_str = self
128			.iter()
129			.map(|value| format!("{}", value))
130			.collect::<Vec<_>>()
131			.join(",");
132
133		write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
134	}
135}
136
137impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
138	#[inline]
139	fn from(val: U) -> Self {
140		Self(val, PhantomData)
141	}
142}
143
144impl<U: UnderlierType, Scalar: BinaryField> ConstantTimeEq for PackedPrimitiveType<U, Scalar> {
145	fn ct_eq(&self, other: &Self) -> Choice {
146		self.0.ct_eq(&other.0)
147	}
148}
149
150impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
151	type Output = Self;
152
153	#[inline]
154	#[allow(clippy::suspicious_arithmetic_impl)]
155	fn add(self, rhs: Self) -> Self::Output {
156		(self.0 ^ rhs.0).into()
157	}
158}
159
160impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
161	type Output = Self;
162
163	#[inline]
164	#[allow(clippy::suspicious_arithmetic_impl)]
165	fn sub(self, rhs: Self) -> Self::Output {
166		(self.0 ^ rhs.0).into()
167	}
168}
169
170impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
171where
172	Self: Add<Output = Self>,
173{
174	fn add_assign(&mut self, rhs: Self) {
175		*self = *self + rhs;
176	}
177}
178
179impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
180where
181	Self: Sub<Output = Self>,
182{
183	fn sub_assign(&mut self, rhs: Self) {
184		*self = *self - rhs;
185	}
186}
187
188impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
189where
190	Self: Mul<Output = Self>,
191{
192	fn mul_assign(&mut self, rhs: Self) {
193		*self = *self * rhs;
194	}
195}
196
197impl<U: UnderlierType, Scalar: BinaryField> Add<Scalar> for PackedPrimitiveType<U, Scalar>
198where
199	Self: Broadcast<Scalar> + Add<Output = Self>,
200{
201	type Output = Self;
202
203	fn add(self, rhs: Scalar) -> Self::Output {
204		self + Self::broadcast(rhs)
205	}
206}
207
208impl<U: UnderlierType, Scalar: BinaryField> Sub<Scalar> for PackedPrimitiveType<U, Scalar>
209where
210	Self: Broadcast<Scalar> + Sub<Output = Self>,
211{
212	type Output = Self;
213
214	fn sub(self, rhs: Scalar) -> Self::Output {
215		self - Self::broadcast(rhs)
216	}
217}
218
219impl<U: UnderlierType, Scalar: BinaryField> Mul<Scalar> for PackedPrimitiveType<U, Scalar>
220where
221	Self: Broadcast<Scalar> + Mul<Output = Self>,
222{
223	type Output = Self;
224
225	fn mul(self, rhs: Scalar) -> Self::Output {
226		self * Self::broadcast(rhs)
227	}
228}
229
230impl<U: UnderlierType, Scalar: BinaryField> AddAssign<Scalar> for PackedPrimitiveType<U, Scalar>
231where
232	Self: Broadcast<Scalar> + AddAssign<Self>,
233{
234	fn add_assign(&mut self, rhs: Scalar) {
235		*self += Self::broadcast(rhs);
236	}
237}
238
239impl<U: UnderlierType, Scalar: BinaryField> SubAssign<Scalar> for PackedPrimitiveType<U, Scalar>
240where
241	Self: Broadcast<Scalar> + SubAssign<Self>,
242{
243	fn sub_assign(&mut self, rhs: Scalar) {
244		*self -= Self::broadcast(rhs);
245	}
246}
247
248impl<U: UnderlierType, Scalar: BinaryField> MulAssign<Scalar> for PackedPrimitiveType<U, Scalar>
249where
250	Self: Broadcast<Scalar> + MulAssign<Self>,
251{
252	fn mul_assign(&mut self, rhs: Scalar) {
253		*self *= Self::broadcast(rhs);
254	}
255}
256
257impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
258where
259	Self: Add<Output = Self>,
260{
261	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
262		iter.fold(Self::from(U::default()), |result, next| result + next)
263	}
264}
265
266impl<U: UnderlierType, Scalar: BinaryField> Product for PackedPrimitiveType<U, Scalar>
267where
268	Self: Broadcast<Scalar> + Mul<Output = Self>,
269{
270	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
271		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
272	}
273}
274
275unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
276	for PackedPrimitiveType<U, Scalar>
277{
278}
279
280unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
281
282impl<U: UnderlierWithBitOps, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
283where
284	Self: Broadcast<Scalar> + Square + InvertOrZero + Mul<Output = Self>,
285	U: UnderlierWithBitConstants + From<Scalar::Underlier> + Send + Sync + 'static,
286	Scalar: BinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
287	Scalar::Underlier: NumCast<U>,
288	IterationMethods<Scalar::Underlier, U>: IterationStrategy<Scalar::Underlier, U>,
289{
290	type Scalar = Scalar;
291
292	const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
293
294	#[inline]
295	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
296		Scalar::from_underlier(self.0.get_subvalue(i))
297	}
298
299	#[inline]
300	unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
301		self.0.set_subvalue(i, scalar.to_underlier());
302	}
303
304	#[inline]
305	fn zero() -> Self {
306		Self::from_underlier(U::ZERO)
307	}
308
309	fn random(rng: impl RngCore) -> Self {
310		U::random(rng).into()
311	}
312
313	#[inline]
314	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
315		IterationMethods::<Scalar::Underlier, U>::ref_iter(&self.0)
316			.map(|underlier| Scalar::from_underlier(underlier))
317	}
318
319	#[inline]
320	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
321		IterationMethods::<Scalar::Underlier, U>::value_iter(self.0)
322			.map(|underlier| Scalar::from_underlier(underlier))
323	}
324
325	#[inline]
326	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
327		IterationMethods::<Scalar::Underlier, U>::slice_iter(Self::to_underliers_ref(slice))
328			.map_skippable(|underlier| Scalar::from_underlier(underlier))
329	}
330
331	#[inline]
332	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
333		assert!(log_block_len < Self::LOG_WIDTH);
334		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
335		let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
336		(c.into(), d.into())
337	}
338
339	#[inline]
340	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
341		assert!(log_block_len < Self::LOG_WIDTH);
342		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
343		let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
344		(c.into(), d.into())
345	}
346
347	#[inline]
348	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
349		debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
350		debug_assert!(
351			block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
352			"{} < {}",
353			block_idx,
354			1 << (Self::LOG_WIDTH - log_block_len)
355		);
356
357		self.0
358			.spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
359			.into()
360	}
361
362	#[inline]
363	fn broadcast(scalar: Self::Scalar) -> Self {
364		<Self as Broadcast<Self::Scalar>>::broadcast(scalar)
365	}
366
367	#[inline]
368	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
369		U::from_fn(move |i| f(i).to_underlier()).into()
370	}
371
372	#[inline]
373	fn square(self) -> Self {
374		<Self as Square>::square(self)
375	}
376
377	#[inline]
378	fn invert_or_zero(self) -> Self {
379		<Self as InvertOrZero>::invert_or_zero(self)
380	}
381}
382
383macro_rules! impl_broadcast {
384	($name:ty, BinaryField1b) => {
385		impl $crate::arithmetic_traits::Broadcast<BinaryField1b>
386			for PackedPrimitiveType<$name, BinaryField1b>
387		{
388			#[inline]
389			fn broadcast(scalar: BinaryField1b) -> Self {
390				use $crate::underlier::{UnderlierWithBitOps, WithUnderlier};
391
392				<Self as WithUnderlier>::Underlier::fill_with_bit(scalar.0.into()).into()
393			}
394		}
395	};
396	($name:ty, $scalar_type:ty) => {
397		impl $crate::arithmetic_traits::Broadcast<$scalar_type>
398			for PackedPrimitiveType<$name, $scalar_type>
399		{
400			#[inline]
401			fn broadcast(scalar: $scalar_type) -> Self {
402				let mut value = <$name>::from(scalar.0);
403				// For PackedBinaryField1x128b, the log bits is 7, so this is
404				// an empty range. This is safe behavior.
405				#[allow(clippy::reversed_empty_ranges)]
406				for i in <$scalar_type as $crate::binary_field::BinaryField>::N_BITS.ilog2()
407					..<$name>::BITS.ilog2()
408				{
409					value = value << (1 << i) | value;
410				}
411
412				value.into()
413			}
414		}
415	};
416}
417
418pub(crate) use impl_broadcast;
419
420macro_rules! impl_ops_for_zero_height {
421	($name:ty) => {
422		impl std::ops::Mul for $name {
423			type Output = Self;
424
425			#[allow(clippy::suspicious_arithmetic_impl)]
426			#[inline]
427			fn mul(self, b: Self) -> Self {
428				crate::tracing::trace_multiplication!($name);
429
430				(self.to_underlier() & b.to_underlier()).into()
431			}
432		}
433
434		impl $crate::arithmetic_traits::MulAlpha for $name {
435			#[inline]
436			fn mul_alpha(self) -> Self {
437				self
438			}
439		}
440
441		impl $crate::arithmetic_traits::Square for $name {
442			#[inline]
443			fn square(self) -> Self {
444				self
445			}
446		}
447
448		impl $crate::arithmetic_traits::InvertOrZero for $name {
449			#[inline]
450			fn invert_or_zero(self) -> Self {
451				self
452			}
453		}
454	};
455}
456
457pub(crate) use impl_ops_for_zero_height;
458
459/// Multiply `PT1` values by upcasting to wider `PT2` type with the same scalar.
460/// This is useful for the cases when SIMD multiplication is faster.
461#[allow(dead_code)]
462pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
463where
464	PT1: PackedField + WithUnderlier,
465	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
466	PT2::Underlier: From<PT1::Underlier>,
467	PT1::Underlier: NumCast<PT2::Underlier>,
468{
469	let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
470	let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
471
472	let bigger_result = bigger_lhs * bigger_rhs;
473
474	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
475}
476
477/// Square `PT1` values by upcasting to wider `PT2` type with the same scalar.
478/// This is useful for the cases when SIMD square is faster.
479#[allow(dead_code)]
480pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
481where
482	PT1: PackedField + WithUnderlier,
483	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
484	PT2::Underlier: From<PT1::Underlier>,
485	PT1::Underlier: NumCast<PT2::Underlier>,
486{
487	let bigger_val = PT2::from_underlier(val.to_underlier().into());
488
489	let bigger_result = bigger_val.square();
490
491	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
492}
493
494/// Invert `PT1` values by upcasting to wider `PT2` type with the same scalar.
495/// This is useful for the cases when SIMD invert is faster.
496#[allow(dead_code)]
497pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
498where
499	PT1: PackedField + WithUnderlier,
500	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
501	PT2::Underlier: From<PT1::Underlier>,
502	PT1::Underlier: NumCast<PT2::Underlier>,
503{
504	let bigger_val = PT2::from_underlier(val.to_underlier().into());
505
506	let bigger_result = bigger_val.invert_or_zero();
507
508	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
509}
510
511/// Multiply by alpha `PT1` values by upcasting to wider `PT2` type with the same scalar.
512/// This is useful for the cases when SIMD multiply by alpha is faster.
513#[allow(dead_code)]
514pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
515where
516	PT1: PackedField + WithUnderlier,
517	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
518	PT2::Underlier: From<PT1::Underlier>,
519	PT1::Underlier: NumCast<PT2::Underlier>,
520{
521	let bigger_val = PT2::from_underlier(val.to_underlier().into());
522
523	let bigger_result = bigger_val.mul_alpha();
524
525	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
526}
527
528macro_rules! impl_pack_scalar {
529	($underlier:ty) => {
530		impl<F> $crate::as_packed_field::PackScalar<F> for $underlier
531		where
532			F: BinaryField,
533			PackedPrimitiveType<$underlier, F>:
534				$crate::packed::PackedField<Scalar = F> + WithUnderlier<Underlier = $underlier>,
535		{
536			type Packed = PackedPrimitiveType<$underlier, F>;
537		}
538	};
539}
540
541pub(crate) use impl_pack_scalar;
542
543impl_pack_scalar!(U1);
544impl_pack_scalar!(U2);
545impl_pack_scalar!(U4);
546impl_pack_scalar!(u8);
547impl_pack_scalar!(u16);
548impl_pack_scalar!(u32);
549impl_pack_scalar!(u64);
550impl_pack_scalar!(u128);