binius_field/underlier/
scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{NoUninit, Pod, Zeroable, must_cast_mut, must_cast_ref};
10use rand::{
11	Rng,
12	distr::{Distribution, StandardUniform},
13};
14
15use super::{Divisible, NumCast, UnderlierType, UnderlierWithBitOps};
16use crate::Random;
17
18/// A type that represents a pair of elements of the same underlier type.
19/// We use it as an underlier for the `ScaledPackedField` type.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
21#[repr(transparent)]
22pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
23
24impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
25	fn default() -> Self {
26		Self(array::from_fn(|_| U::default()))
27	}
28}
29
30impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
31	fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
32		ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
33	}
34}
35
36impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
37	fn from(val: ScaledUnderlier<U, N>) -> Self {
38		val.0
39	}
40}
41
42impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
43	fn from(value: [T; N]) -> Self {
44		Self(value.map(U::from))
45	}
46}
47
48impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
49	fn from(value: [T; 4]) -> Self {
50		Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
51	}
52}
53
54unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
55
56unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
57
58impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
59	const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
60}
61
62unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
63where
64	Self: UnderlierType,
65	U: UnderlierType,
66{
67	type Array = [U; N];
68
69	#[inline]
70	fn split_val(self) -> Self::Array {
71		self.0
72	}
73
74	#[inline]
75	fn split_ref(&self) -> &[U] {
76		&self.0
77	}
78
79	#[inline]
80	fn split_mut(&mut self) -> &mut [U] {
81		&mut self.0
82	}
83}
84
85unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
86where
87	Self: UnderlierType + NoUninit,
88	U: UnderlierType + Pod,
89{
90	type Array = [U; 4];
91
92	#[inline]
93	fn split_val(self) -> Self::Array {
94		bytemuck::must_cast(self)
95	}
96
97	#[inline]
98	fn split_ref(&self) -> &[U] {
99		must_cast_ref::<Self, [U; 4]>(self)
100	}
101
102	#[inline]
103	fn split_mut(&mut self) -> &mut [U] {
104		must_cast_mut::<Self, [U; 4]>(self)
105	}
106}
107
108impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
109	type Output = Self;
110
111	fn bitand(self, rhs: Self) -> Self::Output {
112		Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
113	}
114}
115
116impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
117	fn bitand_assign(&mut self, rhs: Self) {
118		for i in 0..N {
119			self.0[i] &= rhs.0[i];
120		}
121	}
122}
123
124impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
125	type Output = Self;
126
127	fn bitor(self, rhs: Self) -> Self::Output {
128		Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
129	}
130}
131
132impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
133	fn bitor_assign(&mut self, rhs: Self) {
134		for i in 0..N {
135			self.0[i] |= rhs.0[i];
136		}
137	}
138}
139
140impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
141	type Output = Self;
142
143	fn bitxor(self, rhs: Self) -> Self::Output {
144		Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
145	}
146}
147
148impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
149	fn bitxor_assign(&mut self, rhs: Self) {
150		for i in 0..N {
151			self.0[i] ^= rhs.0[i];
152		}
153	}
154}
155
156impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
157	type Output = Self;
158
159	fn shr(self, rhs: usize) -> Self::Output {
160		let mut result = Self::default();
161
162		let shift_in_items = rhs / U::BITS;
163		for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
164			if i + shift_in_items < N {
165				result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
166			}
167			if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
168				result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
169			}
170		}
171
172		result
173	}
174}
175
176impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
177	type Output = Self;
178
179	fn shl(self, rhs: usize) -> Self::Output {
180		let mut result = Self::default();
181
182		let shift_in_items = rhs / U::BITS;
183		for i in shift_in_items.saturating_sub(1)..N {
184			if i >= shift_in_items {
185				result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
186			}
187			if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
188				result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
189			}
190		}
191
192		result
193	}
194}
195
196impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
197	type Output = Self;
198
199	fn not(self) -> Self::Output {
200		Self(self.0.map(U::not))
201	}
202}
203
204impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
205where
206	U: UnderlierWithBitOps + Pod + From<u8>,
207	u8: NumCast<U>,
208{
209	const ZERO: Self = Self([U::ZERO; N]);
210	const ONE: Self = {
211		let mut arr = [U::ZERO; N];
212		arr[0] = U::ONE;
213		Self(arr)
214	};
215	const ONES: Self = Self([U::ONES; N]);
216
217	#[inline]
218	fn fill_with_bit(val: u8) -> Self {
219		Self(array::from_fn(|_| U::fill_with_bit(val)))
220	}
221
222	#[inline]
223	fn shl_128b_lanes(self, rhs: usize) -> Self {
224		// We assume that the underlier type has at least 128 bits as the current implementation
225		// is valid for this case only.
226		// On practice, we don't use scaled underliers with underlier types that have less than 128
227		// bits.
228		assert!(U::BITS >= 128);
229
230		Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
231	}
232
233	#[inline]
234	fn shr_128b_lanes(self, rhs: usize) -> Self {
235		// We assume that the underlier type has at least 128 bits as the current implementation
236		// is valid for this case only.
237		// On practice, we don't use scaled underliers with underlier types that have less than 128
238		// bits.
239		assert!(U::BITS >= 128);
240
241		Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
242	}
243
244	#[inline]
245	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
246		// We assume that the underlier type has at least 128 bits as the current implementation
247		// is valid for this case only.
248		// On practice, we don't use scaled underliers with underlier types that have less than 128
249		// bits.
250		assert!(U::BITS >= 128);
251
252		Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
253	}
254
255	#[inline]
256	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
257		// We assume that the underlier type has at least 128 bits as the current implementation
258		// is valid for this case only.
259		// On practice, we don't use scaled underliers with underlier types that have less than 128
260		// bits.
261		assert!(U::BITS >= 128);
262
263		Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
264	}
265}
266
267impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
268where
269	Self: NumCast<U>,
270{
271	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
272		Self::num_cast_from(val.0[0])
273	}
274}
275
276impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
277where
278	U: From<u8>,
279{
280	fn from(val: u8) -> Self {
281		Self(array::from_fn(|_| U::from(val)))
282	}
283}
284
285#[cfg(test)]
286mod tests {
287	use super::*;
288
289	#[test]
290	fn test_shr() {
291		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
292		assert_eq!(
293			val >> 1,
294			ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
295		);
296		assert_eq!(
297			val >> 2,
298			ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
299		);
300		assert_eq!(
301			val >> 8,
302			ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
303		);
304		assert_eq!(
305			val >> 9,
306			ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
307		);
308	}
309
310	#[test]
311	fn test_shl() {
312		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
313		assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
314		assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
315		assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
316		assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
317	}
318}