binius_field/arch/portable/
packed.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3// This is because derive(bytemuck::TransparentWrapper) adds some type constraints to
4// PackedPrimitiveType in addition to the type constraints we define. Even more, annoying, the
5// allow attribute has to be added to the module, it doesn't work to add it to the struct
6// definition.
7#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10	fmt::Debug,
11	iter::{Product, Sum},
12	marker::PhantomData,
13	ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::{
19	Rng,
20	distr::{Distribution, StandardUniform},
21};
22
23use crate::{
24	BinaryField, Divisible, ExtensionField, Field, PackedField,
25	arithmetic_traits::{InvertOrZero, MulAlpha, Square},
26	field::FieldOps,
27	underlier::{NumCast, UnderlierType, UnderlierWithBitOps, WithUnderlier},
28};
29
30#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
31#[repr(transparent)]
32#[transparent(U)]
33pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
34	pub U,
35	pub PhantomData<Scalar>,
36);
37
38impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
39	pub const WIDTH: usize = {
40		assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
41
42		U::BITS / Scalar::N_BITS
43	};
44
45	pub const LOG_WIDTH: usize = {
46		let result = Self::WIDTH.ilog2();
47
48		assert!(2usize.pow(result) == Self::WIDTH);
49
50		result as usize
51	};
52
53	#[inline]
54	pub const fn from_underlier(val: U) -> Self {
55		Self(val, PhantomData)
56	}
57
58	#[inline]
59	pub const fn to_underlier(self) -> U {
60		self.0
61	}
62}
63
64impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
65	PackedPrimitiveType<U, Scalar>
66{
67	#[inline]
68	pub fn broadcast(scalar: Scalar) -> Self {
69		U::broadcast_subvalue(scalar.to_underlier()).into()
70	}
71}
72
73unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
74	for PackedPrimitiveType<U, Scalar>
75{
76	type Underlier = U;
77
78	#[inline(always)]
79	fn to_underlier(self) -> Self::Underlier {
80		TransparentWrapper::peel(self)
81	}
82
83	#[inline(always)]
84	fn to_underlier_ref(&self) -> &Self::Underlier {
85		TransparentWrapper::peel_ref(self)
86	}
87
88	#[inline(always)]
89	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
90		TransparentWrapper::peel_mut(self)
91	}
92
93	#[inline(always)]
94	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
95		TransparentWrapper::peel_slice(val)
96	}
97
98	#[inline(always)]
99	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
100		TransparentWrapper::peel_slice_mut(val)
101	}
102
103	#[inline(always)]
104	fn from_underlier(val: Self::Underlier) -> Self {
105		TransparentWrapper::wrap(val)
106	}
107
108	#[inline(always)]
109	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
110		TransparentWrapper::wrap_ref(val)
111	}
112
113	#[inline(always)]
114	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
115		TransparentWrapper::wrap_mut(val)
116	}
117
118	#[inline(always)]
119	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
120		TransparentWrapper::wrap_slice(val)
121	}
122
123	#[inline(always)]
124	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
125		TransparentWrapper::wrap_slice_mut(val)
126	}
127}
128
129impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
130where
131	Self: PackedField,
132{
133	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134		let width = checked_int_div(U::BITS, Scalar::N_BITS);
135		let values_str = self
136			.iter()
137			.map(|value| format!("{value}"))
138			.collect::<Vec<_>>()
139			.join(",");
140
141		write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
142	}
143}
144
145impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
146	#[inline]
147	fn from(val: U) -> Self {
148		Self(val, PhantomData)
149	}
150}
151
152impl<U: UnderlierWithBitOps, Scalar: BinaryField> Neg for PackedPrimitiveType<U, Scalar> {
153	type Output = Self;
154
155	#[inline]
156	fn neg(self) -> Self::Output {
157		self
158	}
159}
160
161impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
162	type Output = Self;
163
164	#[inline]
165	#[allow(clippy::suspicious_arithmetic_impl)]
166	fn add(self, rhs: Self) -> Self::Output {
167		(self.0 ^ rhs.0).into()
168	}
169}
170
171impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add<&Self> for PackedPrimitiveType<U, Scalar> {
172	type Output = Self;
173
174	#[inline]
175	#[allow(clippy::suspicious_arithmetic_impl)]
176	fn add(self, rhs: &Self) -> Self::Output {
177		(self.0 ^ rhs.0).into()
178	}
179}
180
181impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
182	type Output = Self;
183
184	#[inline]
185	#[allow(clippy::suspicious_arithmetic_impl)]
186	fn sub(self, rhs: Self) -> Self::Output {
187		(self.0 ^ rhs.0).into()
188	}
189}
190
191impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub<&Self> for PackedPrimitiveType<U, Scalar> {
192	type Output = Self;
193
194	#[inline]
195	#[allow(clippy::suspicious_arithmetic_impl)]
196	fn sub(self, rhs: &Self) -> Self::Output {
197		(self.0 ^ rhs.0).into()
198	}
199}
200
201impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
202where
203	Self: Add<Output = Self>,
204{
205	fn add_assign(&mut self, rhs: Self) {
206		*self = *self + rhs;
207	}
208}
209
210impl<U: UnderlierType, Scalar: BinaryField> AddAssign<&Self> for PackedPrimitiveType<U, Scalar>
211where
212	Self: for<'a> Add<&'a Self, Output = Self>,
213{
214	fn add_assign(&mut self, rhs: &Self) {
215		*self = *self + rhs;
216	}
217}
218
219impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
220where
221	Self: Sub<Output = Self>,
222{
223	fn sub_assign(&mut self, rhs: Self) {
224		*self = *self - rhs;
225	}
226}
227
228impl<U: UnderlierType, Scalar: BinaryField> SubAssign<&Self> for PackedPrimitiveType<U, Scalar>
229where
230	Self: for<'a> Sub<&'a Self, Output = Self>,
231{
232	fn sub_assign(&mut self, rhs: &Self) {
233		*self = *self - rhs;
234	}
235}
236
237impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
238where
239	Self: Mul<Output = Self>,
240{
241	fn mul_assign(&mut self, rhs: Self) {
242		*self = *self * rhs;
243	}
244}
245
246impl<U: UnderlierType, Scalar: BinaryField> MulAssign<&Self> for PackedPrimitiveType<U, Scalar>
247where
248	Self: for<'a> Mul<&'a Self, Output = Self>,
249{
250	fn mul_assign(&mut self, rhs: &Self) {
251		*self = *self * rhs;
252	}
253}
254
255impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Add<Scalar>
256	for PackedPrimitiveType<U, Scalar>
257{
258	type Output = Self;
259
260	fn add(self, rhs: Scalar) -> Self::Output {
261		self + Self::broadcast(rhs)
262	}
263}
264
265impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Sub<Scalar>
266	for PackedPrimitiveType<U, Scalar>
267{
268	type Output = Self;
269
270	fn sub(self, rhs: Scalar) -> Self::Output {
271		self - Self::broadcast(rhs)
272	}
273}
274
275impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Mul<Scalar>
276	for PackedPrimitiveType<U, Scalar>
277where
278	Self: Mul<Output = Self>,
279{
280	type Output = Self;
281
282	fn mul(self, rhs: Scalar) -> Self::Output {
283		self * Self::broadcast(rhs)
284	}
285}
286
287impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> AddAssign<Scalar>
288	for PackedPrimitiveType<U, Scalar>
289{
290	fn add_assign(&mut self, rhs: Scalar) {
291		*self += Self::broadcast(rhs);
292	}
293}
294
295impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> SubAssign<Scalar>
296	for PackedPrimitiveType<U, Scalar>
297{
298	fn sub_assign(&mut self, rhs: Scalar) {
299		*self -= Self::broadcast(rhs);
300	}
301}
302
303impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> MulAssign<Scalar>
304	for PackedPrimitiveType<U, Scalar>
305where
306	Self: MulAssign<Self>,
307{
308	fn mul_assign(&mut self, rhs: Scalar) {
309		*self *= Self::broadcast(rhs);
310	}
311}
312
313impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
314where
315	Self: Add<Output = Self>,
316{
317	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
318		iter.fold(Self::from(U::default()), |result, next| result + next)
319	}
320}
321
322impl<'a, U: UnderlierType, Scalar: BinaryField> Sum<&'a Self> for PackedPrimitiveType<U, Scalar>
323where
324	Self: Add<Output = Self>,
325{
326	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
327		iter.fold(Self::from(U::default()), |result, next| result + *next)
328	}
329}
330
331impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product
332	for PackedPrimitiveType<U, Scalar>
333where
334	Self: Mul<Output = Self>,
335{
336	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
337		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
338	}
339}
340
341impl<'a, U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
342	Product<&'a Self> for PackedPrimitiveType<U, Scalar>
343where
344	Self: Mul<Output = Self>,
345{
346	fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
347		iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * *next)
348	}
349}
350
351impl<U: UnderlierType, Scalar: BinaryField> Mul<&Self> for PackedPrimitiveType<U, Scalar>
352where
353	Self: Mul<Output = Self>,
354{
355	type Output = Self;
356
357	#[inline]
358	fn mul(self, rhs: &Self) -> Self::Output {
359		self * *rhs
360	}
361}
362
363unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
364	for PackedPrimitiveType<U, Scalar>
365{
366}
367
368unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
369
370impl<U, Scalar> FieldOps for PackedPrimitiveType<U, Scalar>
371where
372	Self: Square + InvertOrZero + Mul<Output = Self>,
373	U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
374	Scalar: BinaryField,
375{
376	type Scalar = Scalar;
377
378	#[inline]
379	fn zero() -> Self {
380		Self::from_underlier(U::ZERO)
381	}
382
383	#[inline]
384	fn one() -> Self {
385		Self::broadcast(Scalar::ONE)
386	}
387
388	fn square_transpose<FSub: Field>(elems: &mut [Self])
389	where
390		Scalar: ExtensionField<FSub>,
391	{
392		let log_degree = <Scalar as ExtensionField<FSub>>::LOG_DEGREE;
393		let degree = <Scalar as ExtensionField<FSub>>::DEGREE;
394		assert_eq!(elems.len(), degree);
395
396		let log_sub_bits = Scalar::N_BITS.ilog2() as usize - log_degree;
397
398		// See Hacker's Delight, Section 7-3.
399		for i in 0..log_degree {
400			for j in 0..1 << (log_degree - i - 1) {
401				for k in 0..1 << i {
402					let idx0 = (j << (i + 1)) | k;
403					let idx1 = idx0 | (1 << i);
404					let (u0, u1) = elems[idx0].0.interleave(elems[idx1].0, i + log_sub_bits);
405					elems[idx0] = u0.into();
406					elems[idx1] = u1.into();
407				}
408			}
409		}
410	}
411}
412
413impl<U, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
414where
415	Self: Square + InvertOrZero + Mul<Output = Self>,
416	U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
417	Scalar: BinaryField,
418{
419	const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
420
421	#[inline]
422	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
423		Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
424	}
425
426	#[inline]
427	unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
428		unsafe {
429			self.0.set_subvalue(i, scalar.to_underlier());
430		}
431	}
432
433	#[inline]
434	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
435		Divisible::<Scalar::Underlier>::ref_iter(&self.0)
436			.map(|underlier| Scalar::from_underlier(underlier))
437	}
438
439	#[inline]
440	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
441		Divisible::<Scalar::Underlier>::value_iter(self.0)
442			.map(|underlier| Scalar::from_underlier(underlier))
443	}
444
445	#[inline]
446	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
447		Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
448			.map_skippable(|underlier| Scalar::from_underlier(underlier))
449	}
450
451	#[inline]
452	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
453		assert!(log_block_len < Self::LOG_WIDTH);
454		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
455		let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
456		(c.into(), d.into())
457	}
458
459	#[inline]
460	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
461		assert!(log_block_len < Self::LOG_WIDTH);
462		let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
463		let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
464		(c.into(), d.into())
465	}
466
467	#[inline]
468	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
469		debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
470		debug_assert!(
471			block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
472			"{} < {}",
473			block_idx,
474			1 << (Self::LOG_WIDTH - log_block_len)
475		);
476
477		unsafe {
478			self.0
479				.spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
480				.into()
481		}
482	}
483
484	#[inline]
485	fn broadcast(scalar: Self::Scalar) -> Self {
486		Self::broadcast(scalar)
487	}
488
489	#[inline]
490	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
491		U::from_fn(move |i| f(i).to_underlier()).into()
492	}
493}
494
495impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
496	for StandardUniform
497{
498	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
499		PackedPrimitiveType::from_underlier(U::random(rng))
500	}
501}
502
503/// Multiply `PT1` values by upcasting to wider `PT2` type with the same scalar.
504/// This is useful for the cases when SIMD multiplication is faster.
505#[allow(dead_code)]
506pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
507where
508	PT1: PackedField + WithUnderlier,
509	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
510	PT2::Underlier: From<PT1::Underlier>,
511	PT1::Underlier: NumCast<PT2::Underlier>,
512{
513	let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
514	let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
515
516	let bigger_result = bigger_lhs * bigger_rhs;
517
518	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
519}
520
521/// Square `PT1` values by upcasting to wider `PT2` type with the same scalar.
522/// This is useful for the cases when SIMD square is faster.
523#[allow(dead_code)]
524pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
525where
526	PT1: PackedField + WithUnderlier,
527	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
528	PT2::Underlier: From<PT1::Underlier>,
529	PT1::Underlier: NumCast<PT2::Underlier>,
530{
531	let bigger_val = PT2::from_underlier(val.to_underlier().into());
532
533	let bigger_result = bigger_val.square();
534
535	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
536}
537
538/// Invert `PT1` values by upcasting to wider `PT2` type with the same scalar.
539/// This is useful for the cases when SIMD invert is faster.
540#[allow(dead_code)]
541pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
542where
543	PT1: PackedField + WithUnderlier,
544	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
545	PT2::Underlier: From<PT1::Underlier>,
546	PT1::Underlier: NumCast<PT2::Underlier>,
547{
548	let bigger_val = PT2::from_underlier(val.to_underlier().into());
549
550	let bigger_result = bigger_val.invert_or_zero();
551
552	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
553}
554
555/// Multiply by alpha `PT1` values by upcasting to wider `PT2` type with the same scalar.
556/// This is useful for the cases when SIMD multiply by alpha is faster.
557#[allow(dead_code)]
558pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
559where
560	PT1: PackedField + WithUnderlier,
561	PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
562	PT2::Underlier: From<PT1::Underlier>,
563	PT1::Underlier: NumCast<PT2::Underlier>,
564{
565	let bigger_val = PT2::from_underlier(val.to_underlier().into());
566
567	let bigger_result = bigger_val.mul_alpha();
568
569	PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
570}