binius_field/arch/portable/
packed_scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	iter::{Product, Sum},
6	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::checked_arithmetics::checked_log_2;
10use bytemuck::{Pod, TransparentWrapper, Zeroable};
11use rand::RngCore;
12use subtle::ConstantTimeEq;
13
14use crate::{
15	arithmetic_traits::MulAlpha,
16	as_packed_field::PackScalar,
17	linear_transformation::{
18		FieldLinearTransformation, PackedTransformationFactory, Transformation,
19	},
20	packed::PackedBinaryField,
21	underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
22	Field, PackedField,
23};
24
25/// Packed field that just stores smaller packed field N times and performs all operations
26/// one by one.
27/// This makes sense for creating portable implementations for 256 and 512 packed sizes.
28#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
29#[repr(transparent)]
30pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
31
32impl<PT, const N: usize> ScaledPackedField<PT, N> {
33	pub const WIDTH_IN_PT: usize = N;
34
35	/// In general case PT != Self::Scalar, so this function has a different name from `PackedField::from_fn`
36	pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
37		Self(std::array::from_fn(f))
38	}
39
40	/// We put implementation here to be able to use in the generic code.
41	/// (`PackedField` is only implemented for certain types via macro).
42	#[inline]
43	pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
44	where
45		PT: PackedField,
46	{
47		let log_n = checked_log_2(N);
48		let values = if log_block_len >= PT::LOG_WIDTH {
49			let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
50			let log_packed_block = log_block_len - PT::LOG_WIDTH;
51			let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
52			let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
53			array::from_fn(|i| {
54				self.0
55					.get_unchecked(offset + (i >> (log_n - log_packed_block)))
56					.spread_unchecked(
57						log_smaller_block,
58						(i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
59					)
60			})
61		} else {
62			let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
63			let log_inner_block_len = log_block_len.saturating_sub(log_n);
64			let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
65			let block_offset = block_offset << (log_block_len - log_inner_block_len);
66
67			array::from_fn(|i| {
68				self.0.get_unchecked(value_index).spread_unchecked(
69					log_inner_block_len,
70					block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
71				)
72			})
73		};
74
75		Self(values)
76	}
77}
78
79impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
80where
81	[PT; N]: Default,
82{
83	fn default() -> Self {
84		Self(Default::default())
85	}
86}
87
88impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
89where
90	PT: From<U>,
91{
92	fn from(value: [U; N]) -> Self {
93		Self(value.map(Into::into))
94	}
95}
96
97impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
98where
99	U: From<PT>,
100{
101	fn from(value: ScaledPackedField<PT, N>) -> Self {
102		value.0.map(Into::into)
103	}
104}
105
106unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
107
108unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
109
110impl<PT: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledPackedField<PT, N> {
111	fn ct_eq(&self, other: &Self) -> subtle::Choice {
112		self.0.ct_eq(&other.0)
113	}
114}
115
116impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
117where
118	Self: Default,
119{
120	type Output = Self;
121
122	fn add(self, rhs: Self) -> Self {
123		Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
124	}
125}
126
127impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
128where
129	Self: Default,
130{
131	fn add_assign(&mut self, rhs: Self) {
132		for i in 0..N {
133			self.0[i] += rhs.0[i];
134		}
135	}
136}
137
138impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
139where
140	Self: Default,
141{
142	type Output = Self;
143
144	fn sub(self, rhs: Self) -> Self {
145		Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
146	}
147}
148
149impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
150where
151	Self: Default,
152{
153	fn sub_assign(&mut self, rhs: Self) {
154		for i in 0..N {
155			self.0[i] -= rhs.0[i];
156		}
157	}
158}
159
160impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
161where
162	Self: Default,
163{
164	type Output = Self;
165
166	fn mul(self, rhs: Self) -> Self {
167		Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
168	}
169}
170
171impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
172where
173	Self: Default,
174{
175	fn mul_assign(&mut self, rhs: Self) {
176		for i in 0..N {
177			self.0[i] *= rhs.0[i];
178		}
179	}
180}
181
182/// Currently we use this trait only in this file for compactness.
183/// If it is useful in some other place it worth moving it to `arithmetic_traits.rs`
184trait ArithmeticOps<Rhs>:
185	Add<Rhs, Output = Self>
186	+ AddAssign<Rhs>
187	+ Sub<Rhs, Output = Self>
188	+ SubAssign<Rhs>
189	+ Mul<Rhs, Output = Self>
190	+ MulAssign<Rhs>
191{
192}
193
194impl<T, Rhs> ArithmeticOps<Rhs> for T where
195	T: Add<Rhs, Output = Self>
196		+ AddAssign<Rhs>
197		+ Sub<Rhs, Output = Self>
198		+ SubAssign<Rhs>
199		+ Mul<Rhs, Output = Self>
200		+ MulAssign<Rhs>
201{
202}
203
204impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
205where
206	Self: Default,
207{
208	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
209		iter.fold(Self::default(), |l, r| l + r)
210	}
211}
212
213impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
214where
215	[PT; N]: Default,
216{
217	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
218		let one = Self([PT::one(); N]);
219		iter.fold(one, |l, r| l * r)
220	}
221}
222
223impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
224where
225	[PT; N]: Default,
226	Self: ArithmeticOps<PT::Scalar>,
227{
228	type Scalar = PT::Scalar;
229
230	const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
231
232	#[inline]
233	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
234		let outer_i = i / PT::WIDTH;
235		let inner_i = i % PT::WIDTH;
236		self.0.get_unchecked(outer_i).get_unchecked(inner_i)
237	}
238
239	#[inline]
240	unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
241		let outer_i = i / PT::WIDTH;
242		let inner_i = i % PT::WIDTH;
243		self.0
244			.get_unchecked_mut(outer_i)
245			.set_unchecked(inner_i, scalar);
246	}
247
248	#[inline]
249	fn zero() -> Self {
250		Self(array::from_fn(|_| PT::zero()))
251	}
252
253	fn random(mut rng: impl RngCore) -> Self {
254		Self(array::from_fn(|_| PT::random(&mut rng)))
255	}
256
257	#[inline]
258	fn broadcast(scalar: Self::Scalar) -> Self {
259		Self(array::from_fn(|_| PT::broadcast(scalar)))
260	}
261
262	#[inline]
263	fn square(self) -> Self {
264		Self(self.0.map(|v| v.square()))
265	}
266
267	#[inline]
268	fn invert_or_zero(self) -> Self {
269		Self(self.0.map(|v| v.invert_or_zero()))
270	}
271
272	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
273		let mut first = [Default::default(); N];
274		let mut second = [Default::default(); N];
275
276		if log_block_len >= PT::LOG_WIDTH {
277			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
278			for i in (0..N).step_by(block_in_pts * 2) {
279				first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
280				first[i + block_in_pts..i + 2 * block_in_pts]
281					.copy_from_slice(&other.0[i..i + block_in_pts]);
282
283				second[i..i + block_in_pts]
284					.copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
285				second[i + block_in_pts..i + 2 * block_in_pts]
286					.copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
287			}
288		} else {
289			for i in 0..N {
290				(first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
291			}
292		}
293
294		(Self(first), Self(second))
295	}
296
297	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
298		let mut first = [Default::default(); N];
299		let mut second = [Default::default(); N];
300
301		if log_block_len >= PT::LOG_WIDTH {
302			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
303			for i in (0..N / 2).step_by(block_in_pts) {
304				first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
305
306				second[i..i + block_in_pts]
307					.copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
308			}
309
310			for i in (0..N / 2).step_by(block_in_pts) {
311				first[i + N / 2..i + N / 2 + block_in_pts]
312					.copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
313
314				second[i + N / 2..i + N / 2 + block_in_pts]
315					.copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
316			}
317		} else {
318			for i in 0..N / 2 {
319				(first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
320			}
321
322			for i in 0..N / 2 {
323				(first[i + N / 2], second[i + N / 2]) =
324					other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
325			}
326		}
327
328		(Self(first), Self(second))
329	}
330
331	#[inline]
332	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
333		Self::spread_unchecked(self, log_block_len, block_idx)
334	}
335
336	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
337		Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
338	}
339
340	#[inline]
341	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
342		// Safety: `Self` has the same layout as `[PT; N]` because it is a transparent wrapper.
343		let cast_slice =
344			unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
345
346		PT::iter_slice(cast_slice.as_flattened())
347	}
348}
349
350impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
351where
352	[PT; N]: Default,
353{
354	#[inline]
355	fn mul_alpha(self) -> Self {
356		Self(self.0.map(|v| v.mul_alpha()))
357	}
358}
359
360/// Per-element transformation as a scaled packed field.
361pub struct ScaledTransformation<I> {
362	inner: I,
363}
364
365impl<I> ScaledTransformation<I> {
366	const fn new(inner: I) -> Self {
367		Self { inner }
368	}
369}
370
371impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
372	for ScaledTransformation<I>
373where
374	I: Transformation<IP, OP>,
375{
376	fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
377		ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
378	}
379}
380
381impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
382	for ScaledPackedField<IP, N>
383where
384	Self: PackedBinaryField,
385	ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
386	OP: PackedBinaryField,
387	IP: PackedTransformationFactory<OP>,
388{
389	type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
390		ScaledTransformation<IP::PackedTransformation<Data>>;
391
392	fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
393		transformation: FieldLinearTransformation<
394			<ScaledPackedField<OP, N> as PackedField>::Scalar,
395			Data,
396		>,
397	) -> Self::PackedTransformation<Data> {
398		ScaledTransformation::new(IP::make_packed_transformation(transformation))
399	}
400}
401
402/// The only thing that prevents us from having pure generic implementation of `ScaledPackedField`
403/// is that we can't have generic operations both with `Self` and `PT::Scalar`
404/// (it leads to `conflicting implementations of trait` error).
405/// That's why we implement one of those in a macro.
406macro_rules! packed_scaled_field {
407	($name:ident = [$inner:ty;$size:literal]) => {
408		pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
409
410		impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
411			type Output = Self;
412
413			#[inline]
414			fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
415				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
416				for v in self.0.iter_mut() {
417					*v += broadcast;
418				}
419
420				self
421			}
422		}
423
424		impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
425			#[inline]
426			fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
427				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
428				for v in self.0.iter_mut() {
429					*v += broadcast;
430				}
431			}
432		}
433
434		impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
435			type Output = Self;
436
437			#[inline]
438			fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
439				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
440				for v in self.0.iter_mut() {
441					*v -= broadcast;
442				}
443
444				self
445			}
446		}
447
448		impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
449			#[inline]
450			fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
451				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
452				for v in self.0.iter_mut() {
453					*v -= broadcast;
454				}
455			}
456		}
457
458		impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
459			type Output = Self;
460
461			#[inline]
462			fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
463				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
464				for v in self.0.iter_mut() {
465					*v *= broadcast;
466				}
467
468				self
469			}
470		}
471
472		impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
473			#[inline]
474			fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
475				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
476				for v in self.0.iter_mut() {
477					*v *= broadcast;
478				}
479			}
480		}
481	};
482}
483
484pub(crate) use packed_scaled_field;
485
486unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
487where
488	PT: WithUnderlier<Underlier: Pod>,
489{
490	type Underlier = ScaledUnderlier<PT::Underlier, N>;
491
492	fn to_underlier(self) -> Self::Underlier {
493		TransparentWrapper::peel(self)
494	}
495
496	fn to_underlier_ref(&self) -> &Self::Underlier {
497		TransparentWrapper::peel_ref(self)
498	}
499
500	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
501		TransparentWrapper::peel_mut(self)
502	}
503
504	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
505		TransparentWrapper::peel_slice(val)
506	}
507
508	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
509		TransparentWrapper::peel_slice_mut(val)
510	}
511
512	fn from_underlier(val: Self::Underlier) -> Self {
513		TransparentWrapper::wrap(val)
514	}
515
516	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
517		TransparentWrapper::wrap_ref(val)
518	}
519
520	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
521		TransparentWrapper::wrap_mut(val)
522	}
523
524	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
525		TransparentWrapper::wrap_slice(val)
526	}
527
528	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
529		TransparentWrapper::wrap_slice_mut(val)
530	}
531}
532
533impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
534where
535	U: PackScalar<F> + UnderlierType + Pod,
536	F: Field,
537	ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
538{
539	type Packed = ScaledPackedField<U::Packed, N>;
540}
541
542unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
543	for ScaledPackedField<PT, N>
544where
545	PT: WithUnderlier<Underlier = U>,
546{
547}