1use std::{
4 array,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 checked_arithmetics::checked_log_2,
13};
14use bytemuck::{Pod, TransparentWrapper, Zeroable};
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform},
18};
19
20use crate::{
21 Field, PackedField,
22 arithmetic_traits::MulAlpha,
23 as_packed_field::PackScalar,
24 linear_transformation::{
25 FieldLinearTransformation, PackedTransformationFactory, Transformation,
26 },
27 packed::PackedBinaryField,
28 underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
29};
30
31#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
35#[repr(transparent)]
36pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
37
38impl<PT, const N: usize> ScaledPackedField<PT, N> {
39 pub const WIDTH_IN_PT: usize = N;
40
41 pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
44 Self(std::array::from_fn(f))
45 }
46
47 #[inline]
50 pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
51 where
52 PT: PackedField,
53 {
54 let log_n = checked_log_2(N);
55 let values = if log_block_len >= PT::LOG_WIDTH {
56 let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
57 let log_packed_block = log_block_len - PT::LOG_WIDTH;
58 let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
59 let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
60 array::from_fn(|i| unsafe {
61 self.0
62 .get_unchecked(offset + (i >> (log_n - log_packed_block)))
63 .spread_unchecked(
64 log_smaller_block,
65 (i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
66 )
67 })
68 } else {
69 let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
70 let log_inner_block_len = log_block_len.saturating_sub(log_n);
71 let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
72 let block_offset = block_offset << (log_block_len - log_inner_block_len);
73
74 array::from_fn(|i| unsafe {
75 self.0.get_unchecked(value_index).spread_unchecked(
76 log_inner_block_len,
77 block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
78 )
79 })
80 };
81
82 Self(values)
83 }
84}
85
86impl<PT, const N: usize> SerializeBytes for ScaledPackedField<PT, N>
87where
88 PT: SerializeBytes,
89{
90 fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
91 for elem in &self.0 {
92 elem.serialize(&mut write_buf)?;
93 }
94 Ok(())
95 }
96}
97
98impl<PT, const N: usize> DeserializeBytes for ScaledPackedField<PT, N>
99where
100 PT: DeserializeBytes,
101{
102 fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError> {
103 let mut result = Vec::with_capacity(N);
104 for _ in 0..N {
105 result.push(PT::deserialize(&mut read_buf)?);
106 }
107
108 match result.try_into() {
109 Ok(arr) => Ok(Self(arr)),
110 Err(_) => Err(SerializationError::InvalidConstruction {
111 name: "ScaledPackedField",
112 }),
113 }
114 }
115}
116
117impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
118where
119 [PT; N]: Default,
120{
121 fn default() -> Self {
122 Self(Default::default())
123 }
124}
125
126impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
127where
128 PT: From<U>,
129{
130 fn from(value: [U; N]) -> Self {
131 Self(value.map(Into::into))
132 }
133}
134
135impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
136where
137 U: From<PT>,
138{
139 fn from(value: ScaledPackedField<PT, N>) -> Self {
140 value.0.map(Into::into)
141 }
142}
143
144unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
145
146unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
147
148impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
149where
150 Self: Default,
151{
152 type Output = Self;
153
154 fn add(self, rhs: Self) -> Self {
155 Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
156 }
157}
158
159impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
160where
161 Self: Default,
162{
163 fn add_assign(&mut self, rhs: Self) {
164 for i in 0..N {
165 self.0[i] += rhs.0[i];
166 }
167 }
168}
169
170impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
171where
172 Self: Default,
173{
174 type Output = Self;
175
176 fn sub(self, rhs: Self) -> Self {
177 Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
178 }
179}
180
181impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
182where
183 Self: Default,
184{
185 fn sub_assign(&mut self, rhs: Self) {
186 for i in 0..N {
187 self.0[i] -= rhs.0[i];
188 }
189 }
190}
191
192impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
193where
194 Self: Default,
195{
196 type Output = Self;
197
198 fn mul(self, rhs: Self) -> Self {
199 Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
200 }
201}
202
203impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
204where
205 Self: Default,
206{
207 fn mul_assign(&mut self, rhs: Self) {
208 for i in 0..N {
209 self.0[i] *= rhs.0[i];
210 }
211 }
212}
213
214trait ArithmeticOps<Rhs>:
217 Add<Rhs, Output = Self>
218 + AddAssign<Rhs>
219 + Sub<Rhs, Output = Self>
220 + SubAssign<Rhs>
221 + Mul<Rhs, Output = Self>
222 + MulAssign<Rhs>
223{
224}
225
226impl<T, Rhs> ArithmeticOps<Rhs> for T where
227 T: Add<Rhs, Output = Self>
228 + AddAssign<Rhs>
229 + Sub<Rhs, Output = Self>
230 + SubAssign<Rhs>
231 + Mul<Rhs, Output = Self>
232 + MulAssign<Rhs>
233{
234}
235
236impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
237where
238 Self: Default,
239{
240 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
241 iter.fold(Self::default(), |l, r| l + r)
242 }
243}
244
245impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
246where
247 [PT; N]: Default,
248{
249 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
250 let one = Self([PT::one(); N]);
251 iter.fold(one, |l, r| l * r)
252 }
253}
254
255impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
256where
257 [PT; N]: Default,
258 Self: ArithmeticOps<PT::Scalar>,
259{
260 type Scalar = PT::Scalar;
261
262 const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
263
264 #[inline]
265 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
266 let outer_i = i / PT::WIDTH;
267 let inner_i = i % PT::WIDTH;
268 unsafe { self.0.get_unchecked(outer_i).get_unchecked(inner_i) }
269 }
270
271 #[inline]
272 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
273 let outer_i = i / PT::WIDTH;
274 let inner_i = i % PT::WIDTH;
275 unsafe {
276 self.0
277 .get_unchecked_mut(outer_i)
278 .set_unchecked(inner_i, scalar);
279 }
280 }
281
282 #[inline]
283 fn zero() -> Self {
284 Self(array::from_fn(|_| PT::zero()))
285 }
286
287 #[inline]
288 fn broadcast(scalar: Self::Scalar) -> Self {
289 Self(array::from_fn(|_| PT::broadcast(scalar)))
290 }
291
292 #[inline]
293 fn square(self) -> Self {
294 Self(self.0.map(|v| v.square()))
295 }
296
297 #[inline]
298 fn invert_or_zero(self) -> Self {
299 Self(self.0.map(|v| v.invert_or_zero()))
300 }
301
302 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
303 let mut first = [Default::default(); N];
304 let mut second = [Default::default(); N];
305
306 if log_block_len >= PT::LOG_WIDTH {
307 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
308 for i in (0..N).step_by(block_in_pts * 2) {
309 first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
310 first[i + block_in_pts..i + 2 * block_in_pts]
311 .copy_from_slice(&other.0[i..i + block_in_pts]);
312
313 second[i..i + block_in_pts]
314 .copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
315 second[i + block_in_pts..i + 2 * block_in_pts]
316 .copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
317 }
318 } else {
319 for i in 0..N {
320 (first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
321 }
322 }
323
324 (Self(first), Self(second))
325 }
326
327 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
328 let mut first = [Default::default(); N];
329 let mut second = [Default::default(); N];
330
331 if log_block_len >= PT::LOG_WIDTH {
332 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
333 for i in (0..N / 2).step_by(block_in_pts) {
334 first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
335
336 second[i..i + block_in_pts]
337 .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
338 }
339
340 for i in (0..N / 2).step_by(block_in_pts) {
341 first[i + N / 2..i + N / 2 + block_in_pts]
342 .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
343
344 second[i + N / 2..i + N / 2 + block_in_pts]
345 .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
346 }
347 } else {
348 for i in 0..N / 2 {
349 (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
350 }
351
352 for i in 0..N / 2 {
353 (first[i + N / 2], second[i + N / 2]) =
354 other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
355 }
356 }
357
358 (Self(first), Self(second))
359 }
360
361 #[inline]
362 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
363 unsafe { Self::spread_unchecked(self, log_block_len, block_idx) }
364 }
365
366 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
367 Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
368 }
369
370 #[inline]
371 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
372 let cast_slice =
374 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
375
376 PT::iter_slice(cast_slice.as_flattened())
377 }
378}
379
380impl<PT: PackedField, const N: usize> Distribution<ScaledPackedField<PT, N>> for StandardUniform
381where
382 [PT; N]: Default,
383{
384 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledPackedField<PT, N> {
385 ScaledPackedField(array::from_fn(|_| PT::random(&mut rng)))
386 }
387}
388
389impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
390where
391 [PT; N]: Default,
392{
393 #[inline]
394 fn mul_alpha(self) -> Self {
395 Self(self.0.map(|v| v.mul_alpha()))
396 }
397}
398
399pub struct ScaledTransformation<I> {
401 inner: I,
402}
403
404impl<I> ScaledTransformation<I> {
405 const fn new(inner: I) -> Self {
406 Self { inner }
407 }
408}
409
410impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
411 for ScaledTransformation<I>
412where
413 I: Transformation<IP, OP>,
414{
415 fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
416 ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
417 }
418}
419
420impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
421 for ScaledPackedField<IP, N>
422where
423 Self: PackedBinaryField,
424 ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
425 OP: PackedBinaryField,
426 IP: PackedTransformationFactory<OP>,
427{
428 type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
429 ScaledTransformation<IP::PackedTransformation<Data>>;
430
431 fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
432 transformation: FieldLinearTransformation<
433 <ScaledPackedField<OP, N> as PackedField>::Scalar,
434 Data,
435 >,
436 ) -> Self::PackedTransformation<Data> {
437 ScaledTransformation::new(IP::make_packed_transformation(transformation))
438 }
439}
440
441macro_rules! packed_scaled_field {
446 ($name:ident = [$inner:ty;$size:literal]) => {
447 pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
448
449 impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
450 type Output = Self;
451
452 #[inline]
453 fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
454 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
455 for v in self.0.iter_mut() {
456 *v += broadcast;
457 }
458
459 self
460 }
461 }
462
463 impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
464 #[inline]
465 fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
466 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
467 for v in self.0.iter_mut() {
468 *v += broadcast;
469 }
470 }
471 }
472
473 impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
474 type Output = Self;
475
476 #[inline]
477 fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
478 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
479 for v in self.0.iter_mut() {
480 *v -= broadcast;
481 }
482
483 self
484 }
485 }
486
487 impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
488 #[inline]
489 fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
490 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
491 for v in self.0.iter_mut() {
492 *v -= broadcast;
493 }
494 }
495 }
496
497 impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
498 type Output = Self;
499
500 #[inline]
501 fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
502 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
503 for v in self.0.iter_mut() {
504 *v *= broadcast;
505 }
506
507 self
508 }
509 }
510
511 impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
512 #[inline]
513 fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
514 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
515 for v in self.0.iter_mut() {
516 *v *= broadcast;
517 }
518 }
519 }
520 };
521}
522
523pub(crate) use packed_scaled_field;
524
525unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
526where
527 PT: WithUnderlier<Underlier: Pod>,
528{
529 type Underlier = ScaledUnderlier<PT::Underlier, N>;
530
531 fn to_underlier(self) -> Self::Underlier {
532 TransparentWrapper::peel(self)
533 }
534
535 fn to_underlier_ref(&self) -> &Self::Underlier {
536 TransparentWrapper::peel_ref(self)
537 }
538
539 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
540 TransparentWrapper::peel_mut(self)
541 }
542
543 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
544 TransparentWrapper::peel_slice(val)
545 }
546
547 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
548 TransparentWrapper::peel_slice_mut(val)
549 }
550
551 fn from_underlier(val: Self::Underlier) -> Self {
552 TransparentWrapper::wrap(val)
553 }
554
555 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
556 TransparentWrapper::wrap_ref(val)
557 }
558
559 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
560 TransparentWrapper::wrap_mut(val)
561 }
562
563 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
564 TransparentWrapper::wrap_slice(val)
565 }
566
567 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
568 TransparentWrapper::wrap_slice_mut(val)
569 }
570}
571
572impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
573where
574 U: PackScalar<F> + UnderlierType + Pod,
575 F: Field,
576 ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
577{
578 type Packed = ScaledPackedField<U::Packed, N>;
579}
580
581unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
582 for ScaledPackedField<PT, N>
583where
584 PT: WithUnderlier<Underlier = U>,
585{
586}