binius_field/arch/portable/
packed_scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	iter::{Product, Sum},
6	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	checked_arithmetics::checked_log_2,
13};
14use bytemuck::{Pod, TransparentWrapper, Zeroable};
15use rand::{
16	Rng,
17	distr::{Distribution, StandardUniform},
18};
19
20use crate::{
21	Field, PackedField,
22	arithmetic_traits::MulAlpha,
23	as_packed_field::PackScalar,
24	linear_transformation::Transformation,
25	underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
26};
27
28/// Packed field that just stores smaller packed field N times and performs all operations
29/// one by one.
30/// This makes sense for creating portable implementations for 256 and 512 packed sizes.
31#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
32#[repr(transparent)]
33pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
34
35impl<PT, const N: usize> ScaledPackedField<PT, N> {
36	pub const WIDTH_IN_PT: usize = N;
37
38	/// In general case PT != Self::Scalar, so this function has a different name from
39	/// `PackedField::from_fn`
40	pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
41		Self(std::array::from_fn(f))
42	}
43
44	/// We put implementation here to be able to use in the generic code.
45	/// (`PackedField` is only implemented for certain types via macro).
46	#[inline]
47	pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
48	where
49		PT: PackedField,
50	{
51		let log_n = checked_log_2(N);
52		let values = if log_block_len >= PT::LOG_WIDTH {
53			let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
54			let log_packed_block = log_block_len - PT::LOG_WIDTH;
55			let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
56			let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
57			array::from_fn(|i| unsafe {
58				self.0
59					.get_unchecked(offset + (i >> (log_n - log_packed_block)))
60					.spread_unchecked(
61						log_smaller_block,
62						(i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
63					)
64			})
65		} else {
66			let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
67			let log_inner_block_len = log_block_len.saturating_sub(log_n);
68			let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
69			let block_offset = block_offset << (log_block_len - log_inner_block_len);
70
71			array::from_fn(|i| unsafe {
72				self.0.get_unchecked(value_index).spread_unchecked(
73					log_inner_block_len,
74					block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
75				)
76			})
77		};
78
79		Self(values)
80	}
81}
82
83impl<PT, const N: usize> SerializeBytes for ScaledPackedField<PT, N>
84where
85	PT: SerializeBytes,
86{
87	fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
88		for elem in &self.0 {
89			elem.serialize(&mut write_buf)?;
90		}
91		Ok(())
92	}
93}
94
95impl<PT, const N: usize> DeserializeBytes for ScaledPackedField<PT, N>
96where
97	PT: DeserializeBytes,
98{
99	fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError> {
100		let mut result = Vec::with_capacity(N);
101		for _ in 0..N {
102			result.push(PT::deserialize(&mut read_buf)?);
103		}
104
105		match result.try_into() {
106			Ok(arr) => Ok(Self(arr)),
107			Err(_) => Err(SerializationError::InvalidConstruction {
108				name: "ScaledPackedField",
109			}),
110		}
111	}
112}
113
114impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
115where
116	[PT; N]: Default,
117{
118	fn default() -> Self {
119		Self(Default::default())
120	}
121}
122
123impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
124where
125	PT: From<U>,
126{
127	fn from(value: [U; N]) -> Self {
128		Self(value.map(Into::into))
129	}
130}
131
132impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
133where
134	U: From<PT>,
135{
136	fn from(value: ScaledPackedField<PT, N>) -> Self {
137		value.0.map(Into::into)
138	}
139}
140
141unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
142
143unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
144
145impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
146where
147	Self: Default,
148{
149	type Output = Self;
150
151	fn add(self, rhs: Self) -> Self {
152		Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
153	}
154}
155
156impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
157where
158	Self: Default,
159{
160	fn add_assign(&mut self, rhs: Self) {
161		for i in 0..N {
162			self.0[i] += rhs.0[i];
163		}
164	}
165}
166
167impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
168where
169	Self: Default,
170{
171	type Output = Self;
172
173	fn sub(self, rhs: Self) -> Self {
174		Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
175	}
176}
177
178impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
179where
180	Self: Default,
181{
182	fn sub_assign(&mut self, rhs: Self) {
183		for i in 0..N {
184			self.0[i] -= rhs.0[i];
185		}
186	}
187}
188
189impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
190where
191	Self: Default,
192{
193	type Output = Self;
194
195	fn mul(self, rhs: Self) -> Self {
196		Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
197	}
198}
199
200impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
201where
202	Self: Default,
203{
204	fn mul_assign(&mut self, rhs: Self) {
205		for i in 0..N {
206			self.0[i] *= rhs.0[i];
207		}
208	}
209}
210
211/// Currently we use this trait only in this file for compactness.
212/// If it is useful in some other place it worth moving it to `arithmetic_traits.rs`
213trait ArithmeticOps<Rhs>:
214	Add<Rhs, Output = Self>
215	+ AddAssign<Rhs>
216	+ Sub<Rhs, Output = Self>
217	+ SubAssign<Rhs>
218	+ Mul<Rhs, Output = Self>
219	+ MulAssign<Rhs>
220{
221}
222
223impl<T, Rhs> ArithmeticOps<Rhs> for T where
224	T: Add<Rhs, Output = Self>
225		+ AddAssign<Rhs>
226		+ Sub<Rhs, Output = Self>
227		+ SubAssign<Rhs>
228		+ Mul<Rhs, Output = Self>
229		+ MulAssign<Rhs>
230{
231}
232
233impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
234where
235	Self: Default,
236{
237	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
238		iter.fold(Self::default(), |l, r| l + r)
239	}
240}
241
242impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
243where
244	[PT; N]: Default,
245{
246	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
247		let one = Self([PT::one(); N]);
248		iter.fold(one, |l, r| l * r)
249	}
250}
251
252impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
253where
254	[PT; N]: Default,
255	Self: ArithmeticOps<PT::Scalar>,
256{
257	type Scalar = PT::Scalar;
258
259	const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
260
261	#[inline]
262	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
263		let outer_i = i / PT::WIDTH;
264		let inner_i = i % PT::WIDTH;
265		unsafe { self.0.get_unchecked(outer_i).get_unchecked(inner_i) }
266	}
267
268	#[inline]
269	unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
270		let outer_i = i / PT::WIDTH;
271		let inner_i = i % PT::WIDTH;
272		unsafe {
273			self.0
274				.get_unchecked_mut(outer_i)
275				.set_unchecked(inner_i, scalar);
276		}
277	}
278
279	#[inline]
280	fn zero() -> Self {
281		Self(array::from_fn(|_| PT::zero()))
282	}
283
284	#[inline]
285	fn broadcast(scalar: Self::Scalar) -> Self {
286		Self(array::from_fn(|_| PT::broadcast(scalar)))
287	}
288
289	#[inline]
290	fn square(self) -> Self {
291		Self(self.0.map(|v| v.square()))
292	}
293
294	#[inline]
295	fn invert_or_zero(self) -> Self {
296		Self(self.0.map(|v| v.invert_or_zero()))
297	}
298
299	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
300		let mut first = [Default::default(); N];
301		let mut second = [Default::default(); N];
302
303		if log_block_len >= PT::LOG_WIDTH {
304			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
305			for i in (0..N).step_by(block_in_pts * 2) {
306				first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
307				first[i + block_in_pts..i + 2 * block_in_pts]
308					.copy_from_slice(&other.0[i..i + block_in_pts]);
309
310				second[i..i + block_in_pts]
311					.copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
312				second[i + block_in_pts..i + 2 * block_in_pts]
313					.copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
314			}
315		} else {
316			for i in 0..N {
317				(first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
318			}
319		}
320
321		(Self(first), Self(second))
322	}
323
324	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
325		let mut first = [Default::default(); N];
326		let mut second = [Default::default(); N];
327
328		if log_block_len >= PT::LOG_WIDTH {
329			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
330			for i in (0..N / 2).step_by(block_in_pts) {
331				first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
332
333				second[i..i + block_in_pts]
334					.copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
335			}
336
337			for i in (0..N / 2).step_by(block_in_pts) {
338				first[i + N / 2..i + N / 2 + block_in_pts]
339					.copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
340
341				second[i + N / 2..i + N / 2 + block_in_pts]
342					.copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
343			}
344		} else {
345			for i in 0..N / 2 {
346				(first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
347			}
348
349			for i in 0..N / 2 {
350				(first[i + N / 2], second[i + N / 2]) =
351					other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
352			}
353		}
354
355		(Self(first), Self(second))
356	}
357
358	#[inline]
359	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
360		unsafe { Self::spread_unchecked(self, log_block_len, block_idx) }
361	}
362
363	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
364		Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
365	}
366
367	#[inline]
368	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
369		// Safety: `Self` has the same layout as `[PT; N]` because it is a transparent wrapper.
370		let cast_slice =
371			unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
372
373		PT::iter_slice(cast_slice.as_flattened())
374	}
375}
376
377impl<PT: PackedField, const N: usize> Distribution<ScaledPackedField<PT, N>> for StandardUniform
378where
379	[PT; N]: Default,
380{
381	fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledPackedField<PT, N> {
382		ScaledPackedField(array::from_fn(|_| PT::random(&mut rng)))
383	}
384}
385
386impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
387where
388	[PT; N]: Default,
389{
390	#[inline]
391	fn mul_alpha(self) -> Self {
392		Self(self.0.map(|v| v.mul_alpha()))
393	}
394}
395
396/// Per-element transformation as a scaled packed field.
397pub struct ScaledTransformation<I> {
398	inner: I,
399}
400
401impl<I> ScaledTransformation<I> {
402	const fn new(inner: I) -> Self {
403		Self { inner }
404	}
405}
406
407impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
408	for ScaledTransformation<I>
409where
410	I: Transformation<IP, OP>,
411{
412	fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
413		ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
414	}
415}
416
417/// The only thing that prevents us from having pure generic implementation of `ScaledPackedField`
418/// is that we can't have generic operations both with `Self` and `PT::Scalar`
419/// (it leads to `conflicting implementations of trait` error).
420/// That's why we implement one of those in a macro.
421macro_rules! packed_scaled_field {
422	($name:ident = [$inner:ty;$size:literal]) => {
423		pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
424
425		impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
426			type Output = Self;
427
428			#[inline]
429			fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
430				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
431				for v in self.0.iter_mut() {
432					*v += broadcast;
433				}
434
435				self
436			}
437		}
438
439		impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
440			#[inline]
441			fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
442				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
443				for v in self.0.iter_mut() {
444					*v += broadcast;
445				}
446			}
447		}
448
449		impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
450			type Output = Self;
451
452			#[inline]
453			fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
454				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
455				for v in self.0.iter_mut() {
456					*v -= broadcast;
457				}
458
459				self
460			}
461		}
462
463		impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
464			#[inline]
465			fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
466				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
467				for v in self.0.iter_mut() {
468					*v -= broadcast;
469				}
470			}
471		}
472
473		impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
474			type Output = Self;
475
476			#[inline]
477			fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
478				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
479				for v in self.0.iter_mut() {
480					*v *= broadcast;
481				}
482
483				self
484			}
485		}
486
487		impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
488			#[inline]
489			fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
490				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
491				for v in self.0.iter_mut() {
492					*v *= broadcast;
493				}
494			}
495		}
496	};
497}
498
499pub(crate) use packed_scaled_field;
500
501unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
502where
503	PT: WithUnderlier<Underlier: Pod>,
504{
505	type Underlier = ScaledUnderlier<PT::Underlier, N>;
506}
507
508impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
509where
510	U: PackScalar<F> + UnderlierType + Pod,
511	F: Field,
512	ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
513{
514	type Packed = ScaledPackedField<U::Packed, N>;
515}
516
517unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
518	for ScaledPackedField<PT, N>
519where
520	PT: WithUnderlier<Underlier = U>,
521{
522}