1#![allow(clippy::multiple_bound_locations)]
8
9use std::{
10 fmt::Debug,
11 iter::{Product, Sum},
12 marker::PhantomData,
13 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
14};
15
16use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
17use bytemuck::{Pod, TransparentWrapper, Zeroable};
18use rand::RngCore;
19use subtle::{Choice, ConstantTimeEq};
20
21use super::packed_arithmetic::UnderlierWithBitConstants;
22use crate::{
23 arithmetic_traits::{Broadcast, InvertOrZero, MulAlpha, Square},
24 underlier::{
25 IterationMethods, IterationStrategy, NumCast, UnderlierType, UnderlierWithBitOps,
26 WithUnderlier, U1, U2, U4,
27 },
28 BinaryField, PackedField,
29};
30
31#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
32#[repr(transparent)]
33#[transparent(U)]
34pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
35 pub U,
36 pub PhantomData<Scalar>,
37);
38
39impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
40 pub const WIDTH: usize = {
41 assert!(U::BITS % Scalar::N_BITS == 0);
42
43 U::BITS / Scalar::N_BITS
44 };
45
46 pub const LOG_WIDTH: usize = {
47 let result = Self::WIDTH.ilog2();
48
49 assert!(2usize.pow(result) == Self::WIDTH);
50
51 result as usize
52 };
53
54 #[inline]
55 pub const fn from_underlier(val: U) -> Self {
56 Self(val, PhantomData)
57 }
58
59 #[inline]
60 pub const fn to_underlier(self) -> U {
61 self.0
62 }
63}
64
65unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
66 for PackedPrimitiveType<U, Scalar>
67{
68 type Underlier = U;
69
70 #[inline(always)]
71 fn to_underlier(self) -> Self::Underlier {
72 TransparentWrapper::peel(self)
73 }
74
75 #[inline(always)]
76 fn to_underlier_ref(&self) -> &Self::Underlier {
77 TransparentWrapper::peel_ref(self)
78 }
79
80 #[inline(always)]
81 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
82 TransparentWrapper::peel_mut(self)
83 }
84
85 #[inline(always)]
86 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
87 TransparentWrapper::peel_slice(val)
88 }
89
90 #[inline(always)]
91 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
92 TransparentWrapper::peel_slice_mut(val)
93 }
94
95 #[inline(always)]
96 fn from_underlier(val: Self::Underlier) -> Self {
97 TransparentWrapper::wrap(val)
98 }
99
100 #[inline(always)]
101 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
102 TransparentWrapper::wrap_ref(val)
103 }
104
105 #[inline(always)]
106 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
107 TransparentWrapper::wrap_mut(val)
108 }
109
110 #[inline(always)]
111 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
112 TransparentWrapper::wrap_slice(val)
113 }
114
115 #[inline(always)]
116 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
117 TransparentWrapper::wrap_slice_mut(val)
118 }
119}
120
121impl<U: UnderlierWithBitOps, Scalar: BinaryField> Debug for PackedPrimitiveType<U, Scalar>
122where
123 Self: PackedField,
124{
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 let width = checked_int_div(U::BITS, Scalar::N_BITS);
127 let values_str = self
128 .iter()
129 .map(|value| format!("{}", value))
130 .collect::<Vec<_>>()
131 .join(",");
132
133 write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
134 }
135}
136
137impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
138 #[inline]
139 fn from(val: U) -> Self {
140 Self(val, PhantomData)
141 }
142}
143
144impl<U: UnderlierType, Scalar: BinaryField> ConstantTimeEq for PackedPrimitiveType<U, Scalar> {
145 fn ct_eq(&self, other: &Self) -> Choice {
146 self.0.ct_eq(&other.0)
147 }
148}
149
150impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
151 type Output = Self;
152
153 #[inline]
154 #[allow(clippy::suspicious_arithmetic_impl)]
155 fn add(self, rhs: Self) -> Self::Output {
156 (self.0 ^ rhs.0).into()
157 }
158}
159
160impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
161 type Output = Self;
162
163 #[inline]
164 #[allow(clippy::suspicious_arithmetic_impl)]
165 fn sub(self, rhs: Self) -> Self::Output {
166 (self.0 ^ rhs.0).into()
167 }
168}
169
170impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
171where
172 Self: Add<Output = Self>,
173{
174 fn add_assign(&mut self, rhs: Self) {
175 *self = *self + rhs;
176 }
177}
178
179impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
180where
181 Self: Sub<Output = Self>,
182{
183 fn sub_assign(&mut self, rhs: Self) {
184 *self = *self - rhs;
185 }
186}
187
188impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
189where
190 Self: Mul<Output = Self>,
191{
192 fn mul_assign(&mut self, rhs: Self) {
193 *self = *self * rhs;
194 }
195}
196
197impl<U: UnderlierType, Scalar: BinaryField> Add<Scalar> for PackedPrimitiveType<U, Scalar>
198where
199 Self: Broadcast<Scalar> + Add<Output = Self>,
200{
201 type Output = Self;
202
203 fn add(self, rhs: Scalar) -> Self::Output {
204 self + Self::broadcast(rhs)
205 }
206}
207
208impl<U: UnderlierType, Scalar: BinaryField> Sub<Scalar> for PackedPrimitiveType<U, Scalar>
209where
210 Self: Broadcast<Scalar> + Sub<Output = Self>,
211{
212 type Output = Self;
213
214 fn sub(self, rhs: Scalar) -> Self::Output {
215 self - Self::broadcast(rhs)
216 }
217}
218
219impl<U: UnderlierType, Scalar: BinaryField> Mul<Scalar> for PackedPrimitiveType<U, Scalar>
220where
221 Self: Broadcast<Scalar> + Mul<Output = Self>,
222{
223 type Output = Self;
224
225 fn mul(self, rhs: Scalar) -> Self::Output {
226 self * Self::broadcast(rhs)
227 }
228}
229
230impl<U: UnderlierType, Scalar: BinaryField> AddAssign<Scalar> for PackedPrimitiveType<U, Scalar>
231where
232 Self: Broadcast<Scalar> + AddAssign<Self>,
233{
234 fn add_assign(&mut self, rhs: Scalar) {
235 *self += Self::broadcast(rhs);
236 }
237}
238
239impl<U: UnderlierType, Scalar: BinaryField> SubAssign<Scalar> for PackedPrimitiveType<U, Scalar>
240where
241 Self: Broadcast<Scalar> + SubAssign<Self>,
242{
243 fn sub_assign(&mut self, rhs: Scalar) {
244 *self -= Self::broadcast(rhs);
245 }
246}
247
248impl<U: UnderlierType, Scalar: BinaryField> MulAssign<Scalar> for PackedPrimitiveType<U, Scalar>
249where
250 Self: Broadcast<Scalar> + MulAssign<Self>,
251{
252 fn mul_assign(&mut self, rhs: Scalar) {
253 *self *= Self::broadcast(rhs);
254 }
255}
256
257impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
258where
259 Self: Add<Output = Self>,
260{
261 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
262 iter.fold(Self::from(U::default()), |result, next| result + next)
263 }
264}
265
266impl<U: UnderlierType, Scalar: BinaryField> Product for PackedPrimitiveType<U, Scalar>
267where
268 Self: Broadcast<Scalar> + Mul<Output = Self>,
269{
270 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
271 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
272 }
273}
274
275unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
276 for PackedPrimitiveType<U, Scalar>
277{
278}
279
280unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
281
282impl<U: UnderlierWithBitOps, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
283where
284 Self: Broadcast<Scalar> + Square + InvertOrZero + Mul<Output = Self>,
285 U: UnderlierWithBitConstants + From<Scalar::Underlier> + Send + Sync + 'static,
286 Scalar: BinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
287 Scalar::Underlier: NumCast<U>,
288 IterationMethods<Scalar::Underlier, U>: IterationStrategy<Scalar::Underlier, U>,
289{
290 type Scalar = Scalar;
291
292 const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
293
294 #[inline]
295 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
296 Scalar::from_underlier(self.0.get_subvalue(i))
297 }
298
299 #[inline]
300 unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
301 self.0.set_subvalue(i, scalar.to_underlier());
302 }
303
304 #[inline]
305 fn zero() -> Self {
306 Self::from_underlier(U::ZERO)
307 }
308
309 fn random(rng: impl RngCore) -> Self {
310 U::random(rng).into()
311 }
312
313 #[inline]
314 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
315 IterationMethods::<Scalar::Underlier, U>::ref_iter(&self.0)
316 .map(|underlier| Scalar::from_underlier(underlier))
317 }
318
319 #[inline]
320 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
321 IterationMethods::<Scalar::Underlier, U>::value_iter(self.0)
322 .map(|underlier| Scalar::from_underlier(underlier))
323 }
324
325 #[inline]
326 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
327 IterationMethods::<Scalar::Underlier, U>::slice_iter(Self::to_underliers_ref(slice))
328 .map_skippable(|underlier| Scalar::from_underlier(underlier))
329 }
330
331 #[inline]
332 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
333 assert!(log_block_len < Self::LOG_WIDTH);
334 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
335 let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
336 (c.into(), d.into())
337 }
338
339 #[inline]
340 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
341 assert!(log_block_len < Self::LOG_WIDTH);
342 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
343 let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
344 (c.into(), d.into())
345 }
346
347 #[inline]
348 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
349 debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
350 debug_assert!(
351 block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
352 "{} < {}",
353 block_idx,
354 1 << (Self::LOG_WIDTH - log_block_len)
355 );
356
357 self.0
358 .spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
359 .into()
360 }
361
362 #[inline]
363 fn broadcast(scalar: Self::Scalar) -> Self {
364 <Self as Broadcast<Self::Scalar>>::broadcast(scalar)
365 }
366
367 #[inline]
368 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
369 U::from_fn(move |i| f(i).to_underlier()).into()
370 }
371
372 #[inline]
373 fn square(self) -> Self {
374 <Self as Square>::square(self)
375 }
376
377 #[inline]
378 fn invert_or_zero(self) -> Self {
379 <Self as InvertOrZero>::invert_or_zero(self)
380 }
381}
382
383macro_rules! impl_broadcast {
384 ($name:ty, BinaryField1b) => {
385 impl $crate::arithmetic_traits::Broadcast<BinaryField1b>
386 for PackedPrimitiveType<$name, BinaryField1b>
387 {
388 #[inline]
389 fn broadcast(scalar: BinaryField1b) -> Self {
390 use $crate::underlier::{UnderlierWithBitOps, WithUnderlier};
391
392 <Self as WithUnderlier>::Underlier::fill_with_bit(scalar.0.into()).into()
393 }
394 }
395 };
396 ($name:ty, $scalar_type:ty) => {
397 impl $crate::arithmetic_traits::Broadcast<$scalar_type>
398 for PackedPrimitiveType<$name, $scalar_type>
399 {
400 #[inline]
401 fn broadcast(scalar: $scalar_type) -> Self {
402 let mut value = <$name>::from(scalar.0);
403 #[allow(clippy::reversed_empty_ranges)]
406 for i in <$scalar_type as $crate::binary_field::BinaryField>::N_BITS.ilog2()
407 ..<$name>::BITS.ilog2()
408 {
409 value = value << (1 << i) | value;
410 }
411
412 value.into()
413 }
414 }
415 };
416}
417
418pub(crate) use impl_broadcast;
419
420macro_rules! impl_ops_for_zero_height {
421 ($name:ty) => {
422 impl std::ops::Mul for $name {
423 type Output = Self;
424
425 #[allow(clippy::suspicious_arithmetic_impl)]
426 #[inline]
427 fn mul(self, b: Self) -> Self {
428 crate::tracing::trace_multiplication!($name);
429
430 (self.to_underlier() & b.to_underlier()).into()
431 }
432 }
433
434 impl $crate::arithmetic_traits::MulAlpha for $name {
435 #[inline]
436 fn mul_alpha(self) -> Self {
437 self
438 }
439 }
440
441 impl $crate::arithmetic_traits::Square for $name {
442 #[inline]
443 fn square(self) -> Self {
444 self
445 }
446 }
447
448 impl $crate::arithmetic_traits::InvertOrZero for $name {
449 #[inline]
450 fn invert_or_zero(self) -> Self {
451 self
452 }
453 }
454 };
455}
456
457pub(crate) use impl_ops_for_zero_height;
458
459#[allow(dead_code)]
462pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
463where
464 PT1: PackedField + WithUnderlier,
465 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
466 PT2::Underlier: From<PT1::Underlier>,
467 PT1::Underlier: NumCast<PT2::Underlier>,
468{
469 let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
470 let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
471
472 let bigger_result = bigger_lhs * bigger_rhs;
473
474 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
475}
476
477#[allow(dead_code)]
480pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
481where
482 PT1: PackedField + WithUnderlier,
483 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
484 PT2::Underlier: From<PT1::Underlier>,
485 PT1::Underlier: NumCast<PT2::Underlier>,
486{
487 let bigger_val = PT2::from_underlier(val.to_underlier().into());
488
489 let bigger_result = bigger_val.square();
490
491 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
492}
493
494#[allow(dead_code)]
497pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
498where
499 PT1: PackedField + WithUnderlier,
500 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
501 PT2::Underlier: From<PT1::Underlier>,
502 PT1::Underlier: NumCast<PT2::Underlier>,
503{
504 let bigger_val = PT2::from_underlier(val.to_underlier().into());
505
506 let bigger_result = bigger_val.invert_or_zero();
507
508 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
509}
510
511#[allow(dead_code)]
514pub fn mul_alpha_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
515where
516 PT1: PackedField + WithUnderlier,
517 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier + MulAlpha,
518 PT2::Underlier: From<PT1::Underlier>,
519 PT1::Underlier: NumCast<PT2::Underlier>,
520{
521 let bigger_val = PT2::from_underlier(val.to_underlier().into());
522
523 let bigger_result = bigger_val.mul_alpha();
524
525 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
526}
527
528macro_rules! impl_pack_scalar {
529 ($underlier:ty) => {
530 impl<F> $crate::as_packed_field::PackScalar<F> for $underlier
531 where
532 F: BinaryField,
533 PackedPrimitiveType<$underlier, F>:
534 $crate::packed::PackedField<Scalar = F> + WithUnderlier<Underlier = $underlier>,
535 {
536 type Packed = PackedPrimitiveType<$underlier, F>;
537 }
538 };
539}
540
541pub(crate) use impl_pack_scalar;
542
543impl_pack_scalar!(U1);
544impl_pack_scalar!(U2);
545impl_pack_scalar!(U4);
546impl_pack_scalar!(u8);
547impl_pack_scalar!(u16);
548impl_pack_scalar!(u32);
549impl_pack_scalar!(u64);
550impl_pack_scalar!(u128);