1#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10 fmt::Debug,
11 iter::{Product, Sum},
12 marker::PhantomData,
13 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::{
19 Rng,
20 distr::{Distribution, StandardUniform},
21};
22
23use crate::{
24 BinaryField, Divisible, PackedField,
25 arithmetic_traits::{InvertOrZero, MulAlpha, Square},
26 field::FieldOps,
27 underlier::{NumCast, UnderlierType, UnderlierWithBitOps, WithUnderlier},
28};
29
30#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
31#[repr(transparent)]
32#[transparent(U)]
33pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
34 pub U,
35 pub PhantomData<Scalar>,
36);
37
38impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
39 pub const WIDTH: usize = {
40 assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
41
42 U::BITS / Scalar::N_BITS
43 };
44
45 pub const LOG_WIDTH: usize = {
46 let result = Self::WIDTH.ilog2();
47
48 assert!(2usize.pow(result) == Self::WIDTH);
49
50 result as usize
51 };
52
53 #[inline]
54 pub const fn from_underlier(val: U) -> Self {
55 Self(val, PhantomData)
56 }
57
58 #[inline]
59 pub const fn to_underlier(self) -> U {
60 self.0
61 }
62}
63
64impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
65 PackedPrimitiveType<U, Scalar>
66{
67 #[inline]
68 pub fn broadcast(scalar: Scalar) -> Self {
69 U::broadcast_subvalue(scalar.to_underlier()).into()
70 }
71}
72
73unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
74 for PackedPrimitiveType<U, Scalar>
75{
76 type Underlier = U;
77
78 #[inline(always)]
79 fn to_underlier(self) -> Self::Underlier {
80 TransparentWrapper::peel(self)
81 }
82
83 #[inline(always)]
84 fn to_underlier_ref(&self) -> &Self::Underlier {
85 TransparentWrapper::peel_ref(self)
86 }
87
88 #[inline(always)]
89 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
90 TransparentWrapper::peel_mut(self)
91 }
92
93 #[inline(always)]
94 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
95 TransparentWrapper::peel_slice(val)
96 }
97
98 #[inline(always)]
99 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
100 TransparentWrapper::peel_slice_mut(val)
101 }
102
103 #[inline(always)]
104 fn from_underlier(val: Self::Underlier) -> Self {
105 TransparentWrapper::wrap(val)
106 }
107
108 #[inline(always)]
109 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
110 TransparentWrapper::wrap_ref(val)
111 }
112
113 #[inline(always)]
114 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
115 TransparentWrapper::wrap_mut(val)
116 }
117
118 #[inline(always)]
119 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
120 TransparentWrapper::wrap_slice(val)
121 }
122
123 #[inline(always)]
124 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
125 TransparentWrapper::wrap_slice_mut(val)
126 }
127}
128
129impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
130where
131 Self: PackedField,
132{
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 let width = checked_int_div(U::BITS, Scalar::N_BITS);
135 let values_str = self
136 .iter()
137 .map(|value| format!("{value}"))
138 .collect::<Vec<_>>()
139 .join(",");
140
141 write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
142 }
143}
144
145impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
146 #[inline]
147 fn from(val: U) -> Self {
148 Self(val, PhantomData)
149 }
150}
151
152impl<U: UnderlierWithBitOps, Scalar: BinaryField> Neg for PackedPrimitiveType<U, Scalar> {
153 type Output = Self;
154
155 #[inline]
156 fn neg(self) -> Self::Output {
157 self
158 }
159}
160
161impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
162 type Output = Self;
163
164 #[inline]
165 #[allow(clippy::suspicious_arithmetic_impl)]
166 fn add(self, rhs: Self) -> Self::Output {
167 (self.0 ^ rhs.0).into()
168 }
169}
170
171impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add<&Self> for PackedPrimitiveType<U, Scalar> {
172 type Output = Self;
173
174 #[inline]
175 #[allow(clippy::suspicious_arithmetic_impl)]
176 fn add(self, rhs: &Self) -> Self::Output {
177 (self.0 ^ rhs.0).into()
178 }
179}
180
181impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
182 type Output = Self;
183
184 #[inline]
185 #[allow(clippy::suspicious_arithmetic_impl)]
186 fn sub(self, rhs: Self) -> Self::Output {
187 (self.0 ^ rhs.0).into()
188 }
189}
190
191impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub<&Self> for PackedPrimitiveType<U, Scalar> {
192 type Output = Self;
193
194 #[inline]
195 #[allow(clippy::suspicious_arithmetic_impl)]
196 fn sub(self, rhs: &Self) -> Self::Output {
197 (self.0 ^ rhs.0).into()
198 }
199}
200
201impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
202where
203 Self: Add<Output = Self>,
204{
205 fn add_assign(&mut self, rhs: Self) {
206 *self = *self + rhs;
207 }
208}
209
210impl<U: UnderlierType, Scalar: BinaryField> AddAssign<&Self> for PackedPrimitiveType<U, Scalar>
211where
212 Self: for<'a> Add<&'a Self, Output = Self>,
213{
214 fn add_assign(&mut self, rhs: &Self) {
215 *self = *self + rhs;
216 }
217}
218
219impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
220where
221 Self: Sub<Output = Self>,
222{
223 fn sub_assign(&mut self, rhs: Self) {
224 *self = *self - rhs;
225 }
226}
227
228impl<U: UnderlierType, Scalar: BinaryField> SubAssign<&Self> for PackedPrimitiveType<U, Scalar>
229where
230 Self: for<'a> Sub<&'a Self, Output = Self>,
231{
232 fn sub_assign(&mut self, rhs: &Self) {
233 *self = *self - rhs;
234 }
235}
236
237impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
238where
239 Self: Mul<Output = Self>,
240{
241 fn mul_assign(&mut self, rhs: Self) {
242 *self = *self * rhs;
243 }
244}
245
246impl<U: UnderlierType, Scalar: BinaryField> MulAssign<&Self> for PackedPrimitiveType<U, Scalar>
247where
248 Self: for<'a> Mul<&'a Self, Output = Self>,
249{
250 fn mul_assign(&mut self, rhs: &Self) {
251 *self = *self * rhs;
252 }
253}
254
255impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Add<Scalar>
256 for PackedPrimitiveType<U, Scalar>
257{
258 type Output = Self;
259
260 fn add(self, rhs: Scalar) -> Self::Output {
261 self + Self::broadcast(rhs)
262 }
263}
264
265impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Sub<Scalar>
266 for PackedPrimitiveType<U, Scalar>
267{
268 type Output = Self;
269
270 fn sub(self, rhs: Scalar) -> Self::Output {
271 self - Self::broadcast(rhs)
272 }
273}
274
275impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Mul<Scalar>
276 for PackedPrimitiveType<U, Scalar>
277where
278 Self: Mul<Output = Self>,
279{
280 type Output = Self;
281
282 fn mul(self, rhs: Scalar) -> Self::Output {
283 self * Self::broadcast(rhs)
284 }
285}
286
287impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> AddAssign<Scalar>
288 for PackedPrimitiveType<U, Scalar>
289{
290 fn add_assign(&mut self, rhs: Scalar) {
291 *self += Self::broadcast(rhs);
292 }
293}
294
295impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> SubAssign<Scalar>
296 for PackedPrimitiveType<U, Scalar>
297{
298 fn sub_assign(&mut self, rhs: Scalar) {
299 *self -= Self::broadcast(rhs);
300 }
301}
302
303impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> MulAssign<Scalar>
304 for PackedPrimitiveType<U, Scalar>
305where
306 Self: MulAssign<Self>,
307{
308 fn mul_assign(&mut self, rhs: Scalar) {
309 *self *= Self::broadcast(rhs);
310 }
311}
312
313impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
314where
315 Self: Add<Output = Self>,
316{
317 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
318 iter.fold(Self::from(U::default()), |result, next| result + next)
319 }
320}
321
322impl<'a, U: UnderlierType, Scalar: BinaryField> Sum<&'a Self> for PackedPrimitiveType<U, Scalar>
323where
324 Self: Add<Output = Self>,
325{
326 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
327 iter.fold(Self::from(U::default()), |result, next| result + *next)
328 }
329}
330
331impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product
332 for PackedPrimitiveType<U, Scalar>
333where
334 Self: Mul<Output = Self>,
335{
336 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
337 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
338 }
339}
340
341impl<'a, U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
342 Product<&'a Self> for PackedPrimitiveType<U, Scalar>
343where
344 Self: Mul<Output = Self>,
345{
346 fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
347 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * *next)
348 }
349}
350
351impl<U: UnderlierType, Scalar: BinaryField> Mul<&Self> for PackedPrimitiveType<U, Scalar>
352where
353 Self: Mul<Output = Self>,
354{
355 type Output = Self;
356
357 #[inline]
358 fn mul(self, rhs: &Self) -> Self::Output {
359 self * *rhs
360 }
361}
362
363unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
364 for PackedPrimitiveType<U, Scalar>
365{
366}
367
368unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
369
370impl<U, Scalar> FieldOps<Scalar> for PackedPrimitiveType<U, Scalar>
371where
372 Self: Square + InvertOrZero + Mul<Output = Self>,
373 U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
374 Scalar: BinaryField,
375{
376 #[inline]
377 fn zero() -> Self {
378 Self::from_underlier(U::ZERO)
379 }
380
381 #[inline]
382 fn one() -> Self {
383 Self::broadcast(Scalar::ONE)
384 }
385}
386
387impl<U, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
388where
389 Self: Square + InvertOrZero + Mul<Output = Self>,
390 U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
391 Scalar: BinaryField,
392{
393 type Scalar = Scalar;
394
395 const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
396
397 #[inline]
398 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
399 Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
400 }
401
402 #[inline]
403 unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
404 unsafe {
405 self.0.set_subvalue(i, scalar.to_underlier());
406 }
407 }
408
409 #[inline]
410 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
411 Divisible::<Scalar::Underlier>::ref_iter(&self.0)
412 .map(|underlier| Scalar::from_underlier(underlier))
413 }
414
415 #[inline]
416 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
417 Divisible::<Scalar::Underlier>::value_iter(self.0)
418 .map(|underlier| Scalar::from_underlier(underlier))
419 }
420
421 #[inline]
422 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
423 Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
424 .map_skippable(|underlier| Scalar::from_underlier(underlier))
425 }
426
427 #[inline]
428 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
429 assert!(log_block_len < Self::LOG_WIDTH);
430 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
431 let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
432 (c.into(), d.into())
433 }
434
435 #[inline]
436 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
437 assert!(log_block_len < Self::LOG_WIDTH);
438 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
439 let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
440 (c.into(), d.into())
441 }
442
443 #[inline]
444 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
445 debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
446 debug_assert!(
447 block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
448 "{} < {}",
449 block_idx,
450 1 << (Self::LOG_WIDTH - log_block_len)
451 );
452
453 unsafe {
454 self.0
455 .spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
456 .into()
457 }
458 }
459
460 #[inline]
461 fn broadcast(scalar: Self::Scalar) -> Self {
462 Self::broadcast(scalar)
463 }
464
465 #[inline]
466 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
467 U::from_fn(move |i| f(i).to_underlier()).into()
468 }
469}
470
471impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
472 for StandardUniform
473{
474 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
475 PackedPrimitiveType::from_underlier(U::random(rng))
476 }
477}
478
479#[allow(dead_code)]
482pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
483where
484 PT1: PackedField + WithUnderlier,
485 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
486 PT2::Underlier: From<PT1::Underlier>,
487 PT1::Underlier: NumCast<PT2::Underlier>,
488{
489 let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
490 let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
491
492 let bigger_result = bigger_lhs * bigger_rhs;
493
494 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
495}
496
497#[allow(dead_code)]
500pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
501where
502 PT1: PackedField + WithUnderlier,
503 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
504 PT2::Underlier: From<PT1::Underlier>,
505 PT1::Underlier: NumCast<PT2::Underlier>,
506{
507 let bigger_val = PT2::from_underlier(val.to_underlier().into());
508
509 let bigger_result = bigger_val.square();
510
511 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
512}
513
514#[allow(dead_code)]
517pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
518where
519 PT1: PackedField + WithUnderlier,
520 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
521 PT2::Underlier: From<PT1::Underlier>,
522 PT1::Underlier: NumCast<PT2::Underlier>,
523{
524 let bigger_val = PT2::from_underlier(val.to_underlier().into());
525
526 let bigger_result = bigger_val.invert_or_zero();
527
528 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
529}
530
531#[allow(dead_code)]
534pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
535where
536 PT1: PackedField + WithUnderlier,
537 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
538 PT2::Underlier: From<PT1::Underlier>,
539 PT1::Underlier: NumCast<PT2::Underlier>,
540{
541 let bigger_val = PT2::from_underlier(val.to_underlier().into());
542
543 let bigger_result = bigger_val.mul_alpha();
544
545 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
546}