binius_field/arch/portable/
packed_scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	iter::{Product, Sum},
6	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializationMode, SerializeBytes,
11	bytes::{Buf, BufMut},
12	checked_arithmetics::checked_log_2,
13};
14use bytemuck::{Pod, TransparentWrapper, Zeroable};
15use rand::RngCore;
16use subtle::ConstantTimeEq;
17
18use crate::{
19	Field, PackedField,
20	arithmetic_traits::MulAlpha,
21	as_packed_field::PackScalar,
22	linear_transformation::{
23		FieldLinearTransformation, PackedTransformationFactory, Transformation,
24	},
25	packed::PackedBinaryField,
26	underlier::{ScaledUnderlier, UnderlierType, WithUnderlier},
27};
28
29/// Packed field that just stores smaller packed field N times and performs all operations
30/// one by one.
31/// This makes sense for creating portable implementations for 256 and 512 packed sizes.
32#[derive(PartialEq, Eq, Clone, Copy, Debug, bytemuck::TransparentWrapper)]
33#[repr(transparent)]
34pub struct ScaledPackedField<PT, const N: usize>(pub(super) [PT; N]);
35
36impl<PT, const N: usize> ScaledPackedField<PT, N> {
37	pub const WIDTH_IN_PT: usize = N;
38
39	/// In general case PT != Self::Scalar, so this function has a different name from
40	/// `PackedField::from_fn`
41	pub fn from_direct_packed_fn(f: impl FnMut(usize) -> PT) -> Self {
42		Self(std::array::from_fn(f))
43	}
44
45	/// We put implementation here to be able to use in the generic code.
46	/// (`PackedField` is only implemented for certain types via macro).
47	#[inline]
48	pub(crate) unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self
49	where
50		PT: PackedField,
51	{
52		let log_n = checked_log_2(N);
53		let values = if log_block_len >= PT::LOG_WIDTH {
54			let offset = block_idx << (log_block_len - PT::LOG_WIDTH);
55			let log_packed_block = log_block_len - PT::LOG_WIDTH;
56			let log_smaller_block = PT::LOG_WIDTH.saturating_sub(log_n - log_packed_block);
57			let smaller_block_index_mask = (1 << (PT::LOG_WIDTH - log_smaller_block)) - 1;
58			array::from_fn(|i| unsafe {
59				self.0
60					.get_unchecked(offset + (i >> (log_n - log_packed_block)))
61					.spread_unchecked(
62						log_smaller_block,
63						(i >> log_n.saturating_sub(log_block_len)) & smaller_block_index_mask,
64					)
65			})
66		} else {
67			let value_index = block_idx >> (PT::LOG_WIDTH - log_block_len);
68			let log_inner_block_len = log_block_len.saturating_sub(log_n);
69			let block_offset = block_idx & ((1 << (PT::LOG_WIDTH - log_block_len)) - 1);
70			let block_offset = block_offset << (log_block_len - log_inner_block_len);
71
72			array::from_fn(|i| unsafe {
73				self.0.get_unchecked(value_index).spread_unchecked(
74					log_inner_block_len,
75					block_offset + (i >> (log_n + log_inner_block_len - log_block_len)),
76				)
77			})
78		};
79
80		Self(values)
81	}
82}
83
84impl<PT, const N: usize> SerializeBytes for ScaledPackedField<PT, N>
85where
86	PT: SerializeBytes,
87{
88	fn serialize(
89		&self,
90		mut write_buf: impl BufMut,
91		mode: SerializationMode,
92	) -> Result<(), SerializationError> {
93		for elem in &self.0 {
94			elem.serialize(&mut write_buf, mode)?;
95		}
96		Ok(())
97	}
98}
99
100impl<PT, const N: usize> DeserializeBytes for ScaledPackedField<PT, N>
101where
102	PT: DeserializeBytes,
103{
104	fn deserialize(
105		mut read_buf: impl Buf,
106		mode: SerializationMode,
107	) -> Result<Self, SerializationError> {
108		let mut result = Vec::with_capacity(N);
109		for _ in 0..N {
110			result.push(PT::deserialize(&mut read_buf, mode)?);
111		}
112
113		match result.try_into() {
114			Ok(arr) => Ok(Self(arr)),
115			Err(_) => Err(SerializationError::InvalidConstruction {
116				name: "ScaledPackedField",
117			}),
118		}
119	}
120}
121
122impl<PT, const N: usize> Default for ScaledPackedField<PT, N>
123where
124	[PT; N]: Default,
125{
126	fn default() -> Self {
127		Self(Default::default())
128	}
129}
130
131impl<U, PT, const N: usize> From<[U; N]> for ScaledPackedField<PT, N>
132where
133	PT: From<U>,
134{
135	fn from(value: [U; N]) -> Self {
136		Self(value.map(Into::into))
137	}
138}
139
140impl<U, PT, const N: usize> From<ScaledPackedField<PT, N>> for [U; N]
141where
142	U: From<PT>,
143{
144	fn from(value: ScaledPackedField<PT, N>) -> Self {
145		value.0.map(Into::into)
146	}
147}
148
149unsafe impl<PT: Zeroable, const N: usize> Zeroable for ScaledPackedField<PT, N> {}
150
151unsafe impl<PT: Pod, const N: usize> Pod for ScaledPackedField<PT, N> {}
152
153impl<PT: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledPackedField<PT, N> {
154	fn ct_eq(&self, other: &Self) -> subtle::Choice {
155		self.0.ct_eq(&other.0)
156	}
157}
158
159impl<PT: Copy + Add<Output = PT>, const N: usize> Add for ScaledPackedField<PT, N>
160where
161	Self: Default,
162{
163	type Output = Self;
164
165	fn add(self, rhs: Self) -> Self {
166		Self::from_direct_packed_fn(|i| self.0[i] + rhs.0[i])
167	}
168}
169
170impl<PT: Copy + AddAssign, const N: usize> AddAssign for ScaledPackedField<PT, N>
171where
172	Self: Default,
173{
174	fn add_assign(&mut self, rhs: Self) {
175		for i in 0..N {
176			self.0[i] += rhs.0[i];
177		}
178	}
179}
180
181impl<PT: Copy + Sub<Output = PT>, const N: usize> Sub for ScaledPackedField<PT, N>
182where
183	Self: Default,
184{
185	type Output = Self;
186
187	fn sub(self, rhs: Self) -> Self {
188		Self::from_direct_packed_fn(|i| self.0[i] - rhs.0[i])
189	}
190}
191
192impl<PT: Copy + SubAssign, const N: usize> SubAssign for ScaledPackedField<PT, N>
193where
194	Self: Default,
195{
196	fn sub_assign(&mut self, rhs: Self) {
197		for i in 0..N {
198			self.0[i] -= rhs.0[i];
199		}
200	}
201}
202
203impl<PT: Copy + Mul<Output = PT>, const N: usize> Mul for ScaledPackedField<PT, N>
204where
205	Self: Default,
206{
207	type Output = Self;
208
209	fn mul(self, rhs: Self) -> Self {
210		Self::from_direct_packed_fn(|i| self.0[i] * rhs.0[i])
211	}
212}
213
214impl<PT: Copy + MulAssign, const N: usize> MulAssign for ScaledPackedField<PT, N>
215where
216	Self: Default,
217{
218	fn mul_assign(&mut self, rhs: Self) {
219		for i in 0..N {
220			self.0[i] *= rhs.0[i];
221		}
222	}
223}
224
225/// Currently we use this trait only in this file for compactness.
226/// If it is useful in some other place it worth moving it to `arithmetic_traits.rs`
227trait ArithmeticOps<Rhs>:
228	Add<Rhs, Output = Self>
229	+ AddAssign<Rhs>
230	+ Sub<Rhs, Output = Self>
231	+ SubAssign<Rhs>
232	+ Mul<Rhs, Output = Self>
233	+ MulAssign<Rhs>
234{
235}
236
237impl<T, Rhs> ArithmeticOps<Rhs> for T where
238	T: Add<Rhs, Output = Self>
239		+ AddAssign<Rhs>
240		+ Sub<Rhs, Output = Self>
241		+ SubAssign<Rhs>
242		+ Mul<Rhs, Output = Self>
243		+ MulAssign<Rhs>
244{
245}
246
247impl<PT: Add<Output = PT> + Copy, const N: usize> Sum for ScaledPackedField<PT, N>
248where
249	Self: Default,
250{
251	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
252		iter.fold(Self::default(), |l, r| l + r)
253	}
254}
255
256impl<PT: PackedField, const N: usize> Product for ScaledPackedField<PT, N>
257where
258	[PT; N]: Default,
259{
260	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
261		let one = Self([PT::one(); N]);
262		iter.fold(one, |l, r| l * r)
263	}
264}
265
266impl<PT: PackedField, const N: usize> PackedField for ScaledPackedField<PT, N>
267where
268	[PT; N]: Default,
269	Self: ArithmeticOps<PT::Scalar>,
270{
271	type Scalar = PT::Scalar;
272
273	const LOG_WIDTH: usize = PT::LOG_WIDTH + checked_log_2(N);
274
275	#[inline]
276	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar {
277		let outer_i = i / PT::WIDTH;
278		let inner_i = i % PT::WIDTH;
279		unsafe { self.0.get_unchecked(outer_i).get_unchecked(inner_i) }
280	}
281
282	#[inline]
283	unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) {
284		let outer_i = i / PT::WIDTH;
285		let inner_i = i % PT::WIDTH;
286		unsafe {
287			self.0
288				.get_unchecked_mut(outer_i)
289				.set_unchecked(inner_i, scalar);
290		}
291	}
292
293	#[inline]
294	fn zero() -> Self {
295		Self(array::from_fn(|_| PT::zero()))
296	}
297
298	fn random(mut rng: impl RngCore) -> Self {
299		Self(array::from_fn(|_| PT::random(&mut rng)))
300	}
301
302	#[inline]
303	fn broadcast(scalar: Self::Scalar) -> Self {
304		Self(array::from_fn(|_| PT::broadcast(scalar)))
305	}
306
307	#[inline]
308	fn square(self) -> Self {
309		Self(self.0.map(|v| v.square()))
310	}
311
312	#[inline]
313	fn invert_or_zero(self) -> Self {
314		Self(self.0.map(|v| v.invert_or_zero()))
315	}
316
317	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
318		let mut first = [Default::default(); N];
319		let mut second = [Default::default(); N];
320
321		if log_block_len >= PT::LOG_WIDTH {
322			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
323			for i in (0..N).step_by(block_in_pts * 2) {
324				first[i..i + block_in_pts].copy_from_slice(&self.0[i..i + block_in_pts]);
325				first[i + block_in_pts..i + 2 * block_in_pts]
326					.copy_from_slice(&other.0[i..i + block_in_pts]);
327
328				second[i..i + block_in_pts]
329					.copy_from_slice(&self.0[i + block_in_pts..i + 2 * block_in_pts]);
330				second[i + block_in_pts..i + 2 * block_in_pts]
331					.copy_from_slice(&other.0[i + block_in_pts..i + 2 * block_in_pts]);
332			}
333		} else {
334			for i in 0..N {
335				(first[i], second[i]) = self.0[i].interleave(other.0[i], log_block_len);
336			}
337		}
338
339		(Self(first), Self(second))
340	}
341
342	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) {
343		let mut first = [Default::default(); N];
344		let mut second = [Default::default(); N];
345
346		if log_block_len >= PT::LOG_WIDTH {
347			let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH);
348			for i in (0..N / 2).step_by(block_in_pts) {
349				first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]);
350
351				second[i..i + block_in_pts]
352					.copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
353			}
354
355			for i in (0..N / 2).step_by(block_in_pts) {
356				first[i + N / 2..i + N / 2 + block_in_pts]
357					.copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]);
358
359				second[i + N / 2..i + N / 2 + block_in_pts]
360					.copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]);
361			}
362		} else {
363			for i in 0..N / 2 {
364				(first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len);
365			}
366
367			for i in 0..N / 2 {
368				(first[i + N / 2], second[i + N / 2]) =
369					other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len);
370			}
371		}
372
373		(Self(first), Self(second))
374	}
375
376	#[inline]
377	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
378		unsafe { Self::spread_unchecked(self, log_block_len, block_idx) }
379	}
380
381	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
382		Self(array::from_fn(|i| PT::from_fn(|j| f(i * PT::WIDTH + j))))
383	}
384
385	#[inline]
386	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
387		// Safety: `Self` has the same layout as `[PT; N]` because it is a transparent wrapper.
388		let cast_slice =
389			unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const [PT; N], slice.len()) };
390
391		PT::iter_slice(cast_slice.as_flattened())
392	}
393}
394
395impl<PT: PackedField + MulAlpha, const N: usize> MulAlpha for ScaledPackedField<PT, N>
396where
397	[PT; N]: Default,
398{
399	#[inline]
400	fn mul_alpha(self) -> Self {
401		Self(self.0.map(|v| v.mul_alpha()))
402	}
403}
404
405/// Per-element transformation as a scaled packed field.
406pub struct ScaledTransformation<I> {
407	inner: I,
408}
409
410impl<I> ScaledTransformation<I> {
411	const fn new(inner: I) -> Self {
412		Self { inner }
413	}
414}
415
416impl<OP, IP, const N: usize, I> Transformation<ScaledPackedField<IP, N>, ScaledPackedField<OP, N>>
417	for ScaledTransformation<I>
418where
419	I: Transformation<IP, OP>,
420{
421	fn transform(&self, data: &ScaledPackedField<IP, N>) -> ScaledPackedField<OP, N> {
422		ScaledPackedField::from_direct_packed_fn(|i| self.inner.transform(&data.0[i]))
423	}
424}
425
426impl<OP, IP, const N: usize> PackedTransformationFactory<ScaledPackedField<OP, N>>
427	for ScaledPackedField<IP, N>
428where
429	Self: PackedBinaryField,
430	ScaledPackedField<OP, N>: PackedBinaryField<Scalar = OP::Scalar>,
431	OP: PackedBinaryField,
432	IP: PackedTransformationFactory<OP>,
433{
434	type PackedTransformation<Data: AsRef<[OP::Scalar]> + Sync> =
435		ScaledTransformation<IP::PackedTransformation<Data>>;
436
437	fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
438		transformation: FieldLinearTransformation<
439			<ScaledPackedField<OP, N> as PackedField>::Scalar,
440			Data,
441		>,
442	) -> Self::PackedTransformation<Data> {
443		ScaledTransformation::new(IP::make_packed_transformation(transformation))
444	}
445}
446
447/// The only thing that prevents us from having pure generic implementation of `ScaledPackedField`
448/// is that we can't have generic operations both with `Self` and `PT::Scalar`
449/// (it leads to `conflicting implementations of trait` error).
450/// That's why we implement one of those in a macro.
451macro_rules! packed_scaled_field {
452	($name:ident = [$inner:ty;$size:literal]) => {
453		pub type $name = $crate::arch::portable::packed_scaled::ScaledPackedField<$inner, $size>;
454
455		impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name {
456			type Output = Self;
457
458			#[inline]
459			fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
460				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
461				for v in self.0.iter_mut() {
462					*v += broadcast;
463				}
464
465				self
466			}
467		}
468
469		impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
470			#[inline]
471			fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
472				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
473				for v in self.0.iter_mut() {
474					*v += broadcast;
475				}
476			}
477		}
478
479		impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name {
480			type Output = Self;
481
482			#[inline]
483			fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
484				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
485				for v in self.0.iter_mut() {
486					*v -= broadcast;
487				}
488
489				self
490			}
491		}
492
493		impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
494			#[inline]
495			fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
496				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
497				for v in self.0.iter_mut() {
498					*v -= broadcast;
499				}
500			}
501		}
502
503		impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name {
504			type Output = Self;
505
506			#[inline]
507			fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self {
508				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
509				for v in self.0.iter_mut() {
510					*v *= broadcast;
511				}
512
513				self
514			}
515		}
516
517		impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name {
518			#[inline]
519			fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) {
520				let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs);
521				for v in self.0.iter_mut() {
522					*v *= broadcast;
523				}
524			}
525		}
526	};
527}
528
529pub(crate) use packed_scaled_field;
530
531unsafe impl<PT, const N: usize> WithUnderlier for ScaledPackedField<PT, N>
532where
533	PT: WithUnderlier<Underlier: Pod>,
534{
535	type Underlier = ScaledUnderlier<PT::Underlier, N>;
536
537	fn to_underlier(self) -> Self::Underlier {
538		TransparentWrapper::peel(self)
539	}
540
541	fn to_underlier_ref(&self) -> &Self::Underlier {
542		TransparentWrapper::peel_ref(self)
543	}
544
545	fn to_underlier_ref_mut(&mut self) -> &mut Self::Underlier {
546		TransparentWrapper::peel_mut(self)
547	}
548
549	fn to_underliers_ref(val: &[Self]) -> &[Self::Underlier] {
550		TransparentWrapper::peel_slice(val)
551	}
552
553	fn to_underliers_ref_mut(val: &mut [Self]) -> &mut [Self::Underlier] {
554		TransparentWrapper::peel_slice_mut(val)
555	}
556
557	fn from_underlier(val: Self::Underlier) -> Self {
558		TransparentWrapper::wrap(val)
559	}
560
561	fn from_underlier_ref(val: &Self::Underlier) -> &Self {
562		TransparentWrapper::wrap_ref(val)
563	}
564
565	fn from_underlier_ref_mut(val: &mut Self::Underlier) -> &mut Self {
566		TransparentWrapper::wrap_mut(val)
567	}
568
569	fn from_underliers_ref(val: &[Self::Underlier]) -> &[Self] {
570		TransparentWrapper::wrap_slice(val)
571	}
572
573	fn from_underliers_ref_mut(val: &mut [Self::Underlier]) -> &mut [Self] {
574		TransparentWrapper::wrap_slice_mut(val)
575	}
576}
577
578impl<U, F, const N: usize> PackScalar<F> for ScaledUnderlier<U, N>
579where
580	U: PackScalar<F> + UnderlierType + Pod,
581	F: Field,
582	ScaledPackedField<U::Packed, N>: PackedField<Scalar = F> + WithUnderlier<Underlier = Self>,
583{
584	type Packed = ScaledPackedField<U::Packed, N>;
585}
586
587unsafe impl<PT, U, const N: usize> TransparentWrapper<ScaledUnderlier<U, N>>
588	for ScaledPackedField<PT, N>
589where
590	PT: WithUnderlier<Underlier = U>,
591{
592}
593
594#[cfg(test)]
595mod tests {
596	use binius_utils::{SerializationMode, SerializeBytes, bytes::BytesMut};
597	use rand::{Rng, SeedableRng, rngs::StdRng};
598
599	use super::ScaledPackedField;
600	use crate::{PackedBinaryField2x4b, PackedBinaryField4x4b};
601
602	#[test]
603	fn test_equivalent_serialization_between_packed_representations() {
604		let mode = SerializationMode::Native;
605
606		let mut rng = StdRng::seed_from_u64(0);
607
608		let byte_low: u8 = rng.r#gen();
609		let byte_high: u8 = rng.r#gen();
610
611		let combined_underlier = ((byte_high as u16) << 8) | (byte_low as u16);
612
613		let packed = PackedBinaryField4x4b::from_underlier(combined_underlier);
614		let packed_equivalent =
615			ScaledPackedField::<PackedBinaryField2x4b, 2>::from([byte_low, byte_high]);
616
617		let mut buffer_packed = BytesMut::new();
618		let mut buffer_equivalent = BytesMut::new();
619
620		packed.serialize(&mut buffer_packed, mode).unwrap();
621		packed_equivalent
622			.serialize(&mut buffer_equivalent, mode)
623			.unwrap();
624
625		assert_eq!(buffer_packed, buffer_equivalent);
626	}
627}