1use std::{
4 array,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::checked_arithmetics::checked_log_2;
10use bytemuck::{Pod, TransparentWrapper, Zeroable};
11use rand::RngCore;
12use subtle::ConstantTimeEq;
13
14use crate::{
15 arithmetic_traits::MulAlpha,
16 as_packed_field::PackScalar,
17 linear_transformation::{
18 FieldLinearTransformation, PackedTransformationFactory, Transformation,
19 },
20 packed::PackedBinaryField,
21 underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
22 Field, PackedField,
23};
24
25#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
29#[repr(transparent)]
30pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
31
32impl<PT, const N: usize> ScaledPackedField<PT, N> {
33 pub const WIDTH_IN_PT: usize = N;
34
35 pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
37 Self(std::array::from_fn(f))
38 }
39
40 #[inline]
43 pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
44 where
45 PT: PackedField,
46 {
47 let log_n = checked_log_2(N);
48 let values = if log_block_len >= PT::LOG_WIDTH {
49 let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
50 let log_packed_block = log_block_len - PT::LOG_WIDTH;
51 let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
52 let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
53 array::from_fn(|i| {
54 self.0
55 .get_unchecked(offset + (i >> (log_n - log_packed_block)))
56 .spread_unchecked(
57 log_smaller_block,
58 (i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
59 )
60 })
61 } else {
62 let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
63 let log_inner_block_len = log_block_len.saturating_sub(log_n);
64 let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
65 let block_offset = block_offset << (log_block_len - log_inner_block_len);
66
67 array::from_fn(|i| {
68 self.0.get_unchecked(value_index).spread_unchecked(
69 log_inner_block_len,
70 block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
71 )
72 })
73 };
74
75 Self(values)
76 }
77}
78
79impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
80where
81 [PT; N]: Default,
82{
83 fn default() -> Self {
84 Self(Default::default())
85 }
86}
87
88impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
89where
90 PT: From<U>,
91{
92 fn from(value: [U; N]) -> Self {
93 Self(value.map(Into::into))
94 }
95}
96
97impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
98where
99 U: From<PT>,
100{
101 fn from(value: ScaledPackedField<PT, N>) -> Self {
102 value.0.map(Into::into)
103 }
104}
105
106unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
107
108unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
109
110impl<PT: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledPackedField<PT, N> {
111 fn ct_eq(&self, other: &Self) -> subtle::Choice {
112 self.0.ct_eq(&other.0)
113 }
114}
115
116impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
117where
118 Self: Default,
119{
120 type Output = Self;
121
122 fn add(self, rhs: Self) -> Self {
123 Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
124 }
125}
126
127impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
128where
129 Self: Default,
130{
131 fn add_assign(&mut self, rhs: Self) {
132 for i in 0..N {
133 self.0[i] += rhs.0[i];
134 }
135 }
136}
137
138impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
139where
140 Self: Default,
141{
142 type Output = Self;
143
144 fn sub(self, rhs: Self) -> Self {
145 Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
146 }
147}
148
149impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
150where
151 Self: Default,
152{
153 fn sub_assign(&mut self, rhs: Self) {
154 for i in 0..N {
155 self.0[i] -= rhs.0[i];
156 }
157 }
158}
159
160impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
161where
162 Self: Default,
163{
164 type Output = Self;
165
166 fn mul(self, rhs: Self) -> Self {
167 Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
168 }
169}
170
171impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
172where
173 Self: Default,
174{
175 fn mul_assign(&mut self, rhs: Self) {
176 for i in 0..N {
177 self.0[i] *= rhs.0[i];
178 }
179 }
180}
181
182trait ArithmeticOps<Rhs>:
185 Add<Rhs, Output = Self>
186 + AddAssign<Rhs>
187 + Sub<Rhs, Output = Self>
188 + SubAssign<Rhs>
189 + Mul<Rhs, Output = Self>
190 + MulAssign<Rhs>
191{
192}
193
194impl<T, Rhs> ArithmeticOps<Rhs> for T where
195 T: Add<Rhs, Output = Self>
196 + AddAssign<Rhs>
197 + Sub<Rhs, Output = Self>
198 + SubAssign<Rhs>
199 + Mul<Rhs, Output = Self>
200 + MulAssign<Rhs>
201{
202}
203
204impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
205where
206 Self: Default,
207{
208 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
209 iter.fold(Self::default(), |l, r| l + r)
210 }
211}
212
213impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
214where
215 [PT; N]: Default,
216{
217 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
218 let one = Self([PT::one(); N]);
219 iter.fold(one, |l, r| l * r)
220 }
221}
222
223impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
224where
225 [PT; N]: Default,
226 Self: ArithmeticOps<PT::Scalar>,
227{
228 type Scalar = PT::Scalar;
229
230 const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
231
232 #[inline]
233 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
234 let outer_i = i / PT::WIDTH;
235 let inner_i = i % PT::WIDTH;
236 self.0.get_unchecked(outer_i).get_unchecked(inner_i)
237 }
238
239 #[inline]
240 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
241 let outer_i = i / PT::WIDTH;
242 let inner_i = i % PT::WIDTH;
243 self.0
244 .get_unchecked_mut(outer_i)
245 .set_unchecked(inner_i, scalar);
246 }
247
248 #[inline]
249 fn zero() -> Self {
250 Self(array::from_fn(|_| PT::zero()))
251 }
252
253 fn random(mut rng: impl RngCore) -> Self {
254 Self(array::from_fn(|_| PT::random(&mut rng)))
255 }
256
257 #[inline]
258 fn broadcast(scalar: Self::Scalar) -> Self {
259 Self(array::from_fn(|_| PT::broadcast(scalar)))
260 }
261
262 #[inline]
263 fn square(self) -> Self {
264 Self(self.0.map(|v| v.square()))
265 }
266
267 #[inline]
268 fn invert_or_zero(self) -> Self {
269 Self(self.0.map(|v| v.invert_or_zero()))
270 }
271
272 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
273 let mut first = [Default::default(); N];
274 let mut second = [Default::default(); N];
275
276 if log_block_len >= PT::LOG_WIDTH {
277 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
278 for i in (0..N).step_by(block_in_pts * 2) {
279 first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
280 first[i + block_in_pts..i + 2 * block_in_pts]
281 .copy_from_slice(&other.0[i..i + block_in_pts]);
282
283 second[i..i + block_in_pts]
284 .copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
285 second[i + block_in_pts..i + 2 * block_in_pts]
286 .copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
287 }
288 } else {
289 for i in 0..N {
290 (first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
291 }
292 }
293
294 (Self(first), Self(second))
295 }
296
297 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
298 let mut first = [Default::default(); N];
299 let mut second = [Default::default(); N];
300
301 if log_block_len >= PT::LOG_WIDTH {
302 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
303 for i in (0..N / 2).step_by(block_in_pts) {
304 first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
305
306 second[i..i + block_in_pts]
307 .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
308 }
309
310 for i in (0..N / 2).step_by(block_in_pts) {
311 first[i + N / 2..i + N / 2 + block_in_pts]
312 .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
313
314 second[i + N / 2..i + N / 2 + block_in_pts]
315 .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
316 }
317 } else {
318 for i in 0..N / 2 {
319 (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
320 }
321
322 for i in 0..N / 2 {
323 (first[i + N / 2], second[i + N / 2]) =
324 other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
325 }
326 }
327
328 (Self(first), Self(second))
329 }
330
331 #[inline]
332 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
333 Self::spread_unchecked(self, log_block_len, block_idx)
334 }
335
336 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
337 Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
338 }
339
340 #[inline]
341 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
342 let cast_slice =
344 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
345
346 PT::iter_slice(cast_slice.as_flattened())
347 }
348}
349
350impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
351where
352 [PT; N]: Default,
353{
354 #[inline]
355 fn mul_alpha(self) -> Self {
356 Self(self.0.map(|v| v.mul_alpha()))
357 }
358}
359
360pub struct ScaledTransformation<I> {
362 inner: I,
363}
364
365impl<I> ScaledTransformation<I> {
366 const fn new(inner: I) -> Self {
367 Self { inner }
368 }
369}
370
371impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
372 for ScaledTransformation<I>
373where
374 I: Transformation<IP, OP>,
375{
376 fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
377 ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
378 }
379}
380
381impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
382 for ScaledPackedField<IP, N>
383where
384 Self: PackedBinaryField,
385 ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
386 OP: PackedBinaryField,
387 IP: PackedTransformationFactory<OP>,
388{
389 type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
390 ScaledTransformation<IP::PackedTransformation<Data>>;
391
392 fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
393 transformation: FieldLinearTransformation<
394 <ScaledPackedField<OP, N> as PackedField>::Scalar,
395 Data,
396 >,
397 ) -> Self::PackedTransformation<Data> {
398 ScaledTransformation::new(IP::make_packed_transformation(transformation))
399 }
400}
401
402macro_rules! packed_scaled_field {
407 ($name:ident = [$inner:ty;$size:literal]) => {
408 pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
409
410 impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
411 type Output = Self;
412
413 #[inline]
414 fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
415 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
416 for v in self.0.iter_mut() {
417 *v += broadcast;
418 }
419
420 self
421 }
422 }
423
424 impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
425 #[inline]
426 fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
427 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
428 for v in self.0.iter_mut() {
429 *v += broadcast;
430 }
431 }
432 }
433
434 impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
435 type Output = Self;
436
437 #[inline]
438 fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
439 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
440 for v in self.0.iter_mut() {
441 *v -= broadcast;
442 }
443
444 self
445 }
446 }
447
448 impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
449 #[inline]
450 fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
451 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
452 for v in self.0.iter_mut() {
453 *v -= broadcast;
454 }
455 }
456 }
457
458 impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
459 type Output = Self;
460
461 #[inline]
462 fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
463 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
464 for v in self.0.iter_mut() {
465 *v *= broadcast;
466 }
467
468 self
469 }
470 }
471
472 impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
473 #[inline]
474 fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
475 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
476 for v in self.0.iter_mut() {
477 *v *= broadcast;
478 }
479 }
480 }
481 };
482}
483
484pub(crate) use packed_scaled_field;
485
486unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
487where
488 PT: WithUnderlier<Underlier: Pod>,
489{
490 type Underlier = ScaledUnderlier<PT::Underlier, N>;
491
492 fn to_underlier(self) -> Self::Underlier {
493 TransparentWrapper::peel(self)
494 }
495
496 fn to_underlier_ref(&self) -> &Self::Underlier {
497 TransparentWrapper::peel_ref(self)
498 }
499
500 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
501 TransparentWrapper::peel_mut(self)
502 }
503
504 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
505 TransparentWrapper::peel_slice(val)
506 }
507
508 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
509 TransparentWrapper::peel_slice_mut(val)
510 }
511
512 fn from_underlier(val: Self::Underlier) -> Self {
513 TransparentWrapper::wrap(val)
514 }
515
516 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
517 TransparentWrapper::wrap_ref(val)
518 }
519
520 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
521 TransparentWrapper::wrap_mut(val)
522 }
523
524 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
525 TransparentWrapper::wrap_slice(val)
526 }
527
528 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
529 TransparentWrapper::wrap_slice_mut(val)
530 }
531}
532
533impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
534where
535 U: PackScalar<F> + UnderlierType + Pod,
536 F: Field,
537 ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
538{
539 type Packed = ScaledPackedField<U::Packed, N>;
540}
541
542unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
543 for ScaledPackedField<PT, N>
544where
545 PT: WithUnderlier<Underlier = U>,
546{
547}