binius_field/arch/portable/
packed_scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	iter::{Product, Sum},
6	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::checked_arithmetics::checked_log_2;
10use bytemuck::{Pod, TransparentWrapper, Zeroable};
11use rand::RngCore;
12use subtle::ConstantTimeEq;
13
14use crate::{
15	arithmetic_traits::MulAlpha,
16	as_packed_field::PackScalar,
17	linear_transformation::{
18		FieldLinearTransformation, PackedTransformationFactory, Transformation,
19	},
20	packed::PackedBinaryField,
21	underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
22	Field, PackedField,
23};
24
25/// Packed field that just stores smaller packed field N times and performs all operations
26/// one by one.
27/// This makes sense for creating portable implementations for 256 and 512 packed sizes.
28#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
29#[repr(transparent)]
30pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
31
32impl<PT, const N: usize> ScaledPackedField<PT, N> {
33	pub const WIDTH_IN_PT: usize = N;
34
35	/// In general case PT != Self::Scalar, so this function has a different name from `PackedField::from_fn`
36	pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
37		Self(std::array::from_fn(f))
38	}
39}
40
41impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
42where
43	[PT; N]: Default,
44{
45	fn default() -> Self {
46		Self(Default::default())
47	}
48}
49
50impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
51where
52	PT: From<U>,
53{
54	fn from(value: [U; N]) -> Self {
55		Self(value.map(Into::into))
56	}
57}
58
59impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
60where
61	U: From<PT>,
62{
63	fn from(value: ScaledPackedField<PT, N>) -> Self {
64		value.0.map(Into::into)
65	}
66}
67
68unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
69
70unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
71
72impl<PT: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledPackedField<PT, N> {
73	fn ct_eq(&self, other: &Self) -> subtle::Choice {
74		self.0.ct_eq(&other.0)
75	}
76}
77
78impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
79where
80	Self: Default,
81{
82	type Output = Self;
83
84	fn add(self, rhs: Self) -> Self {
85		Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
86	}
87}
88
89impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
90where
91	Self: Default,
92{
93	fn add_assign(&mut self, rhs: Self) {
94		for i in 0..N {
95			self.0[i] += rhs.0[i];
96		}
97	}
98}
99
100impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
101where
102	Self: Default,
103{
104	type Output = Self;
105
106	fn sub(self, rhs: Self) -> Self {
107		Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
108	}
109}
110
111impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
112where
113	Self: Default,
114{
115	fn sub_assign(&mut self, rhs: Self) {
116		for i in 0..N {
117			self.0[i] -= rhs.0[i];
118		}
119	}
120}
121
122impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
123where
124	Self: Default,
125{
126	type Output = Self;
127
128	fn mul(self, rhs: Self) -> Self {
129		Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
130	}
131}
132
133impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
134where
135	Self: Default,
136{
137	fn mul_assign(&mut self, rhs: Self) {
138		for i in 0..N {
139			self.0[i] *= rhs.0[i];
140		}
141	}
142}
143
144/// Currently we use this trait only in this file for compactness.
145/// If it is useful in some other place it worth moving it to `arithmetic_traits.rs`
146trait ArithmeticOps<Rhs>:
147	Add<Rhs, Output = Self>
148	+ AddAssign<Rhs>
149	+ Sub<Rhs, Output = Self>
150	+ SubAssign<Rhs>
151	+ Mul<Rhs, Output = Self>
152	+ MulAssign<Rhs>
153{
154}
155
156impl<T, Rhs> ArithmeticOps<Rhs> for T where
157	T: Add<Rhs, Output = Self>
158		+ AddAssign<Rhs>
159		+ Sub<Rhs, Output = Self>
160		+ SubAssign<Rhs>
161		+ Mul<Rhs, Output = Self>
162		+ MulAssign<Rhs>
163{
164}
165
166impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
167where
168	Self: Default,
169{
170	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
171		iter.fold(Self::default(), |l, r| l + r)
172	}
173}
174
175impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
176where
177	[PT; N]: Default,
178{
179	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
180		let one = Self([PT::one(); N]);
181		iter.fold(one, |l, r| l * r)
182	}
183}
184
185impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
186where
187	[PT; N]: Default,
188	Self: ArithmeticOps<PT::Scalar>,
189{
190	type Scalar = PT::Scalar;
191
192	const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
193
194	#[inline]
195	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
196		let outer_i = i / PT::WIDTH;
197		let inner_i = i % PT::WIDTH;
198		self.0.get_unchecked(outer_i).get_unchecked(inner_i)
199	}
200
201	#[inline]
202	unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
203		let outer_i = i / PT::WIDTH;
204		let inner_i = i % PT::WIDTH;
205		self.0
206			.get_unchecked_mut(outer_i)
207			.set_unchecked(inner_i, scalar);
208	}
209
210	#[inline]
211	fn zero() -> Self {
212		Self(array::from_fn(|_| PT::zero()))
213	}
214
215	fn random(mut rng: impl RngCore) -> Self {
216		Self(array::from_fn(|_| PT::random(&mut rng)))
217	}
218
219	#[inline]
220	fn broadcast(scalar: Self::Scalar) -> Self {
221		Self(array::from_fn(|_| PT::broadcast(scalar)))
222	}
223
224	#[inline]
225	fn square(self) -> Self {
226		Self(self.0.map(|v| v.square()))
227	}
228
229	#[inline]
230	fn invert_or_zero(self) -> Self {
231		Self(self.0.map(|v| v.invert_or_zero()))
232	}
233
234	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
235		let mut first = [Default::default(); N];
236		let mut second = [Default::default(); N];
237
238		if log_block_len >= PT::LOG_WIDTH {
239			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
240			for i in (0..N).step_by(block_in_pts * 2) {
241				first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
242				first[i + block_in_pts..i + 2 * block_in_pts]
243					.copy_from_slice(&other.0[i..i + block_in_pts]);
244
245				second[i..i + block_in_pts]
246					.copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
247				second[i + block_in_pts..i + 2 * block_in_pts]
248					.copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
249			}
250		} else {
251			for i in 0..N {
252				(first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
253			}
254		}
255
256		(Self(first), Self(second))
257	}
258
259	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
260		let mut first = [Default::default(); N];
261		let mut second = [Default::default(); N];
262
263		if log_block_len >= PT::LOG_WIDTH {
264			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
265			for i in (0..N / 2).step_by(block_in_pts) {
266				first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
267
268				second[i..i + block_in_pts]
269					.copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
270			}
271
272			for i in (0..N / 2).step_by(block_in_pts) {
273				first[i + N / 2..i + N / 2 + block_in_pts]
274					.copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
275
276				second[i + N / 2..i + N / 2 + block_in_pts]
277					.copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
278			}
279		} else {
280			for i in 0..N / 2 {
281				(first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
282			}
283
284			for i in 0..N / 2 {
285				(first[i + N / 2], second[i + N / 2]) =
286					other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
287			}
288		}
289
290		(Self(first), Self(second))
291	}
292
293	#[inline]
294	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
295		let log_n = checked_log_2(N);
296		let values = if log_block_len >= PT::LOG_WIDTH {
297			let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
298			let log_packed_block = log_block_len - PT::LOG_WIDTH;
299			let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
300			let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
301			array::from_fn(|i| {
302				self.0
303					.get_unchecked(offset + (i >> (log_n - log_packed_block)))
304					.spread_unchecked(log_smaller_block, i & smaller_block_index_mask)
305			})
306		} else {
307			let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
308			let log_inner_block_len = log_block_len.saturating_sub(log_n);
309			let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
310			let block_offset = block_offset << (log_block_len - log_inner_block_len);
311
312			array::from_fn(|i| {
313				self.0.get_unchecked(value_index).spread_unchecked(
314					log_inner_block_len,
315					block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
316				)
317			})
318		};
319
320		Self(values)
321	}
322
323	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
324		Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
325	}
326
327	#[inline]
328	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
329		// Safety: `Self` has the same layout as `[PT; N]` because it is a transparent wrapper.
330		let cast_slice =
331			unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
332
333		PT::iter_slice(cast_slice.as_flattened())
334	}
335}
336
337impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
338where
339	[PT; N]: Default,
340{
341	#[inline]
342	fn mul_alpha(self) -> Self {
343		Self(self.0.map(|v| v.mul_alpha()))
344	}
345}
346
347/// Per-element transformation as a scaled packed field.
348pub struct ScaledTransformation<I> {
349	inner: I,
350}
351
352impl<I> ScaledTransformation<I> {
353	const fn new(inner: I) -> Self {
354		Self { inner }
355	}
356}
357
358impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
359	for ScaledTransformation<I>
360where
361	I: Transformation<IP, OP>,
362{
363	fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
364		ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
365	}
366}
367
368impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
369	for ScaledPackedField<IP, N>
370where
371	Self: PackedBinaryField,
372	ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
373	OP: PackedBinaryField,
374	IP: PackedTransformationFactory<OP>,
375{
376	type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
377		ScaledTransformation<IP::PackedTransformation<Data>>;
378
379	fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
380		transformation: FieldLinearTransformation<
381			<ScaledPackedField<OP, N> as PackedField>::Scalar,
382			Data,
383		>,
384	) -> Self::PackedTransformation<Data> {
385		ScaledTransformation::new(IP::make_packed_transformation(transformation))
386	}
387}
388
389/// The only thing that prevents us from having pure generic implementation of `ScaledPackedField`
390/// is that we can't have generic operations both with `Self` and `PT::Scalar`
391/// (it leads to `conflicting implementations of trait` error).
392/// That's why we implement one of those in a macro.
393macro_rules! packed_scaled_field {
394	($name:ident = [$inner:ty;$size:literal]) => {
395		pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
396
397		impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
398			type Output = Self;
399
400			#[inline]
401			fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
402				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
403				for v in self.0.iter_mut() {
404					*v += broadcast;
405				}
406
407				self
408			}
409		}
410
411		impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
412			#[inline]
413			fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
414				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
415				for v in self.0.iter_mut() {
416					*v += broadcast;
417				}
418			}
419		}
420
421		impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
422			type Output = Self;
423
424			#[inline]
425			fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
426				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
427				for v in self.0.iter_mut() {
428					*v -= broadcast;
429				}
430
431				self
432			}
433		}
434
435		impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
436			#[inline]
437			fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
438				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
439				for v in self.0.iter_mut() {
440					*v -= broadcast;
441				}
442			}
443		}
444
445		impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
446			type Output = Self;
447
448			#[inline]
449			fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
450				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
451				for v in self.0.iter_mut() {
452					*v *= broadcast;
453				}
454
455				self
456			}
457		}
458
459		impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
460			#[inline]
461			fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
462				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
463				for v in self.0.iter_mut() {
464					*v *= broadcast;
465				}
466			}
467		}
468	};
469}
470
471pub(crate) use packed_scaled_field;
472
473unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
474where
475	PT: WithUnderlier<Underlier: Pod>,
476{
477	type Underlier = ScaledUnderlier<PT::Underlier, N>;
478
479	fn to_underlier(self) -> Self::Underlier {
480		TransparentWrapper::peel(self)
481	}
482
483	fn to_underlier_ref(&self) -> &Self::Underlier {
484		TransparentWrapper::peel_ref(self)
485	}
486
487	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
488		TransparentWrapper::peel_mut(self)
489	}
490
491	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
492		TransparentWrapper::peel_slice(val)
493	}
494
495	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
496		TransparentWrapper::peel_slice_mut(val)
497	}
498
499	fn from_underlier(val: Self::Underlier) -> Self {
500		TransparentWrapper::wrap(val)
501	}
502
503	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
504		TransparentWrapper::wrap_ref(val)
505	}
506
507	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
508		TransparentWrapper::wrap_mut(val)
509	}
510
511	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
512		TransparentWrapper::wrap_slice(val)
513	}
514
515	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
516		TransparentWrapper::wrap_slice_mut(val)
517	}
518}
519
520impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
521where
522	U: PackScalar<F> + UnderlierType + Pod,
523	F: Field,
524	ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
525{
526	type Packed = ScaledPackedField<U::Packed, N>;
527}
528
529unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
530	for ScaledPackedField<PT, N>
531where
532	PT: WithUnderlier<Underlier = U>,
533{
534}