binius_field/arch/portable/
packed_scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	iter::{Product, Sum},
6	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	checked_arithmetics::checked_log_2,
13};
14use bytemuck::{Pod, TransparentWrapper, Zeroable};
15use rand::{
16	Rng,
17	distr::{Distribution, StandardUniform},
18};
19
20use crate::{
21	Field, PackedField,
22	arithmetic_traits::MulAlpha,
23	as_packed_field::PackScalar,
24	linear_transformation::{
25		FieldLinearTransformation, PackedTransformationFactory, Transformation,
26	},
27	packed::PackedBinaryField,
28	underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
29};
30
31/// Packed field that just stores smaller packed field N times and performs all operations
32/// one by one.
33/// This makes sense for creating portable implementations for 256 and 512 packed sizes.
34#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
35#[repr(transparent)]
36pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
37
38impl<PT, const N: usize> ScaledPackedField<PT, N> {
39	pub const WIDTH_IN_PT: usize = N;
40
41	/// In general case PT != Self::Scalar, so this function has a different name from
42	/// `PackedField::from_fn`
43	pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
44		Self(std::array::from_fn(f))
45	}
46
47	/// We put implementation here to be able to use in the generic code.
48	/// (`PackedField` is only implemented for certain types via macro).
49	#[inline]
50	pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
51	where
52		PT: PackedField,
53	{
54		let log_n = checked_log_2(N);
55		let values = if log_block_len >= PT::LOG_WIDTH {
56			let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
57			let log_packed_block = log_block_len - PT::LOG_WIDTH;
58			let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
59			let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
60			array::from_fn(|i| unsafe {
61				self.0
62					.get_unchecked(offset + (i >> (log_n - log_packed_block)))
63					.spread_unchecked(
64						log_smaller_block,
65						(i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
66					)
67			})
68		} else {
69			let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
70			let log_inner_block_len = log_block_len.saturating_sub(log_n);
71			let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
72			let block_offset = block_offset << (log_block_len - log_inner_block_len);
73
74			array::from_fn(|i| unsafe {
75				self.0.get_unchecked(value_index).spread_unchecked(
76					log_inner_block_len,
77					block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
78				)
79			})
80		};
81
82		Self(values)
83	}
84}
85
86impl<PT, const N: usize> SerializeBytes for ScaledPackedField<PT, N>
87where
88	PT: SerializeBytes,
89{
90	fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
91		for elem in &self.0 {
92			elem.serialize(&mut write_buf)?;
93		}
94		Ok(())
95	}
96}
97
98impl<PT, const N: usize> DeserializeBytes for ScaledPackedField<PT, N>
99where
100	PT: DeserializeBytes,
101{
102	fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError> {
103		let mut result = Vec::with_capacity(N);
104		for _ in 0..N {
105			result.push(PT::deserialize(&mut read_buf)?);
106		}
107
108		match result.try_into() {
109			Ok(arr) => Ok(Self(arr)),
110			Err(_) => Err(SerializationError::InvalidConstruction {
111				name: "ScaledPackedField",
112			}),
113		}
114	}
115}
116
117impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
118where
119	[PT; N]: Default,
120{
121	fn default() -> Self {
122		Self(Default::default())
123	}
124}
125
126impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
127where
128	PT: From<U>,
129{
130	fn from(value: [U; N]) -> Self {
131		Self(value.map(Into::into))
132	}
133}
134
135impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
136where
137	U: From<PT>,
138{
139	fn from(value: ScaledPackedField<PT, N>) -> Self {
140		value.0.map(Into::into)
141	}
142}
143
144unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
145
146unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
147
148impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
149where
150	Self: Default,
151{
152	type Output = Self;
153
154	fn add(self, rhs: Self) -> Self {
155		Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
156	}
157}
158
159impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
160where
161	Self: Default,
162{
163	fn add_assign(&mut self, rhs: Self) {
164		for i in 0..N {
165			self.0[i] += rhs.0[i];
166		}
167	}
168}
169
170impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
171where
172	Self: Default,
173{
174	type Output = Self;
175
176	fn sub(self, rhs: Self) -> Self {
177		Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
178	}
179}
180
181impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
182where
183	Self: Default,
184{
185	fn sub_assign(&mut self, rhs: Self) {
186		for i in 0..N {
187			self.0[i] -= rhs.0[i];
188		}
189	}
190}
191
192impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
193where
194	Self: Default,
195{
196	type Output = Self;
197
198	fn mul(self, rhs: Self) -> Self {
199		Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
200	}
201}
202
203impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
204where
205	Self: Default,
206{
207	fn mul_assign(&mut self, rhs: Self) {
208		for i in 0..N {
209			self.0[i] *= rhs.0[i];
210		}
211	}
212}
213
214/// Currently we use this trait only in this file for compactness.
215/// If it is useful in some other place it worth moving it to `arithmetic_traits.rs`
216trait ArithmeticOps<Rhs>:
217	Add<Rhs, Output = Self>
218	+ AddAssign<Rhs>
219	+ Sub<Rhs, Output = Self>
220	+ SubAssign<Rhs>
221	+ Mul<Rhs, Output = Self>
222	+ MulAssign<Rhs>
223{
224}
225
226impl<T, Rhs> ArithmeticOps<Rhs> for T where
227	T: Add<Rhs, Output = Self>
228		+ AddAssign<Rhs>
229		+ Sub<Rhs, Output = Self>
230		+ SubAssign<Rhs>
231		+ Mul<Rhs, Output = Self>
232		+ MulAssign<Rhs>
233{
234}
235
236impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
237where
238	Self: Default,
239{
240	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
241		iter.fold(Self::default(), |l, r| l + r)
242	}
243}
244
245impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
246where
247	[PT; N]: Default,
248{
249	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
250		let one = Self([PT::one(); N]);
251		iter.fold(one, |l, r| l * r)
252	}
253}
254
255impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
256where
257	[PT; N]: Default,
258	Self: ArithmeticOps<PT::Scalar>,
259{
260	type Scalar = PT::Scalar;
261
262	const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
263
264	#[inline]
265	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
266		let outer_i = i / PT::WIDTH;
267		let inner_i = i % PT::WIDTH;
268		unsafe { self.0.get_unchecked(outer_i).get_unchecked(inner_i) }
269	}
270
271	#[inline]
272	unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
273		let outer_i = i / PT::WIDTH;
274		let inner_i = i % PT::WIDTH;
275		unsafe {
276			self.0
277				.get_unchecked_mut(outer_i)
278				.set_unchecked(inner_i, scalar);
279		}
280	}
281
282	#[inline]
283	fn zero() -> Self {
284		Self(array::from_fn(|_| PT::zero()))
285	}
286
287	#[inline]
288	fn broadcast(scalar: Self::Scalar) -> Self {
289		Self(array::from_fn(|_| PT::broadcast(scalar)))
290	}
291
292	#[inline]
293	fn square(self) -> Self {
294		Self(self.0.map(|v| v.square()))
295	}
296
297	#[inline]
298	fn invert_or_zero(self) -> Self {
299		Self(self.0.map(|v| v.invert_or_zero()))
300	}
301
302	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
303		let mut first = [Default::default(); N];
304		let mut second = [Default::default(); N];
305
306		if log_block_len >= PT::LOG_WIDTH {
307			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
308			for i in (0..N).step_by(block_in_pts * 2) {
309				first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
310				first[i + block_in_pts..i + 2 * block_in_pts]
311					.copy_from_slice(&other.0[i..i + block_in_pts]);
312
313				second[i..i + block_in_pts]
314					.copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
315				second[i + block_in_pts..i + 2 * block_in_pts]
316					.copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
317			}
318		} else {
319			for i in 0..N {
320				(first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
321			}
322		}
323
324		(Self(first), Self(second))
325	}
326
327	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
328		let mut first = [Default::default(); N];
329		let mut second = [Default::default(); N];
330
331		if log_block_len >= PT::LOG_WIDTH {
332			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
333			for i in (0..N / 2).step_by(block_in_pts) {
334				first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
335
336				second[i..i + block_in_pts]
337					.copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
338			}
339
340			for i in (0..N / 2).step_by(block_in_pts) {
341				first[i + N / 2..i + N / 2 + block_in_pts]
342					.copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
343
344				second[i + N / 2..i + N / 2 + block_in_pts]
345					.copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
346			}
347		} else {
348			for i in 0..N / 2 {
349				(first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
350			}
351
352			for i in 0..N / 2 {
353				(first[i + N / 2], second[i + N / 2]) =
354					other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
355			}
356		}
357
358		(Self(first), Self(second))
359	}
360
361	#[inline]
362	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
363		unsafe { Self::spread_unchecked(self, log_block_len, block_idx) }
364	}
365
366	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
367		Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
368	}
369
370	#[inline]
371	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
372		// Safety: `Self` has the same layout as `[PT; N]` because it is a transparent wrapper.
373		let cast_slice =
374			unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
375
376		PT::iter_slice(cast_slice.as_flattened())
377	}
378}
379
380impl<PT: PackedField, const N: usize> Distribution<ScaledPackedField<PT, N>> for StandardUniform
381where
382	[PT; N]: Default,
383{
384	fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledPackedField<PT, N> {
385		ScaledPackedField(array::from_fn(|_| PT::random(&mut rng)))
386	}
387}
388
389impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
390where
391	[PT; N]: Default,
392{
393	#[inline]
394	fn mul_alpha(self) -> Self {
395		Self(self.0.map(|v| v.mul_alpha()))
396	}
397}
398
399/// Per-element transformation as a scaled packed field.
400pub struct ScaledTransformation<I> {
401	inner: I,
402}
403
404impl<I> ScaledTransformation<I> {
405	const fn new(inner: I) -> Self {
406		Self { inner }
407	}
408}
409
410impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
411	for ScaledTransformation<I>
412where
413	I: Transformation<IP, OP>,
414{
415	fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
416		ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
417	}
418}
419
420impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
421	for ScaledPackedField<IP, N>
422where
423	Self: PackedBinaryField,
424	ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
425	OP: PackedBinaryField,
426	IP: PackedTransformationFactory<OP>,
427{
428	type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
429		ScaledTransformation<IP::PackedTransformation<Data>>;
430
431	fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
432		transformation: FieldLinearTransformation<
433			<ScaledPackedField<OP, N> as PackedField>::Scalar,
434			Data,
435		>,
436	) -> Self::PackedTransformation<Data> {
437		ScaledTransformation::new(IP::make_packed_transformation(transformation))
438	}
439}
440
441/// The only thing that prevents us from having pure generic implementation of `ScaledPackedField`
442/// is that we can't have generic operations both with `Self` and `PT::Scalar`
443/// (it leads to `conflicting implementations of trait` error).
444/// That's why we implement one of those in a macro.
445macro_rules! packed_scaled_field {
446	($name:ident = [$inner:ty;$size:literal]) => {
447		pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
448
449		impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
450			type Output = Self;
451
452			#[inline]
453			fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
454				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
455				for v in self.0.iter_mut() {
456					*v += broadcast;
457				}
458
459				self
460			}
461		}
462
463		impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
464			#[inline]
465			fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
466				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
467				for v in self.0.iter_mut() {
468					*v += broadcast;
469				}
470			}
471		}
472
473		impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
474			type Output = Self;
475
476			#[inline]
477			fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
478				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
479				for v in self.0.iter_mut() {
480					*v -= broadcast;
481				}
482
483				self
484			}
485		}
486
487		impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
488			#[inline]
489			fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
490				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
491				for v in self.0.iter_mut() {
492					*v -= broadcast;
493				}
494			}
495		}
496
497		impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
498			type Output = Self;
499
500			#[inline]
501			fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
502				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
503				for v in self.0.iter_mut() {
504					*v *= broadcast;
505				}
506
507				self
508			}
509		}
510
511		impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
512			#[inline]
513			fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
514				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
515				for v in self.0.iter_mut() {
516					*v *= broadcast;
517				}
518			}
519		}
520	};
521}
522
523pub(crate) use packed_scaled_field;
524
525unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
526where
527	PT: WithUnderlier<Underlier: Pod>,
528{
529	type Underlier = ScaledUnderlier<PT::Underlier, N>;
530
531	fn to_underlier(self) -> Self::Underlier {
532		TransparentWrapper::peel(self)
533	}
534
535	fn to_underlier_ref(&self) -> &Self::Underlier {
536		TransparentWrapper::peel_ref(self)
537	}
538
539	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
540		TransparentWrapper::peel_mut(self)
541	}
542
543	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
544		TransparentWrapper::peel_slice(val)
545	}
546
547	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
548		TransparentWrapper::peel_slice_mut(val)
549	}
550
551	fn from_underlier(val: Self::Underlier) -> Self {
552		TransparentWrapper::wrap(val)
553	}
554
555	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
556		TransparentWrapper::wrap_ref(val)
557	}
558
559	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
560		TransparentWrapper::wrap_mut(val)
561	}
562
563	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
564		TransparentWrapper::wrap_slice(val)
565	}
566
567	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
568		TransparentWrapper::wrap_slice_mut(val)
569	}
570}
571
572impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
573where
574	U: PackScalar<F> + UnderlierType + Pod,
575	F: Field,
576	ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
577{
578	type Packed = ScaledPackedField<U::Packed, N>;
579}
580
581unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
582	for ScaledPackedField<PT, N>
583where
584	PT: WithUnderlier<Underlier = U>,
585{
586}