1use std::{
4 array,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 checked_arithmetics::checked_log_2,
13};
14use bytemuck::{Pod, TransparentWrapper, Zeroable};
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform},
18};
19
20use crate::{
21 Field, PackedField,
22 arithmetic_traits::MulAlpha,
23 as_packed_field::PackScalar,
24 linear_transformation::Transformation,
25 underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
26};
27
28#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
32#[repr(transparent)]
33pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
34
35impl<PT, const N: usize> ScaledPackedField<PT, N> {
36 pub const WIDTH_IN_PT: usize = N;
37
38 pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
41 Self(std::array::from_fn(f))
42 }
43
44 #[inline]
47 pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
48 where
49 PT: PackedField,
50 {
51 let log_n = checked_log_2(N);
52 let values = if log_block_len >= PT::LOG_WIDTH {
53 let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
54 let log_packed_block = log_block_len - PT::LOG_WIDTH;
55 let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
56 let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
57 array::from_fn(|i| unsafe {
58 self.0
59 .get_unchecked(offset + (i >> (log_n - log_packed_block)))
60 .spread_unchecked(
61 log_smaller_block,
62 (i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
63 )
64 })
65 } else {
66 let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
67 let log_inner_block_len = log_block_len.saturating_sub(log_n);
68 let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
69 let block_offset = block_offset << (log_block_len - log_inner_block_len);
70
71 array::from_fn(|i| unsafe {
72 self.0.get_unchecked(value_index).spread_unchecked(
73 log_inner_block_len,
74 block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
75 )
76 })
77 };
78
79 Self(values)
80 }
81}
82
83impl<PT, const N: usize> SerializeBytes for ScaledPackedField<PT, N>
84where
85 PT: SerializeBytes,
86{
87 fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
88 for elem in &self.0 {
89 elem.serialize(&mut write_buf)?;
90 }
91 Ok(())
92 }
93}
94
95impl<PT, const N: usize> DeserializeBytes for ScaledPackedField<PT, N>
96where
97 PT: DeserializeBytes,
98{
99 fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError> {
100 let mut result = Vec::with_capacity(N);
101 for _ in 0..N {
102 result.push(PT::deserialize(&mut read_buf)?);
103 }
104
105 match result.try_into() {
106 Ok(arr) => Ok(Self(arr)),
107 Err(_) => Err(SerializationError::InvalidConstruction {
108 name: "ScaledPackedField",
109 }),
110 }
111 }
112}
113
114impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
115where
116 [PT; N]: Default,
117{
118 fn default() -> Self {
119 Self(Default::default())
120 }
121}
122
123impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
124where
125 PT: From<U>,
126{
127 fn from(value: [U; N]) -> Self {
128 Self(value.map(Into::into))
129 }
130}
131
132impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
133where
134 U: From<PT>,
135{
136 fn from(value: ScaledPackedField<PT, N>) -> Self {
137 value.0.map(Into::into)
138 }
139}
140
141unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
142
143unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
144
145impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
146where
147 Self: Default,
148{
149 type Output = Self;
150
151 fn add(self, rhs: Self) -> Self {
152 Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
153 }
154}
155
156impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
157where
158 Self: Default,
159{
160 fn add_assign(&mut self, rhs: Self) {
161 for i in 0..N {
162 self.0[i] += rhs.0[i];
163 }
164 }
165}
166
167impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
168where
169 Self: Default,
170{
171 type Output = Self;
172
173 fn sub(self, rhs: Self) -> Self {
174 Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
175 }
176}
177
178impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
179where
180 Self: Default,
181{
182 fn sub_assign(&mut self, rhs: Self) {
183 for i in 0..N {
184 self.0[i] -= rhs.0[i];
185 }
186 }
187}
188
189impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
190where
191 Self: Default,
192{
193 type Output = Self;
194
195 fn mul(self, rhs: Self) -> Self {
196 Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
197 }
198}
199
200impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
201where
202 Self: Default,
203{
204 fn mul_assign(&mut self, rhs: Self) {
205 for i in 0..N {
206 self.0[i] *= rhs.0[i];
207 }
208 }
209}
210
211trait ArithmeticOps<Rhs>:
214 Add<Rhs, Output = Self>
215 + AddAssign<Rhs>
216 + Sub<Rhs, Output = Self>
217 + SubAssign<Rhs>
218 + Mul<Rhs, Output = Self>
219 + MulAssign<Rhs>
220{
221}
222
223impl<T, Rhs> ArithmeticOps<Rhs> for T where
224 T: Add<Rhs, Output = Self>
225 + AddAssign<Rhs>
226 + Sub<Rhs, Output = Self>
227 + SubAssign<Rhs>
228 + Mul<Rhs, Output = Self>
229 + MulAssign<Rhs>
230{
231}
232
233impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
234where
235 Self: Default,
236{
237 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
238 iter.fold(Self::default(), |l, r| l + r)
239 }
240}
241
242impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
243where
244 [PT; N]: Default,
245{
246 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
247 let one = Self([PT::one(); N]);
248 iter.fold(one, |l, r| l * r)
249 }
250}
251
252impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
253where
254 [PT; N]: Default,
255 Self: ArithmeticOps<PT::Scalar>,
256{
257 type Scalar = PT::Scalar;
258
259 const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
260
261 #[inline]
262 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
263 let outer_i = i / PT::WIDTH;
264 let inner_i = i % PT::WIDTH;
265 unsafe { self.0.get_unchecked(outer_i).get_unchecked(inner_i) }
266 }
267
268 #[inline]
269 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
270 let outer_i = i / PT::WIDTH;
271 let inner_i = i % PT::WIDTH;
272 unsafe {
273 self.0
274 .get_unchecked_mut(outer_i)
275 .set_unchecked(inner_i, scalar);
276 }
277 }
278
279 #[inline]
280 fn zero() -> Self {
281 Self(array::from_fn(|_| PT::zero()))
282 }
283
284 #[inline]
285 fn broadcast(scalar: Self::Scalar) -> Self {
286 Self(array::from_fn(|_| PT::broadcast(scalar)))
287 }
288
289 #[inline]
290 fn square(self) -> Self {
291 Self(self.0.map(|v| v.square()))
292 }
293
294 #[inline]
295 fn invert_or_zero(self) -> Self {
296 Self(self.0.map(|v| v.invert_or_zero()))
297 }
298
299 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
300 let mut first = [Default::default(); N];
301 let mut second = [Default::default(); N];
302
303 if log_block_len >= PT::LOG_WIDTH {
304 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
305 for i in (0..N).step_by(block_in_pts * 2) {
306 first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
307 first[i + block_in_pts..i + 2 * block_in_pts]
308 .copy_from_slice(&other.0[i..i + block_in_pts]);
309
310 second[i..i + block_in_pts]
311 .copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
312 second[i + block_in_pts..i + 2 * block_in_pts]
313 .copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
314 }
315 } else {
316 for i in 0..N {
317 (first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
318 }
319 }
320
321 (Self(first), Self(second))
322 }
323
324 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
325 let mut first = [Default::default(); N];
326 let mut second = [Default::default(); N];
327
328 if log_block_len >= PT::LOG_WIDTH {
329 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
330 for i in (0..N / 2).step_by(block_in_pts) {
331 first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
332
333 second[i..i + block_in_pts]
334 .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
335 }
336
337 for i in (0..N / 2).step_by(block_in_pts) {
338 first[i + N / 2..i + N / 2 + block_in_pts]
339 .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
340
341 second[i + N / 2..i + N / 2 + block_in_pts]
342 .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
343 }
344 } else {
345 for i in 0..N / 2 {
346 (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
347 }
348
349 for i in 0..N / 2 {
350 (first[i + N / 2], second[i + N / 2]) =
351 other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
352 }
353 }
354
355 (Self(first), Self(second))
356 }
357
358 #[inline]
359 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
360 unsafe { Self::spread_unchecked(self, log_block_len, block_idx) }
361 }
362
363 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
364 Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
365 }
366
367 #[inline]
368 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
369 let cast_slice =
371 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
372
373 PT::iter_slice(cast_slice.as_flattened())
374 }
375}
376
377impl<PT: PackedField, const N: usize> Distribution<ScaledPackedField<PT, N>> for StandardUniform
378where
379 [PT; N]: Default,
380{
381 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledPackedField<PT, N> {
382 ScaledPackedField(array::from_fn(|_| PT::random(&mut rng)))
383 }
384}
385
386impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
387where
388 [PT; N]: Default,
389{
390 #[inline]
391 fn mul_alpha(self) -> Self {
392 Self(self.0.map(|v| v.mul_alpha()))
393 }
394}
395
396pub struct ScaledTransformation<I> {
398 inner: I,
399}
400
401impl<I> ScaledTransformation<I> {
402 const fn new(inner: I) -> Self {
403 Self { inner }
404 }
405}
406
407impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
408 for ScaledTransformation<I>
409where
410 I: Transformation<IP, OP>,
411{
412 fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
413 ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
414 }
415}
416
417macro_rules! packed_scaled_field {
422 ($name:ident = [$inner:ty;$size:literal]) => {
423 pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
424
425 impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
426 type Output = Self;
427
428 #[inline]
429 fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
430 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
431 for v in self.0.iter_mut() {
432 *v += broadcast;
433 }
434
435 self
436 }
437 }
438
439 impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
440 #[inline]
441 fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
442 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
443 for v in self.0.iter_mut() {
444 *v += broadcast;
445 }
446 }
447 }
448
449 impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
450 type Output = Self;
451
452 #[inline]
453 fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
454 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
455 for v in self.0.iter_mut() {
456 *v -= broadcast;
457 }
458
459 self
460 }
461 }
462
463 impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
464 #[inline]
465 fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
466 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
467 for v in self.0.iter_mut() {
468 *v -= broadcast;
469 }
470 }
471 }
472
473 impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
474 type Output = Self;
475
476 #[inline]
477 fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
478 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
479 for v in self.0.iter_mut() {
480 *v *= broadcast;
481 }
482
483 self
484 }
485 }
486
487 impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
488 #[inline]
489 fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
490 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
491 for v in self.0.iter_mut() {
492 *v *= broadcast;
493 }
494 }
495 }
496 };
497}
498
499pub(crate) use packed_scaled_field;
500
501unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
502where
503 PT: WithUnderlier<Underlier: Pod>,
504{
505 type Underlier = ScaledUnderlier<PT::Underlier, N>;
506}
507
508impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
509where
510 U: PackScalar<F> + UnderlierType + Pod,
511 F: Field,
512 ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
513{
514 type Packed = ScaledPackedField<U::Packed, N>;
515}
516
517unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
518 for ScaledPackedField<PT, N>
519where
520 PT: WithUnderlier<Underlier = U>,
521{
522}