1#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10 fmt::Debug,
11 iter::{Product, Sum},
12 marker::PhantomData,
13 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::{
19 Rng,
20 distr::{Distribution, StandardUniform},
21};
22
23use super::packed_arithmetic::UnderlierWithBitConstants;
24use crate::{
25 BinaryField, PackedField,
26 arithmetic_traits::{Broadcast, InvertOrZero, MulAlpha, Square},
27 underlier::{
28 IterationMethods, IterationStrategy, NumCast, U1, U2, U4, UnderlierType,
29 UnderlierWithBitOps, WithUnderlier,
30 },
31};
32
33#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
34#[repr(transparent)]
35#[transparent(U)]
36pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
37 pub U,
38 pub PhantomData<Scalar>,
39);
40
41impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
42 pub const WIDTH: usize = {
43 assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
44
45 U::BITS / Scalar::N_BITS
46 };
47
48 pub const LOG_WIDTH: usize = {
49 let result = Self::WIDTH.ilog2();
50
51 assert!(2usize.pow(result) == Self::WIDTH);
52
53 result as usize
54 };
55
56 #[inline]
57 pub const fn from_underlier(val: U) -> Self {
58 Self(val, PhantomData)
59 }
60
61 #[inline]
62 pub const fn to_underlier(self) -> U {
63 self.0
64 }
65}
66
67unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
68 for PackedPrimitiveType<U, Scalar>
69{
70 type Underlier = U;
71
72 #[inline(always)]
73 fn to_underlier(self) -> Self::Underlier {
74 TransparentWrapper::peel(self)
75 }
76
77 #[inline(always)]
78 fn to_underlier_ref(&self) -> &Self::Underlier {
79 TransparentWrapper::peel_ref(self)
80 }
81
82 #[inline(always)]
83 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
84 TransparentWrapper::peel_mut(self)
85 }
86
87 #[inline(always)]
88 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
89 TransparentWrapper::peel_slice(val)
90 }
91
92 #[inline(always)]
93 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
94 TransparentWrapper::peel_slice_mut(val)
95 }
96
97 #[inline(always)]
98 fn from_underlier(val: Self::Underlier) -> Self {
99 TransparentWrapper::wrap(val)
100 }
101
102 #[inline(always)]
103 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
104 TransparentWrapper::wrap_ref(val)
105 }
106
107 #[inline(always)]
108 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
109 TransparentWrapper::wrap_mut(val)
110 }
111
112 #[inline(always)]
113 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
114 TransparentWrapper::wrap_slice(val)
115 }
116
117 #[inline(always)]
118 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
119 TransparentWrapper::wrap_slice_mut(val)
120 }
121}
122
123impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
124where
125 Self: PackedField,
126{
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 let width = checked_int_div(U::BITS, Scalar::N_BITS);
129 let values_str = self
130 .iter()
131 .map(|value| format!("{value}"))
132 .collect::<Vec<_>>()
133 .join(",");
134
135 write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
136 }
137}
138
139impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
140 #[inline]
141 fn from(val: U) -> Self {
142 Self(val, PhantomData)
143 }
144}
145
146impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
147 type Output = Self;
148
149 #[inline]
150 #[allow(clippy::suspicious_arithmetic_impl)]
151 fn add(self, rhs: Self) -> Self::Output {
152 (self.0 ^ rhs.0).into()
153 }
154}
155
156impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
157 type Output = Self;
158
159 #[inline]
160 #[allow(clippy::suspicious_arithmetic_impl)]
161 fn sub(self, rhs: Self) -> Self::Output {
162 (self.0 ^ rhs.0).into()
163 }
164}
165
166impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
167where
168 Self: Add<Output = Self>,
169{
170 fn add_assign(&mut self, rhs: Self) {
171 *self = *self + rhs;
172 }
173}
174
175impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
176where
177 Self: Sub<Output = Self>,
178{
179 fn sub_assign(&mut self, rhs: Self) {
180 *self = *self - rhs;
181 }
182}
183
184impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
185where
186 Self: Mul<Output = Self>,
187{
188 fn mul_assign(&mut self, rhs: Self) {
189 *self = *self * rhs;
190 }
191}
192
193impl<U: UnderlierType, Scalar: BinaryField> Add<Scalar> for PackedPrimitiveType<U, Scalar>
194where
195 Self: Broadcast<Scalar> + Add<Output = Self>,
196{
197 type Output = Self;
198
199 fn add(self, rhs: Scalar) -> Self::Output {
200 self + Self::broadcast(rhs)
201 }
202}
203
204impl<U: UnderlierType, Scalar: BinaryField> Sub<Scalar> for PackedPrimitiveType<U, Scalar>
205where
206 Self: Broadcast<Scalar> + Sub<Output = Self>,
207{
208 type Output = Self;
209
210 fn sub(self, rhs: Scalar) -> Self::Output {
211 self - Self::broadcast(rhs)
212 }
213}
214
215impl<U: UnderlierType, Scalar: BinaryField> Mul<Scalar> for PackedPrimitiveType<U, Scalar>
216where
217 Self: Broadcast<Scalar> + Mul<Output = Self>,
218{
219 type Output = Self;
220
221 fn mul(self, rhs: Scalar) -> Self::Output {
222 self * Self::broadcast(rhs)
223 }
224}
225
226impl<U: UnderlierType, Scalar: BinaryField> AddAssign<Scalar> for PackedPrimitiveType<U, Scalar>
227where
228 Self: Broadcast<Scalar> + AddAssign<Self>,
229{
230 fn add_assign(&mut self, rhs: Scalar) {
231 *self += Self::broadcast(rhs);
232 }
233}
234
235impl<U: UnderlierType, Scalar: BinaryField> SubAssign<Scalar> for PackedPrimitiveType<U, Scalar>
236where
237 Self: Broadcast<Scalar> + SubAssign<Self>,
238{
239 fn sub_assign(&mut self, rhs: Scalar) {
240 *self -= Self::broadcast(rhs);
241 }
242}
243
244impl<U: UnderlierType, Scalar: BinaryField> MulAssign<Scalar> for PackedPrimitiveType<U, Scalar>
245where
246 Self: Broadcast<Scalar> + MulAssign<Self>,
247{
248 fn mul_assign(&mut self, rhs: Scalar) {
249 *self *= Self::broadcast(rhs);
250 }
251}
252
253impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
254where
255 Self: Add<Output = Self>,
256{
257 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
258 iter.fold(Self::from(U::default()), |result, next| result + next)
259 }
260}
261
262impl<U: UnderlierType, Scalar: BinaryField> Product for PackedPrimitiveType<U, Scalar>
263where
264 Self: Broadcast<Scalar> + Mul<Output = Self>,
265{
266 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
267 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
268 }
269}
270
271unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
272 for PackedPrimitiveType<U, Scalar>
273{
274}
275
276unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
277
278impl<U: UnderlierWithBitOps, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
279where
280 Self: Broadcast<Scalar> + Square + InvertOrZero + Mul<Output = Self>,
281 U: UnderlierWithBitConstants + From<Scalar::Underlier> + Send + Sync + 'static,
282 Scalar: BinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
283 Scalar::Underlier: NumCast<U>,
284 IterationMethods<Scalar::Underlier, U>: IterationStrategy<Scalar::Underlier, U>,
285{
286 type Scalar = Scalar;
287
288 const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
289
290 #[inline]
291 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
292 Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
293 }
294
295 #[inline]
296 unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
297 unsafe {
298 self.0.set_subvalue(i, scalar.to_underlier());
299 }
300 }
301
302 #[inline]
303 fn zero() -> Self {
304 Self::from_underlier(U::ZERO)
305 }
306
307 #[inline]
308 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
309 IterationMethods::<Scalar::Underlier, U>::ref_iter(&self.0)
310 .map(|underlier| Scalar::from_underlier(underlier))
311 }
312
313 #[inline]
314 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
315 IterationMethods::<Scalar::Underlier, U>::value_iter(self.0)
316 .map(|underlier| Scalar::from_underlier(underlier))
317 }
318
319 #[inline]
320 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
321 IterationMethods::<Scalar::Underlier, U>::slice_iter(Self::to_underliers_ref(slice))
322 .map_skippable(|underlier| Scalar::from_underlier(underlier))
323 }
324
325 #[inline]
326 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
327 assert!(log_block_len < Self::LOG_WIDTH);
328 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
329 let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
330 (c.into(), d.into())
331 }
332
333 #[inline]
334 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
335 assert!(log_block_len < Self::LOG_WIDTH);
336 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
337 let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
338 (c.into(), d.into())
339 }
340
341 #[inline]
342 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
343 debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
344 debug_assert!(
345 block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
346 "{} < {}",
347 block_idx,
348 1 << (Self::LOG_WIDTH - log_block_len)
349 );
350
351 unsafe {
352 self.0
353 .spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
354 .into()
355 }
356 }
357
358 #[inline]
359 fn broadcast(scalar: Self::Scalar) -> Self {
360 <Self as Broadcast<Self::Scalar>>::broadcast(scalar)
361 }
362
363 #[inline]
364 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
365 U::from_fn(move |i| f(i).to_underlier()).into()
366 }
367
368 #[inline]
369 fn square(self) -> Self {
370 <Self as Square>::square(self)
371 }
372
373 #[inline]
374 fn invert_or_zero(self) -> Self {
375 <Self as InvertOrZero>::invert_or_zero(self)
376 }
377}
378
379impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
380 for StandardUniform
381{
382 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
383 PackedPrimitiveType::from_underlier(U::random(rng))
384 }
385}
386
387#[allow(dead_code)]
390pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
391where
392 PT1: PackedField + WithUnderlier,
393 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
394 PT2::Underlier: From<PT1::Underlier>,
395 PT1::Underlier: NumCast<PT2::Underlier>,
396{
397 let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
398 let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
399
400 let bigger_result = bigger_lhs * bigger_rhs;
401
402 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
403}
404
405#[allow(dead_code)]
408pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
409where
410 PT1: PackedField + WithUnderlier,
411 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
412 PT2::Underlier: From<PT1::Underlier>,
413 PT1::Underlier: NumCast<PT2::Underlier>,
414{
415 let bigger_val = PT2::from_underlier(val.to_underlier().into());
416
417 let bigger_result = bigger_val.square();
418
419 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
420}
421
422#[allow(dead_code)]
425pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
426where
427 PT1: PackedField + WithUnderlier,
428 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
429 PT2::Underlier: From<PT1::Underlier>,
430 PT1::Underlier: NumCast<PT2::Underlier>,
431{
432 let bigger_val = PT2::from_underlier(val.to_underlier().into());
433
434 let bigger_result = bigger_val.invert_or_zero();
435
436 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
437}
438
439#[allow(dead_code)]
442pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
443where
444 PT1: PackedField + WithUnderlier,
445 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
446 PT2::Underlier: From<PT1::Underlier>,
447 PT1::Underlier: NumCast<PT2::Underlier>,
448{
449 let bigger_val = PT2::from_underlier(val.to_underlier().into());
450
451 let bigger_result = bigger_val.mul_alpha();
452
453 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
454}
455
456macro_rules! impl_pack_scalar {
457 ($underlier:ty) => {
458 impl<F> $crate::as_packed_field::PackScalar<F> for $underlier
459 where
460 F: BinaryField,
461 PackedPrimitiveType<$underlier, F>:
462 $crate::packed::PackedField<Scalar = F> + WithUnderlier<Underlier = $underlier>,
463 {
464 type Packed = PackedPrimitiveType<$underlier, F>;
465 }
466 };
467}
468
469pub(crate) use impl_pack_scalar;
470
471impl_pack_scalar!(U1);
472impl_pack_scalar!(U2);
473impl_pack_scalar!(U4);
474impl_pack_scalar!(u8);
475impl_pack_scalar!(u16);
476impl_pack_scalar!(u32);
477impl_pack_scalar!(u64);
478impl_pack_scalar!(u128);