1#![allow(clippy::multiple_bound_locations)]
9
10use std::{
11 fmt::Debug,
12 iter::{Product, Sum},
13 marker::PhantomData,
14 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
15};
16
17use binius_utils::{
18 DeserializeBytes, FixedSizeSerializeBytes, SerializationError, SerializeBytes,
19 bytes::{Buf, BufMut},
20 checked_arithmetics::checked_int_div,
21 iter::IterExtensions,
22};
23use bytemuck::{Pod, TransparentWrapper, Zeroable};
24use rand::{
25 Rng,
26 distr::{Distribution, StandardUniform},
27};
28
29use crate::{
30 BinaryField, Divisible, ExtensionField, Field, Maskable, PackedField, WideMul,
31 arithmetic_traits::{InvertOrZero, Square},
32 field::FieldOps,
33 underlier::{NumCast, UnderlierType, WithUnderlier},
34};
35
36#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
37#[repr(transparent)]
38#[transparent(U)]
39pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
40 pub U,
41 pub PhantomData<Scalar>,
42);
43
44impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
45 pub const WIDTH: usize = {
46 assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
47
48 U::BITS / Scalar::N_BITS
49 };
50
51 pub const LOG_WIDTH: usize = {
52 let result = Self::WIDTH.ilog2();
53
54 assert!(2usize.pow(result) == Self::WIDTH);
55
56 result as usize
57 };
58
59 #[inline]
60 pub const fn from_underlier(val: U) -> Self {
61 Self(val, PhantomData)
62 }
63
64 #[inline]
65 pub const fn to_underlier(self) -> U {
66 self.0
67 }
68}
69
70impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField>
71 PackedPrimitiveType<U, Scalar>
72{
73 #[inline]
74 pub fn broadcast(scalar: Scalar) -> Self {
75 U::broadcast_subvalue(scalar.to_underlier()).into()
76 }
77}
78
79unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
80 for PackedPrimitiveType<U, Scalar>
81{
82 type Underlier = U;
83
84 #[inline(always)]
85 fn to_underlier(self) -> Self::Underlier {
86 TransparentWrapper::peel(self)
87 }
88
89 #[inline(always)]
90 fn to_underlier_ref(&self) -> &Self::Underlier {
91 TransparentWrapper::peel_ref(self)
92 }
93
94 #[inline(always)]
95 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
96 TransparentWrapper::peel_mut(self)
97 }
98
99 #[inline(always)]
100 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
101 TransparentWrapper::peel_slice(val)
102 }
103
104 #[inline(always)]
105 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
106 TransparentWrapper::peel_slice_mut(val)
107 }
108
109 #[inline(always)]
110 fn from_underlier(val: Self::Underlier) -> Self {
111 TransparentWrapper::wrap(val)
112 }
113
114 #[inline(always)]
115 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
116 TransparentWrapper::wrap_ref(val)
117 }
118
119 #[inline(always)]
120 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
121 TransparentWrapper::wrap_mut(val)
122 }
123
124 #[inline(always)]
125 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
126 TransparentWrapper::wrap_slice(val)
127 }
128
129 #[inline(always)]
130 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
131 TransparentWrapper::wrap_slice_mut(val)
132 }
133}
134
135impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> Debug
140 for PackedPrimitiveType<U, Scalar>
141{
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 let width = checked_int_div(U::BITS, Scalar::N_BITS);
144 let values_str = (0..width)
145 .map(|i| {
147 Scalar::from_underlier(unsafe {
148 Divisible::<Scalar::Underlier>::get_unchecked(&self.0, i)
149 })
150 })
151 .map(|value| format!("{value}"))
152 .collect::<Vec<_>>()
153 .join(",");
154
155 write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
156 }
157}
158
159impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
160 #[inline]
161 fn from(val: U) -> Self {
162 Self(val, PhantomData)
163 }
164}
165
166impl<U: UnderlierType + SerializeBytes, Scalar: BinaryField> SerializeBytes
170 for PackedPrimitiveType<U, Scalar>
171{
172 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
173 self.0.serialize(write_buf)
174 }
175}
176
177impl<U: UnderlierType + DeserializeBytes, Scalar: BinaryField> DeserializeBytes
178 for PackedPrimitiveType<U, Scalar>
179{
180 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError> {
181 Ok(Self(U::deserialize(read_buf)?, PhantomData))
182 }
183}
184
185impl<U: UnderlierType + FixedSizeSerializeBytes, Scalar: BinaryField> FixedSizeSerializeBytes
186 for PackedPrimitiveType<U, Scalar>
187{
188 const BYTE_SIZE: usize = U::BYTE_SIZE;
189}
190
191impl<U: UnderlierType, Scalar: BinaryField> Neg for PackedPrimitiveType<U, Scalar> {
192 type Output = Self;
193
194 #[inline]
195 fn neg(self) -> Self::Output {
196 self
197 }
198}
199
200impl<U: UnderlierType, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
201 type Output = Self;
202
203 #[inline]
204 #[allow(clippy::suspicious_arithmetic_impl)]
205 fn add(self, rhs: Self) -> Self::Output {
206 (self.0 ^ rhs.0).into()
207 }
208}
209
210impl<U: UnderlierType, Scalar: BinaryField> Add<&Self> for PackedPrimitiveType<U, Scalar> {
211 type Output = Self;
212
213 #[inline]
214 #[allow(clippy::suspicious_arithmetic_impl)]
215 fn add(self, rhs: &Self) -> Self::Output {
216 (self.0 ^ rhs.0).into()
217 }
218}
219
220impl<U: UnderlierType, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
221 type Output = Self;
222
223 #[inline]
224 #[allow(clippy::suspicious_arithmetic_impl)]
225 fn sub(self, rhs: Self) -> Self::Output {
226 (self.0 ^ rhs.0).into()
227 }
228}
229
230impl<U: UnderlierType, Scalar: BinaryField> Sub<&Self> for PackedPrimitiveType<U, Scalar> {
231 type Output = Self;
232
233 #[inline]
234 #[allow(clippy::suspicious_arithmetic_impl)]
235 fn sub(self, rhs: &Self) -> Self::Output {
236 (self.0 ^ rhs.0).into()
237 }
238}
239
240impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
241where
242 Self: Add<Output = Self>,
243{
244 fn add_assign(&mut self, rhs: Self) {
245 *self = *self + rhs;
246 }
247}
248
249impl<U: UnderlierType, Scalar: BinaryField> AddAssign<&Self> for PackedPrimitiveType<U, Scalar>
250where
251 Self: for<'a> Add<&'a Self, Output = Self>,
252{
253 fn add_assign(&mut self, rhs: &Self) {
254 *self = *self + rhs;
255 }
256}
257
258impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
259where
260 Self: Sub<Output = Self>,
261{
262 fn sub_assign(&mut self, rhs: Self) {
263 *self = *self - rhs;
264 }
265}
266
267impl<U: UnderlierType, Scalar: BinaryField> SubAssign<&Self> for PackedPrimitiveType<U, Scalar>
268where
269 Self: for<'a> Sub<&'a Self, Output = Self>,
270{
271 fn sub_assign(&mut self, rhs: &Self) {
272 *self = *self - rhs;
273 }
274}
275
276impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
277where
278 Self: Mul<Output = Self>,
279{
280 fn mul_assign(&mut self, rhs: Self) {
281 *self = *self * rhs;
282 }
283}
284
285impl<U: UnderlierType, Scalar: BinaryField> MulAssign<&Self> for PackedPrimitiveType<U, Scalar>
286where
287 Self: for<'a> Mul<&'a Self, Output = Self>,
288{
289 fn mul_assign(&mut self, rhs: &Self) {
290 *self = *self * rhs;
291 }
292}
293
294impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> Add<Scalar>
295 for PackedPrimitiveType<U, Scalar>
296{
297 type Output = Self;
298
299 fn add(self, rhs: Scalar) -> Self::Output {
300 self + Self::broadcast(rhs)
301 }
302}
303
304impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> Sub<Scalar>
305 for PackedPrimitiveType<U, Scalar>
306{
307 type Output = Self;
308
309 fn sub(self, rhs: Scalar) -> Self::Output {
310 self - Self::broadcast(rhs)
311 }
312}
313
314impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> Mul<Scalar>
315 for PackedPrimitiveType<U, Scalar>
316where
317 Self: Mul<Output = Self>,
318{
319 type Output = Self;
320
321 fn mul(self, rhs: Scalar) -> Self::Output {
322 self * Self::broadcast(rhs)
323 }
324}
325
326impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> AddAssign<Scalar>
327 for PackedPrimitiveType<U, Scalar>
328{
329 fn add_assign(&mut self, rhs: Scalar) {
330 *self += Self::broadcast(rhs);
331 }
332}
333
334impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> SubAssign<Scalar>
335 for PackedPrimitiveType<U, Scalar>
336{
337 fn sub_assign(&mut self, rhs: Scalar) {
338 *self -= Self::broadcast(rhs);
339 }
340}
341
342impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> MulAssign<Scalar>
343 for PackedPrimitiveType<U, Scalar>
344where
345 Self: MulAssign<Self>,
346{
347 fn mul_assign(&mut self, rhs: Scalar) {
348 *self *= Self::broadcast(rhs);
349 }
350}
351
352impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
353where
354 Self: Add<Output = Self>,
355{
356 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
357 iter.fold(Self::from(U::default()), |result, next| result + next)
358 }
359}
360
361impl<'a, U: UnderlierType, Scalar: BinaryField> Sum<&'a Self> for PackedPrimitiveType<U, Scalar>
362where
363 Self: Add<Output = Self>,
364{
365 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
366 iter.fold(Self::from(U::default()), |result, next| result + *next)
367 }
368}
369
370impl<U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product
371 for PackedPrimitiveType<U, Scalar>
372where
373 Self: Mul<Output = Self>,
374{
375 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
376 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
377 }
378}
379
380impl<'a, U: UnderlierType + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product<&'a Self>
381 for PackedPrimitiveType<U, Scalar>
382where
383 Self: Mul<Output = Self>,
384{
385 fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
386 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * *next)
387 }
388}
389
390impl<U: UnderlierType, Scalar: BinaryField> Mul<&Self> for PackedPrimitiveType<U, Scalar>
391where
392 Self: Mul<Output = Self>,
393{
394 type Output = Self;
395
396 #[inline]
397 fn mul(self, rhs: &Self) -> Self::Output {
398 self * *rhs
399 }
400}
401
402unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
403 for PackedPrimitiveType<U, Scalar>
404{
405}
406
407unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
408
409impl<U, Scalar> FieldOps for PackedPrimitiveType<U, Scalar>
410where
411 Self: Square + InvertOrZero + Mul<Output = Self>,
412 U: UnderlierType + Divisible<Scalar::Underlier>,
413 Scalar: BinaryField,
414{
415 type Scalar = Scalar;
416
417 #[inline]
418 fn zero() -> Self {
419 Self::from_underlier(U::ZERO)
420 }
421
422 #[inline]
423 fn one() -> Self {
424 Self::broadcast(Scalar::ONE)
425 }
426
427 fn square_transpose<FSub: Field>(elems: &mut [Self])
428 where
429 Scalar: ExtensionField<FSub>,
430 {
431 let log_degree = <Scalar as ExtensionField<FSub>>::LOG_DEGREE;
432 let degree = <Scalar as ExtensionField<FSub>>::DEGREE;
433 assert_eq!(elems.len(), degree);
434
435 let log_sub_bits = Scalar::N_BITS.ilog2() as usize - log_degree;
436
437 for i in 0..log_degree {
439 for j in 0..1 << (log_degree - i - 1) {
440 for k in 0..1 << i {
441 let idx0 = (j << (i + 1)) | k;
442 let idx1 = idx0 | (1 << i);
443 let (u0, u1) = elems[idx0].0.interleave(elems[idx1].0, i + log_sub_bits);
444 elems[idx0] = u0.into();
445 elems[idx1] = u1.into();
446 }
447 }
448 }
449 }
450}
451
452impl<U, Scalar> Divisible<Scalar> for PackedPrimitiveType<U, Scalar>
455where
456 U: UnderlierType + Divisible<Scalar::Underlier>,
457 Scalar: BinaryField,
458{
459 const LOG_N: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
460
461 #[inline]
462 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = Scalar> + Send + Clone {
463 Divisible::<Scalar::Underlier>::value_iter(value.0).map(Scalar::from_underlier)
464 }
465
466 #[inline]
467 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = Scalar> + Send + Clone + '_ {
468 Divisible::<Scalar::Underlier>::ref_iter(&value.0).map(Scalar::from_underlier)
469 }
470
471 #[inline]
472 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = Scalar> + Send + Clone + '_ {
473 Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
474 .map(Scalar::from_underlier)
475 }
476
477 #[inline]
478 unsafe fn get_unchecked(&self, index: usize) -> Scalar {
479 Scalar::from_underlier(unsafe {
481 Divisible::<Scalar::Underlier>::get_unchecked(&self.0, index)
482 })
483 }
484
485 #[inline]
486 unsafe fn set_unchecked(&mut self, index: usize, val: Scalar) {
487 unsafe {
489 <U as Divisible<Scalar::Underlier>>::set_unchecked(
490 &mut self.0,
491 index,
492 val.to_underlier(),
493 )
494 };
495 }
496
497 #[inline]
498 fn broadcast(val: Scalar) -> Self {
499 <U as Divisible<Scalar::Underlier>>::broadcast(val.to_underlier()).into()
500 }
501
502 #[inline]
503 fn from_iter(iter: impl Iterator<Item = Scalar>) -> Self {
504 <U as Divisible<Scalar::Underlier>>::from_iter(iter.map(Scalar::to_underlier)).into()
505 }
506}
507
508impl<U, Scalar> Maskable<Scalar> for PackedPrimitiveType<U, Scalar>
511where
512 U: UnderlierType + Divisible<Scalar::Underlier>,
513 Scalar: BinaryField,
514{
515 type Mask = U;
516
517 #[inline]
518 fn make_mask(selectors: impl Iterator<Item = bool>) -> U {
519 <U as Divisible<Scalar::Underlier>>::from_iter(
522 selectors
523 .take(<Self as Divisible<Scalar>>::N)
524 .map(|selected| {
525 <Scalar::Underlier as UnderlierType>::fill_with_bit(u8::from(selected))
526 }),
527 )
528 }
529
530 #[inline]
531 fn select(&self, mask: &U) -> Self {
532 Self::from_underlier(self.to_underlier() & *mask)
533 }
534}
535
536impl<U, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
537where
538 Self:
539 Square + InvertOrZero + Mul<Output = Self> + WideMul<Output: Debug + Send + Sync + 'static>,
540 U: UnderlierType + Divisible<Scalar::Underlier>,
541 Scalar: BinaryField,
542{
543 #[inline]
549 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
550 Divisible::<Scalar::Underlier>::ref_iter(&self.0)
551 .map(|underlier| Scalar::from_underlier(underlier))
552 }
553
554 #[inline]
555 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
556 Divisible::<Scalar::Underlier>::value_iter(self.0)
557 .map(|underlier| Scalar::from_underlier(underlier))
558 }
559
560 #[inline]
561 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
562 Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
563 .map_skippable(|underlier| Scalar::from_underlier(underlier))
564 }
565
566 #[inline]
567 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
568 assert!(log_block_len < Self::LOG_WIDTH);
569 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
570 let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
571 (c.into(), d.into())
572 }
573
574 #[inline]
575 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
576 assert!(log_block_len < Self::LOG_WIDTH);
577 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
578 let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
579 (c.into(), d.into())
580 }
581
582 #[inline]
583 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
584 debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
585 debug_assert!(
586 block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
587 "{} < {}",
588 block_idx,
589 1 << (Self::LOG_WIDTH - log_block_len)
590 );
591
592 unsafe {
593 self.0
594 .spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
595 .into()
596 }
597 }
598
599 #[inline]
600 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
601 U::from_fn(move |i| f(i).to_underlier()).into()
602 }
603}
604
605impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
606 for StandardUniform
607{
608 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
609 PackedPrimitiveType::from_underlier(U::random(rng))
610 }
611}
612
613#[allow(dead_code)]
616pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
617where
618 PT1: PackedField + WithUnderlier,
619 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
620 PT2::Underlier: From<PT1::Underlier>,
621 PT1::Underlier: NumCast<PT2::Underlier>,
622{
623 let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
624 let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
625
626 let bigger_result = bigger_lhs * bigger_rhs;
627
628 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
629}
630
631#[allow(dead_code)]
634pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
635where
636 PT1: PackedField + WithUnderlier,
637 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
638 PT2::Underlier: From<PT1::Underlier>,
639 PT1::Underlier: NumCast<PT2::Underlier>,
640{
641 let bigger_val = PT2::from_underlier(val.to_underlier().into());
642
643 let bigger_result = bigger_val.square();
644
645 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
646}
647
648#[allow(dead_code)]
651pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
652where
653 PT1: PackedField + WithUnderlier,
654 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
655 PT2::Underlier: From<PT1::Underlier>,
656 PT1::Underlier: NumCast<PT2::Underlier>,
657{
658 let bigger_val = PT2::from_underlier(val.to_underlier().into());
659
660 let bigger_result = bigger_val.invert_or_zero();
661
662 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
663}