1#![allow(clippy::multiple_bound_locations)]
9
10use std::{
11 fmt::Debug,
12 iter::{Product, Sum},
13 marker::PhantomData,
14 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
15};
16
17use binius_utils::{checked_arithmetics::checked_int_div, iter::IterExtensions};
18use bytemuck::{Pod, TransparentWrapper, Zeroable};
19use rand::{
20 Rng,
21 distr::{Distribution, StandardUniform},
22};
23
24use crate::{
25 BinaryField, Divisible, ExtensionField, Field, PackedField, WideMul,
26 arithmetic_traits::{InvertOrZero, Square},
27 field::FieldOps,
28 underlier::{NumCast, UnderlierType, UnderlierWithBitOps, WithUnderlier},
29};
30
31#[derive(PartialEq, Eq, Clone, Copy, Default, bytemuck::TransparentWrapper)]
32#[repr(transparent)]
33#[transparent(U)]
34pub struct PackedPrimitiveType<U: UnderlierType, Scalar: BinaryField>(
35 pub U,
36 pub PhantomData<Scalar>,
37);
38
39impl<U: UnderlierType, Scalar: BinaryField> PackedPrimitiveType<U, Scalar> {
40 pub const WIDTH: usize = {
41 assert!(U::BITS.is_multiple_of(Scalar::N_BITS));
42
43 U::BITS / Scalar::N_BITS
44 };
45
46 pub const LOG_WIDTH: usize = {
47 let result = Self::WIDTH.ilog2();
48
49 assert!(2usize.pow(result) == Self::WIDTH);
50
51 result as usize
52 };
53
54 #[inline]
55 pub const fn from_underlier(val: U) -> Self {
56 Self(val, PhantomData)
57 }
58
59 #[inline]
60 pub const fn to_underlier(self) -> U {
61 self.0
62 }
63}
64
65impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
66 PackedPrimitiveType<U, Scalar>
67{
68 #[inline]
69 pub fn broadcast(scalar: Scalar) -> Self {
70 U::broadcast_subvalue(scalar.to_underlier()).into()
71 }
72}
73
74unsafe impl<U: UnderlierType, Scalar: BinaryField> WithUnderlier
75 for PackedPrimitiveType<U, Scalar>
76{
77 type Underlier = U;
78
79 #[inline(always)]
80 fn to_underlier(self) -> Self::Underlier {
81 TransparentWrapper::peel(self)
82 }
83
84 #[inline(always)]
85 fn to_underlier_ref(&self) -> &Self::Underlier {
86 TransparentWrapper::peel_ref(self)
87 }
88
89 #[inline(always)]
90 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
91 TransparentWrapper::peel_mut(self)
92 }
93
94 #[inline(always)]
95 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
96 TransparentWrapper::peel_slice(val)
97 }
98
99 #[inline(always)]
100 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
101 TransparentWrapper::peel_slice_mut(val)
102 }
103
104 #[inline(always)]
105 fn from_underlier(val: Self::Underlier) -> Self {
106 TransparentWrapper::wrap(val)
107 }
108
109 #[inline(always)]
110 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
111 TransparentWrapper::wrap_ref(val)
112 }
113
114 #[inline(always)]
115 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
116 TransparentWrapper::wrap_mut(val)
117 }
118
119 #[inline(always)]
120 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
121 TransparentWrapper::wrap_slice(val)
122 }
123
124 #[inline(always)]
125 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
126 TransparentWrapper::wrap_slice_mut(val)
127 }
128}
129
130impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Debug
135 for PackedPrimitiveType<U, Scalar>
136{
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 let width = checked_int_div(U::BITS, Scalar::N_BITS);
139 let values_str = (0..width)
140 .map(|i| Scalar::from_underlier(unsafe { self.0.get_subvalue(i) }))
142 .map(|value| format!("{value}"))
143 .collect::<Vec<_>>()
144 .join(",");
145
146 write!(f, "Packed{}x{}([{}])", width, Scalar::N_BITS, values_str)
147 }
148}
149
150impl<U: UnderlierType, Scalar: BinaryField> From<U> for PackedPrimitiveType<U, Scalar> {
151 #[inline]
152 fn from(val: U) -> Self {
153 Self(val, PhantomData)
154 }
155}
156
157impl<U: UnderlierWithBitOps, Scalar: BinaryField> Neg for PackedPrimitiveType<U, Scalar> {
158 type Output = Self;
159
160 #[inline]
161 fn neg(self) -> Self::Output {
162 self
163 }
164}
165
166impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add for PackedPrimitiveType<U, Scalar> {
167 type Output = Self;
168
169 #[inline]
170 #[allow(clippy::suspicious_arithmetic_impl)]
171 fn add(self, rhs: Self) -> Self::Output {
172 (self.0 ^ rhs.0).into()
173 }
174}
175
176impl<U: UnderlierWithBitOps, Scalar: BinaryField> Add<&Self> for PackedPrimitiveType<U, Scalar> {
177 type Output = Self;
178
179 #[inline]
180 #[allow(clippy::suspicious_arithmetic_impl)]
181 fn add(self, rhs: &Self) -> Self::Output {
182 (self.0 ^ rhs.0).into()
183 }
184}
185
186impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub for PackedPrimitiveType<U, Scalar> {
187 type Output = Self;
188
189 #[inline]
190 #[allow(clippy::suspicious_arithmetic_impl)]
191 fn sub(self, rhs: Self) -> Self::Output {
192 (self.0 ^ rhs.0).into()
193 }
194}
195
196impl<U: UnderlierWithBitOps, Scalar: BinaryField> Sub<&Self> for PackedPrimitiveType<U, Scalar> {
197 type Output = Self;
198
199 #[inline]
200 #[allow(clippy::suspicious_arithmetic_impl)]
201 fn sub(self, rhs: &Self) -> Self::Output {
202 (self.0 ^ rhs.0).into()
203 }
204}
205
206impl<U: UnderlierType, Scalar: BinaryField> AddAssign for PackedPrimitiveType<U, Scalar>
207where
208 Self: Add<Output = Self>,
209{
210 fn add_assign(&mut self, rhs: Self) {
211 *self = *self + rhs;
212 }
213}
214
215impl<U: UnderlierType, Scalar: BinaryField> AddAssign<&Self> for PackedPrimitiveType<U, Scalar>
216where
217 Self: for<'a> Add<&'a Self, Output = Self>,
218{
219 fn add_assign(&mut self, rhs: &Self) {
220 *self = *self + rhs;
221 }
222}
223
224impl<U: UnderlierType, Scalar: BinaryField> SubAssign for PackedPrimitiveType<U, Scalar>
225where
226 Self: Sub<Output = Self>,
227{
228 fn sub_assign(&mut self, rhs: Self) {
229 *self = *self - rhs;
230 }
231}
232
233impl<U: UnderlierType, Scalar: BinaryField> SubAssign<&Self> for PackedPrimitiveType<U, Scalar>
234where
235 Self: for<'a> Sub<&'a Self, Output = Self>,
236{
237 fn sub_assign(&mut self, rhs: &Self) {
238 *self = *self - rhs;
239 }
240}
241
242impl<U: UnderlierType, Scalar: BinaryField> MulAssign for PackedPrimitiveType<U, Scalar>
243where
244 Self: Mul<Output = Self>,
245{
246 fn mul_assign(&mut self, rhs: Self) {
247 *self = *self * rhs;
248 }
249}
250
251impl<U: UnderlierType, Scalar: BinaryField> MulAssign<&Self> for PackedPrimitiveType<U, Scalar>
252where
253 Self: for<'a> Mul<&'a Self, Output = Self>,
254{
255 fn mul_assign(&mut self, rhs: &Self) {
256 *self = *self * rhs;
257 }
258}
259
260impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Add<Scalar>
261 for PackedPrimitiveType<U, Scalar>
262{
263 type Output = Self;
264
265 fn add(self, rhs: Scalar) -> Self::Output {
266 self + Self::broadcast(rhs)
267 }
268}
269
270impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Sub<Scalar>
271 for PackedPrimitiveType<U, Scalar>
272{
273 type Output = Self;
274
275 fn sub(self, rhs: Scalar) -> Self::Output {
276 self - Self::broadcast(rhs)
277 }
278}
279
280impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Mul<Scalar>
281 for PackedPrimitiveType<U, Scalar>
282where
283 Self: Mul<Output = Self>,
284{
285 type Output = Self;
286
287 fn mul(self, rhs: Scalar) -> Self::Output {
288 self * Self::broadcast(rhs)
289 }
290}
291
292impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> AddAssign<Scalar>
293 for PackedPrimitiveType<U, Scalar>
294{
295 fn add_assign(&mut self, rhs: Scalar) {
296 *self += Self::broadcast(rhs);
297 }
298}
299
300impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> SubAssign<Scalar>
301 for PackedPrimitiveType<U, Scalar>
302{
303 fn sub_assign(&mut self, rhs: Scalar) {
304 *self -= Self::broadcast(rhs);
305 }
306}
307
308impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> MulAssign<Scalar>
309 for PackedPrimitiveType<U, Scalar>
310where
311 Self: MulAssign<Self>,
312{
313 fn mul_assign(&mut self, rhs: Scalar) {
314 *self *= Self::broadcast(rhs);
315 }
316}
317
318impl<U: UnderlierType, Scalar: BinaryField> Sum for PackedPrimitiveType<U, Scalar>
319where
320 Self: Add<Output = Self>,
321{
322 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
323 iter.fold(Self::from(U::default()), |result, next| result + next)
324 }
325}
326
327impl<'a, U: UnderlierType, Scalar: BinaryField> Sum<&'a Self> for PackedPrimitiveType<U, Scalar>
328where
329 Self: Add<Output = Self>,
330{
331 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
332 iter.fold(Self::from(U::default()), |result, next| result + *next)
333 }
334}
335
336impl<U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField> Product
337 for PackedPrimitiveType<U, Scalar>
338where
339 Self: Mul<Output = Self>,
340{
341 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
342 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * next)
343 }
344}
345
346impl<'a, U: UnderlierWithBitOps + Divisible<Scalar::Underlier>, Scalar: BinaryField>
347 Product<&'a Self> for PackedPrimitiveType<U, Scalar>
348where
349 Self: Mul<Output = Self>,
350{
351 fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
352 iter.fold(Self::broadcast(Scalar::ONE), |result, next| result * *next)
353 }
354}
355
356impl<U: UnderlierType, Scalar: BinaryField> Mul<&Self> for PackedPrimitiveType<U, Scalar>
357where
358 Self: Mul<Output = Self>,
359{
360 type Output = Self;
361
362 #[inline]
363 fn mul(self, rhs: &Self) -> Self::Output {
364 self * *rhs
365 }
366}
367
368unsafe impl<U: UnderlierType + Zeroable, Scalar: BinaryField> Zeroable
369 for PackedPrimitiveType<U, Scalar>
370{
371}
372
373unsafe impl<U: UnderlierType + Pod, Scalar: BinaryField> Pod for PackedPrimitiveType<U, Scalar> {}
374
375impl<U, Scalar> FieldOps for PackedPrimitiveType<U, Scalar>
376where
377 Self: Square + InvertOrZero + Mul<Output = Self>,
378 U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
379 Scalar: BinaryField,
380{
381 type Scalar = Scalar;
382
383 #[inline]
384 fn zero() -> Self {
385 Self::from_underlier(U::ZERO)
386 }
387
388 #[inline]
389 fn one() -> Self {
390 Self::broadcast(Scalar::ONE)
391 }
392
393 fn square_transpose<FSub: Field>(elems: &mut [Self])
394 where
395 Scalar: ExtensionField<FSub>,
396 {
397 let log_degree = <Scalar as ExtensionField<FSub>>::LOG_DEGREE;
398 let degree = <Scalar as ExtensionField<FSub>>::DEGREE;
399 assert_eq!(elems.len(), degree);
400
401 let log_sub_bits = Scalar::N_BITS.ilog2() as usize - log_degree;
402
403 for i in 0..log_degree {
405 for j in 0..1 << (log_degree - i - 1) {
406 for k in 0..1 << i {
407 let idx0 = (j << (i + 1)) | k;
408 let idx1 = idx0 | (1 << i);
409 let (u0, u1) = elems[idx0].0.interleave(elems[idx1].0, i + log_sub_bits);
410 elems[idx0] = u0.into();
411 elems[idx1] = u1.into();
412 }
413 }
414 }
415 }
416}
417
418impl<U, Scalar> PackedField for PackedPrimitiveType<U, Scalar>
419where
420 Self:
421 Square + InvertOrZero + Mul<Output = Self> + WideMul<Output: Debug + Send + Sync + 'static>,
422 U: UnderlierWithBitOps + Divisible<Scalar::Underlier>,
423 Scalar: BinaryField,
424{
425 const LOG_WIDTH: usize = (U::BITS / Scalar::N_BITS).ilog2() as usize;
426
427 #[inline]
428 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
429 Scalar::from_underlier(unsafe { self.0.get_subvalue(i) })
430 }
431
432 #[inline]
433 unsafe fn set_unchecked(&mut self, i: usize, scalar: Scalar) {
434 unsafe {
435 self.0.set_subvalue(i, scalar.to_underlier());
436 }
437 }
438
439 #[inline]
440 fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
441 Divisible::<Scalar::Underlier>::ref_iter(&self.0)
442 .map(|underlier| Scalar::from_underlier(underlier))
443 }
444
445 #[inline]
446 fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
447 Divisible::<Scalar::Underlier>::value_iter(self.0)
448 .map(|underlier| Scalar::from_underlier(underlier))
449 }
450
451 #[inline]
452 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
453 Divisible::<Scalar::Underlier>::slice_iter(Self::to_underliers_ref(slice))
454 .map_skippable(|underlier| Scalar::from_underlier(underlier))
455 }
456
457 #[inline]
458 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
459 assert!(log_block_len < Self::LOG_WIDTH);
460 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
461 let (c, d) = self.0.interleave(other.0, log_block_len + log_bit_len);
462 (c.into(), d.into())
463 }
464
465 #[inline]
466 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
467 assert!(log_block_len < Self::LOG_WIDTH);
468 let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize;
469 let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len);
470 (c.into(), d.into())
471 }
472
473 #[inline]
474 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
475 debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH);
476 debug_assert!(
477 block_idx < 1 << (Self::LOG_WIDTH - log_block_len),
478 "{} < {}",
479 block_idx,
480 1 << (Self::LOG_WIDTH - log_block_len)
481 );
482
483 unsafe {
484 self.0
485 .spread::<<Self::Scalar as WithUnderlier>::Underlier>(log_block_len, block_idx)
486 .into()
487 }
488 }
489
490 #[inline]
491 fn broadcast(scalar: Self::Scalar) -> Self {
492 Self::broadcast(scalar)
493 }
494
495 #[inline]
496 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
497 U::from_fn(move |i| f(i).to_underlier()).into()
498 }
499}
500
501impl<U: UnderlierType, Scalar: BinaryField> Distribution<PackedPrimitiveType<U, Scalar>>
502 for StandardUniform
503{
504 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PackedPrimitiveType<U, Scalar> {
505 PackedPrimitiveType::from_underlier(U::random(rng))
506 }
507}
508
509#[allow(dead_code)]
512pub fn mul_as_bigger_type<PT1, PT2>(lhs: PT1, rhs: PT1) -> PT1
513where
514 PT1: PackedField + WithUnderlier,
515 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
516 PT2::Underlier: From<PT1::Underlier>,
517 PT1::Underlier: NumCast<PT2::Underlier>,
518{
519 let bigger_lhs = PT2::from_underlier(lhs.to_underlier().into());
520 let bigger_rhs = PT2::from_underlier(rhs.to_underlier().into());
521
522 let bigger_result = bigger_lhs * bigger_rhs;
523
524 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
525}
526
527#[allow(dead_code)]
530pub fn square_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
531where
532 PT1: PackedField + WithUnderlier,
533 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
534 PT2::Underlier: From<PT1::Underlier>,
535 PT1::Underlier: NumCast<PT2::Underlier>,
536{
537 let bigger_val = PT2::from_underlier(val.to_underlier().into());
538
539 let bigger_result = bigger_val.square();
540
541 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
542}
543
544#[allow(dead_code)]
547pub fn invert_as_bigger_type<PT1, PT2>(val: PT1) -> PT1
548where
549 PT1: PackedField + WithUnderlier,
550 PT2: PackedField<Scalar = PT1::Scalar> + WithUnderlier,
551 PT2::Underlier: From<PT1::Underlier>,
552 PT1::Underlier: NumCast<PT2::Underlier>,
553{
554 let bigger_val = PT2::from_underlier(val.to_underlier().into());
555
556 let bigger_result = bigger_val.invert_or_zero();
557
558 PT1::from_underlier(PT1::Underlier::num_cast_from(bigger_result.to_underlier()))
559}