1use std::{
4 array,
5 iter::{Product, Sum},
6 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::checked_arithmetics::checked_log_2;
10use bytemuck::{Pod, TransparentWrapper, Zeroable};
11use rand::RngCore;
12use subtle::ConstantTimeEq;
13
14use crate::{
15 arithmetic_traits::MulAlpha,
16 as_packed_field::PackScalar,
17 linear_transformation::{
18 FieldLinearTransformation, PackedTransformationFactory, Transformation,
19 },
20 packed::PackedBinaryField,
21 underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
22 Field, PackedField,
23};
24
25#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
29#[repr(transparent)]
30pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
31
32impl<PT, const N: usize> ScaledPackedField<PT, N> {
33 pub const WIDTH_IN_PT: usize = N;
34
35 pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
37 Self(std::array::from_fn(f))
38 }
39}
40
41impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
42where
43 [PT; N]: Default,
44{
45 fn default() -> Self {
46 Self(Default::default())
47 }
48}
49
50impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
51where
52 PT: From<U>,
53{
54 fn from(value: [U; N]) -> Self {
55 Self(value.map(Into::into))
56 }
57}
58
59impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
60where
61 U: From<PT>,
62{
63 fn from(value: ScaledPackedField<PT, N>) -> Self {
64 value.0.map(Into::into)
65 }
66}
67
68unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
69
70unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
71
72impl<PT: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledPackedField<PT, N> {
73 fn ct_eq(&self, other: &Self) -> subtle::Choice {
74 self.0.ct_eq(&other.0)
75 }
76}
77
78impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
79where
80 Self: Default,
81{
82 type Output = Self;
83
84 fn add(self, rhs: Self) -> Self {
85 Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
86 }
87}
88
89impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
90where
91 Self: Default,
92{
93 fn add_assign(&mut self, rhs: Self) {
94 for i in 0..N {
95 self.0[i] += rhs.0[i];
96 }
97 }
98}
99
100impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
101where
102 Self: Default,
103{
104 type Output = Self;
105
106 fn sub(self, rhs: Self) -> Self {
107 Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
108 }
109}
110
111impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
112where
113 Self: Default,
114{
115 fn sub_assign(&mut self, rhs: Self) {
116 for i in 0..N {
117 self.0[i] -= rhs.0[i];
118 }
119 }
120}
121
122impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
123where
124 Self: Default,
125{
126 type Output = Self;
127
128 fn mul(self, rhs: Self) -> Self {
129 Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
130 }
131}
132
133impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
134where
135 Self: Default,
136{
137 fn mul_assign(&mut self, rhs: Self) {
138 for i in 0..N {
139 self.0[i] *= rhs.0[i];
140 }
141 }
142}
143
144trait ArithmeticOps<Rhs>:
147 Add<Rhs, Output = Self>
148 + AddAssign<Rhs>
149 + Sub<Rhs, Output = Self>
150 + SubAssign<Rhs>
151 + Mul<Rhs, Output = Self>
152 + MulAssign<Rhs>
153{
154}
155
156impl<T, Rhs> ArithmeticOps<Rhs> for T where
157 T: Add<Rhs, Output = Self>
158 + AddAssign<Rhs>
159 + Sub<Rhs, Output = Self>
160 + SubAssign<Rhs>
161 + Mul<Rhs, Output = Self>
162 + MulAssign<Rhs>
163{
164}
165
166impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
167where
168 Self: Default,
169{
170 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
171 iter.fold(Self::default(), |l, r| l + r)
172 }
173}
174
175impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
176where
177 [PT; N]: Default,
178{
179 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
180 let one = Self([PT::one(); N]);
181 iter.fold(one, |l, r| l * r)
182 }
183}
184
185impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
186where
187 [PT; N]: Default,
188 Self: ArithmeticOps<PT::Scalar>,
189{
190 type Scalar = PT::Scalar;
191
192 const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
193
194 #[inline]
195 unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
196 let outer_i = i / PT::WIDTH;
197 let inner_i = i % PT::WIDTH;
198 self.0.get_unchecked(outer_i).get_unchecked(inner_i)
199 }
200
201 #[inline]
202 unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
203 let outer_i = i / PT::WIDTH;
204 let inner_i = i % PT::WIDTH;
205 self.0
206 .get_unchecked_mut(outer_i)
207 .set_unchecked(inner_i, scalar);
208 }
209
210 #[inline]
211 fn zero() -> Self {
212 Self(array::from_fn(|_| PT::zero()))
213 }
214
215 fn random(mut rng: impl RngCore) -> Self {
216 Self(array::from_fn(|_| PT::random(&mut rng)))
217 }
218
219 #[inline]
220 fn broadcast(scalar: Self::Scalar) -> Self {
221 Self(array::from_fn(|_| PT::broadcast(scalar)))
222 }
223
224 #[inline]
225 fn square(self) -> Self {
226 Self(self.0.map(|v| v.square()))
227 }
228
229 #[inline]
230 fn invert_or_zero(self) -> Self {
231 Self(self.0.map(|v| v.invert_or_zero()))
232 }
233
234 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
235 let mut first = [Default::default(); N];
236 let mut second = [Default::default(); N];
237
238 if log_block_len >= PT::LOG_WIDTH {
239 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
240 for i in (0..N).step_by(block_in_pts * 2) {
241 first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
242 first[i + block_in_pts..i + 2 * block_in_pts]
243 .copy_from_slice(&other.0[i..i + block_in_pts]);
244
245 second[i..i + block_in_pts]
246 .copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
247 second[i + block_in_pts..i + 2 * block_in_pts]
248 .copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
249 }
250 } else {
251 for i in 0..N {
252 (first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
253 }
254 }
255
256 (Self(first), Self(second))
257 }
258
259 fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
260 let mut first = [Default::default(); N];
261 let mut second = [Default::default(); N];
262
263 if log_block_len >= PT::LOG_WIDTH {
264 let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
265 for i in (0..N / 2).step_by(block_in_pts) {
266 first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
267
268 second[i..i + block_in_pts]
269 .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
270 }
271
272 for i in (0..N / 2).step_by(block_in_pts) {
273 first[i + N / 2..i + N / 2 + block_in_pts]
274 .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
275
276 second[i + N / 2..i + N / 2 + block_in_pts]
277 .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
278 }
279 } else {
280 for i in 0..N / 2 {
281 (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
282 }
283
284 for i in 0..N / 2 {
285 (first[i + N / 2], second[i + N / 2]) =
286 other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
287 }
288 }
289
290 (Self(first), Self(second))
291 }
292
293 #[inline]
294 unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
295 let log_n = checked_log_2(N);
296 let values = if log_block_len >= PT::LOG_WIDTH {
297 let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
298 let log_packed_block = log_block_len - PT::LOG_WIDTH;
299 let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
300 let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
301 array::from_fn(|i| {
302 self.0
303 .get_unchecked(offset + (i >> (log_n - log_packed_block)))
304 .spread_unchecked(log_smaller_block, i & smaller_block_index_mask)
305 })
306 } else {
307 let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
308 let log_inner_block_len = log_block_len.saturating_sub(log_n);
309 let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
310 let block_offset = block_offset << (log_block_len - log_inner_block_len);
311
312 array::from_fn(|i| {
313 self.0.get_unchecked(value_index).spread_unchecked(
314 log_inner_block_len,
315 block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
316 )
317 })
318 };
319
320 Self(values)
321 }
322
323 fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
324 Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
325 }
326
327 #[inline]
328 fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
329 let cast_slice =
331 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
332
333 PT::iter_slice(cast_slice.as_flattened())
334 }
335}
336
337impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
338where
339 [PT; N]: Default,
340{
341 #[inline]
342 fn mul_alpha(self) -> Self {
343 Self(self.0.map(|v| v.mul_alpha()))
344 }
345}
346
347pub struct ScaledTransformation<I> {
349 inner: I,
350}
351
352impl<I> ScaledTransformation<I> {
353 const fn new(inner: I) -> Self {
354 Self { inner }
355 }
356}
357
358impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
359 for ScaledTransformation<I>
360where
361 I: Transformation<IP, OP>,
362{
363 fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
364 ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
365 }
366}
367
368impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
369 for ScaledPackedField<IP, N>
370where
371 Self: PackedBinaryField,
372 ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
373 OP: PackedBinaryField,
374 IP: PackedTransformationFactory<OP>,
375{
376 type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
377 ScaledTransformation<IP::PackedTransformation<Data>>;
378
379 fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
380 transformation: FieldLinearTransformation<
381 <ScaledPackedField<OP, N> as PackedField>::Scalar,
382 Data,
383 >,
384 ) -> Self::PackedTransformation<Data> {
385 ScaledTransformation::new(IP::make_packed_transformation(transformation))
386 }
387}
388
389macro_rules! packed_scaled_field {
394 ($name:ident = [$inner:ty;$size:literal]) => {
395 pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
396
397 impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
398 type Output = Self;
399
400 #[inline]
401 fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
402 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
403 for v in self.0.iter_mut() {
404 *v += broadcast;
405 }
406
407 self
408 }
409 }
410
411 impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
412 #[inline]
413 fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
414 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
415 for v in self.0.iter_mut() {
416 *v += broadcast;
417 }
418 }
419 }
420
421 impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
422 type Output = Self;
423
424 #[inline]
425 fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
426 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
427 for v in self.0.iter_mut() {
428 *v -= broadcast;
429 }
430
431 self
432 }
433 }
434
435 impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
436 #[inline]
437 fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
438 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
439 for v in self.0.iter_mut() {
440 *v -= broadcast;
441 }
442 }
443 }
444
445 impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
446 type Output = Self;
447
448 #[inline]
449 fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
450 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
451 for v in self.0.iter_mut() {
452 *v *= broadcast;
453 }
454
455 self
456 }
457 }
458
459 impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
460 #[inline]
461 fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
462 let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
463 for v in self.0.iter_mut() {
464 *v *= broadcast;
465 }
466 }
467 }
468 };
469}
470
471pub(crate) use packed_scaled_field;
472
473unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
474where
475 PT: WithUnderlier<Underlier: Pod>,
476{
477 type Underlier = ScaledUnderlier<PT::Underlier, N>;
478
479 fn to_underlier(self) -> Self::Underlier {
480 TransparentWrapper::peel(self)
481 }
482
483 fn to_underlier_ref(&self) -> &Self::Underlier {
484 TransparentWrapper::peel_ref(self)
485 }
486
487 fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
488 TransparentWrapper::peel_mut(self)
489 }
490
491 fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
492 TransparentWrapper::peel_slice(val)
493 }
494
495 fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
496 TransparentWrapper::peel_slice_mut(val)
497 }
498
499 fn from_underlier(val: Self::Underlier) -> Self {
500 TransparentWrapper::wrap(val)
501 }
502
503 fn from_underlier_ref(val: &Self::Underlier) -> &Self {
504 TransparentWrapper::wrap_ref(val)
505 }
506
507 fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
508 TransparentWrapper::wrap_mut(val)
509 }
510
511 fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
512 TransparentWrapper::wrap_slice(val)
513 }
514
515 fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
516 TransparentWrapper::wrap_slice_mut(val)
517 }
518}
519
520impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
521where
522 U: PackScalar<F> + UnderlierType + Pod,
523 F: Field,
524 ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
525{
526 type Packed = ScaledPackedField<U::Packed, N>;
527}
528
529unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
530 for ScaledPackedField<PT, N>
531where
532 PT: WithUnderlier<Underlier = U>,
533{
534}