1use std::{
4 array,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializationMode, SerializeBytes,
11 bytes::{Buf, BufMut},
12 checked_arithmetics::checked_log_2,
13};
14use bytemuck::{Pod, TransparentWrapper, Zeroable};
15use rand::RngCore;
16use subtle::ConstantTimeEq;
17
18use crate::{
19 Field, PackedField,
20 arithmetic_traits::MulAlpha,
21 as_packed_field::PackScalar,
22 linear_transformation::{
23 FieldLinearTransformation, PackedTransformationFactory, Transformation,
24 },
25 packed::PackedBinaryField,
26 underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
27};
28
29#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
33#[repr(transparent)]
34pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
35
36impl<PT, const N: usize> ScaledPackedField<PT, N> {
37 pub const WIDTH_IN_PT: usize = N;
38
39 pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
42 Self(std::array::from_fn(f))
43 }
44
45 #[inline]
48 pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
49 where
50 PT: PackedField,
51 {
52 let log_n = checked_log_2(N);
53 let values = if log_block_len >= PT::LOG_WIDTH {
54 let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
55 let log_packed_block = log_block_len - PT::LOG_WIDTH;
56 let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
57 let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
58 array::from_fn(|i| unsafe {
59 self.0
60 .get_unchecked(offset + (i >> (log_n - log_packed_block)))
61 .spread_unchecked(
62 log_smaller_block,
63 (i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
64 )
65 })
66 } else {
67 let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
68 let log_inner_block_len = log_block_len.saturating_sub(log_n);
69 let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
70 let block_offset = block_offset << (log_block_len - log_inner_block_len);
71
72 array::from_fn(|i| unsafe {
73 self.0.get_unchecked(value_index).spread_unchecked(
74 log_inner_block_len,
75 block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
76 )
77 })
78 };
79
80 Self(values)
81 }
82}
83
84impl<PT, const N: usize> SerializeBytes for ScaledPackedField<PT, N>
85where
86 PT: SerializeBytes,
87{
88 fn serialize(
89 &self,
90 mut write_buf: impl BufMut,
91 mode: SerializationMode,
92 ) -> Result<(), SerializationError> {
93 for elem in &self.0 {
94 elem.serialize(&mut write_buf, mode)?;
95 }
96 Ok(())
97 }
98}
99
100impl<PT, const N: usize> DeserializeBytes for ScaledPackedField<PT, N>
101where
102 PT: DeserializeBytes,
103{
104 fn deserialize(
105 mut read_buf: impl Buf,
106 mode: SerializationMode,
107 ) -> Result<Self, SerializationError> {
108 let mut result = Vec::with_capacity(N);
109 for _ in 0..N {
110 result.push(PT::deserialize(&mut read_buf, mode)?);
111 }
112
113 match result.try_into() {
114 Ok(arr) => Ok(Self(arr)),
115 Err(_) => Err(SerializationError::InvalidConstruction {
116 name: "ScaledPackedField",
117 }),
118 }
119 }
120}
121
122impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
123where
124 [PT; N]: Default,
125{
126 fn default() -> Self {
127 Self(Default::default())
128 }
129}
130
131impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
132where
133 PT: From<U>,
134{
135 fn from(value: [U; N]) -> Self {
136 Self(value.map(Into::into))
137 }
138}
139
140impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
141where
142 U: From<PT>,
143{
144 fn from(value: ScaledPackedField<PT, N>) -> Self {
145 value.0.map(Into::into)
146 }
147}
148
149unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
150
151unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
152
153impl<PT: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledPackedField<PT, N> {
154 fn ct_eq(&self, other: &Self) -> subtle::Choice {
155 self.0.ct_eq(&other.0)
156 }
157}
158
159impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
160where
161 Self: Default,
162{
163 type Output = Self;
164
165 fn add(self, rhs: Self) -> Self {
166 Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
167 }
168}
169
170impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
171where
172 Self: Default,
173{
174 fn add_assign(&mut self, rhs: Self) {
175 for i in 0..N {
176 self.0[i] += rhs.0[i];
177 }
178 }
179}
180
181impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
182where
183 Self: Default,
184{
185 type Output = Self;
186
187 fn sub(self, rhs: Self) -> Self {
188 Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
189 }
190}
191
192impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
193where
194 Self: Default,
195{
196 fn sub_assign(&mut self, rhs: Self) {
197 for i in 0..N {
198 self.0[i] -= rhs.0[i];
199 }
200 }
201}
202
203impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
204where
205 Self: Default,
206{
207 type Output = Self;
208
209 fn mul(self, rhs: Self) -> Self {
210 Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
211 }
212}
213
214impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
215where
216 Self: Default,
217{
218 fn mul_assign(&mut self, rhs: Self) {
219 for i in 0..N {
220 self.0[i] *= rhs.0[i];
221 }
222 }
223}
224
225trait ArithmeticOps<Rhs>:
228 Add<Rhs, Output = Self>
229 + AddAssign<Rhs>
230 + Sub<Rhs, Output = Self>
231 + SubAssign<Rhs>
232 + Mul<Rhs, Output = Self>
233 + MulAssign<Rhs>
234{
235}
236
237impl<T, Rhs> ArithmeticOps<Rhs> for T where
238 T: Add<Rhs, Output = Self>
239 + AddAssign<Rhs>
240 + Sub<Rhs, Output = Self>
241 + SubAssign<Rhs>
242 + Mul<Rhs, Output = Self>
243 + MulAssign<Rhs>
244{
245}
246
247impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
248where
249 Self: Default,
250{
251 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
252 iter.fold(Self::default(), |l, r| l + r)
253 }
254}
255
256impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
257where
258 [PT; N]: Default,
259{
260 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
261 let one = Self([PT::one(); N]);
262 iter.fold(one, |l, r| l * r)
263 }
264}
265
266impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
267where
268 [PT; N]: Default,
269 Self: ArithmeticOps<PT::Scalar>,
270{
271 type Scalar = PT::Scalar;
272
273 const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
274
275 #[inline]
276 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
277 let outer_i = i / PT::WIDTH;
278 let inner_i = i % PT::WIDTH;
279 unsafe { self.0.get_unchecked(outer_i).get_unchecked(inner_i) }
280 }
281
282 #[inline]
283 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
284 let outer_i = i / PT::WIDTH;
285 let inner_i = i % PT::WIDTH;
286 unsafe {
287 self.0
288 .get_unchecked_mut(outer_i)
289 .set_unchecked(inner_i, scalar);
290 }
291 }
292
293 #[inline]
294 fn zero() -> Self {
295 Self(array::from_fn(|_| PT::zero()))
296 }
297
298 fn random(mut rng: impl RngCore) -> Self {
299 Self(array::from_fn(|_| PT::random(&mut rng)))
300 }
301
302 #[inline]
303 fn broadcast(scalar: Self::Scalar) -> Self {
304 Self(array::from_fn(|_| PT::broadcast(scalar)))
305 }
306
307 #[inline]
308 fn square(self) -> Self {
309 Self(self.0.map(|v| v.square()))
310 }
311
312 #[inline]
313 fn invert_or_zero(self) -> Self {
314 Self(self.0.map(|v| v.invert_or_zero()))
315 }
316
317 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
318 let mut first = [Default::default(); N];
319 let mut second = [Default::default(); N];
320
321 if log_block_len >= PT::LOG_WIDTH {
322 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
323 for i in (0..N).step_by(block_in_pts * 2) {
324 first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
325 first[i + block_in_pts..i + 2 * block_in_pts]
326 .copy_from_slice(&other.0[i..i + block_in_pts]);
327
328 second[i..i + block_in_pts]
329 .copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
330 second[i + block_in_pts..i + 2 * block_in_pts]
331 .copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
332 }
333 } else {
334 for i in 0..N {
335 (first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
336 }
337 }
338
339 (Self(first), Self(second))
340 }
341
342 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
343 let mut first = [Default::default(); N];
344 let mut second = [Default::default(); N];
345
346 if log_block_len >= PT::LOG_WIDTH {
347 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
348 for i in (0..N / 2).step_by(block_in_pts) {
349 first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
350
351 second[i..i + block_in_pts]
352 .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
353 }
354
355 for i in (0..N / 2).step_by(block_in_pts) {
356 first[i + N / 2..i + N / 2 + block_in_pts]
357 .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
358
359 second[i + N / 2..i + N / 2 + block_in_pts]
360 .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
361 }
362 } else {
363 for i in 0..N / 2 {
364 (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
365 }
366
367 for i in 0..N / 2 {
368 (first[i + N / 2], second[i + N / 2]) =
369 other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
370 }
371 }
372
373 (Self(first), Self(second))
374 }
375
376 #[inline]
377 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
378 unsafe { Self::spread_unchecked(self, log_block_len, block_idx) }
379 }
380
381 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
382 Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
383 }
384
385 #[inline]
386 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
387 let cast_slice =
389 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
390
391 PT::iter_slice(cast_slice.as_flattened())
392 }
393}
394
395impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
396where
397 [PT; N]: Default,
398{
399 #[inline]
400 fn mul_alpha(self) -> Self {
401 Self(self.0.map(|v| v.mul_alpha()))
402 }
403}
404
405pub struct ScaledTransformation<I> {
407 inner: I,
408}
409
410impl<I> ScaledTransformation<I> {
411 const fn new(inner: I) -> Self {
412 Self { inner }
413 }
414}
415
416impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
417 for ScaledTransformation<I>
418where
419 I: Transformation<IP, OP>,
420{
421 fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
422 ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
423 }
424}
425
426impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
427 for ScaledPackedField<IP, N>
428where
429 Self: PackedBinaryField,
430 ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
431 OP: PackedBinaryField,
432 IP: PackedTransformationFactory<OP>,
433{
434 type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
435 ScaledTransformation<IP::PackedTransformation<Data>>;
436
437 fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
438 transformation: FieldLinearTransformation<
439 <ScaledPackedField<OP, N> as PackedField>::Scalar,
440 Data,
441 >,
442 ) -> Self::PackedTransformation<Data> {
443 ScaledTransformation::new(IP::make_packed_transformation(transformation))
444 }
445}
446
447macro_rules! packed_scaled_field {
452 ($name:ident = [$inner:ty;$size:literal]) => {
453 pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
454
455 impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
456 type Output = Self;
457
458 #[inline]
459 fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
460 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
461 for v in self.0.iter_mut() {
462 *v += broadcast;
463 }
464
465 self
466 }
467 }
468
469 impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
470 #[inline]
471 fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
472 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
473 for v in self.0.iter_mut() {
474 *v += broadcast;
475 }
476 }
477 }
478
479 impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
480 type Output = Self;
481
482 #[inline]
483 fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
484 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
485 for v in self.0.iter_mut() {
486 *v -= broadcast;
487 }
488
489 self
490 }
491 }
492
493 impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
494 #[inline]
495 fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
496 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
497 for v in self.0.iter_mut() {
498 *v -= broadcast;
499 }
500 }
501 }
502
503 impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
504 type Output = Self;
505
506 #[inline]
507 fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
508 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
509 for v in self.0.iter_mut() {
510 *v *= broadcast;
511 }
512
513 self
514 }
515 }
516
517 impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
518 #[inline]
519 fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
520 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
521 for v in self.0.iter_mut() {
522 *v *= broadcast;
523 }
524 }
525 }
526 };
527}
528
529pub(crate) use packed_scaled_field;
530
531unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
532where
533 PT: WithUnderlier<Underlier: Pod>,
534{
535 type Underlier = ScaledUnderlier<PT::Underlier, N>;
536
537 fn to_underlier(self) -> Self::Underlier {
538 TransparentWrapper::peel(self)
539 }
540
541 fn to_underlier_ref(&self) -> &Self::Underlier {
542 TransparentWrapper::peel_ref(self)
543 }
544
545 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
546 TransparentWrapper::peel_mut(self)
547 }
548
549 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
550 TransparentWrapper::peel_slice(val)
551 }
552
553 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
554 TransparentWrapper::peel_slice_mut(val)
555 }
556
557 fn from_underlier(val: Self::Underlier) -> Self {
558 TransparentWrapper::wrap(val)
559 }
560
561 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
562 TransparentWrapper::wrap_ref(val)
563 }
564
565 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
566 TransparentWrapper::wrap_mut(val)
567 }
568
569 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
570 TransparentWrapper::wrap_slice(val)
571 }
572
573 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
574 TransparentWrapper::wrap_slice_mut(val)
575 }
576}
577
578impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
579where
580 U: PackScalar<F> + UnderlierType + Pod,
581 F: Field,
582 ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
583{
584 type Packed = ScaledPackedField<U::Packed, N>;
585}
586
587unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
588 for ScaledPackedField<PT, N>
589where
590 PT: WithUnderlier<Underlier = U>,
591{
592}
593
594#[cfg(test)]
595mod tests {
596 use binius_utils::{SerializationMode, SerializeBytes, bytes::BytesMut};
597 use rand::{Rng, SeedableRng, rngs::StdRng};
598
599 use super::ScaledPackedField;
600 use crate::{PackedBinaryField2x4b, PackedBinaryField4x4b};
601
602 #[test]
603 fn test_equivalent_serialization_between_packed_representations() {
604 let mode = SerializationMode::Native;
605
606 let mut rng = StdRng::seed_from_u64(0);
607
608 let byte_low: u8 = rng.r#gen();
609 let byte_high: u8 = rng.r#gen();
610
611 let combined_underlier = ((byte_high as u16) << 8) | (byte_low as u16);
612
613 let packed = PackedBinaryField4x4b::from_underlier(combined_underlier);
614 let packed_equivalent =
615 ScaledPackedField::<PackedBinaryField2x4b, 2>::from([byte_low, byte_high]);
616
617 let mut buffer_packed = BytesMut::new();
618 let mut buffer_equivalent = BytesMut::new();
619
620 packed.serialize(&mut buffer_packed, mode).unwrap();
621 packed_equivalent
622 .serialize(&mut buffer_equivalent, mode)
623 .unwrap();
624
625 assert_eq!(buffer_packed, buffer_equivalent);
626 }
627}