binius_field/underlier/
scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{must_cast_mut, must_cast_ref, NoUninit, Pod, Zeroable};
10use rand::RngCore;
11use subtle::{Choice, ConstantTimeEq};
12
13use super::{Divisible, Random, UnderlierType, UnderlierWithBitOps};
14
15/// A type that represents a pair of elements of the same underlier type.
16/// We use it as an underlier for the `ScaledPAckedField` type.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
18#[repr(transparent)]
19pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
20
21impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
22	fn default() -> Self {
23		Self(array::from_fn(|_| U::default()))
24	}
25}
26
27impl<U: Random, const N: usize> Random for ScaledUnderlier<U, N> {
28	fn random(mut rng: impl RngCore) -> Self {
29		Self(array::from_fn(|_| U::random(&mut rng)))
30	}
31}
32
33impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
34	fn from(val: ScaledUnderlier<U, N>) -> Self {
35		val.0
36	}
37}
38
39impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
40	fn from(value: [T; N]) -> Self {
41		Self(value.map(U::from))
42	}
43}
44
45impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
46	fn from(value: [T; 4]) -> Self {
47		Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
48	}
49}
50
51impl<U: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledUnderlier<U, N> {
52	fn ct_eq(&self, other: &Self) -> Choice {
53		self.0.ct_eq(&other.0)
54	}
55}
56
57unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
58
59unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
60
61impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
62	const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
63}
64
65unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
66where
67	Self: UnderlierType,
68	U: UnderlierType,
69{
70	type Array = [U; N];
71
72	#[inline]
73	fn split_val(self) -> Self::Array {
74		self.0
75	}
76
77	#[inline]
78	fn split_ref(&self) -> &[U] {
79		&self.0
80	}
81
82	#[inline]
83	fn split_mut(&mut self) -> &mut [U] {
84		&mut self.0
85	}
86}
87
88unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
89where
90	Self: UnderlierType + NoUninit,
91	U: UnderlierType + Pod,
92{
93	type Array = [U; 4];
94
95	#[inline]
96	fn split_val(self) -> Self::Array {
97		bytemuck::must_cast(self)
98	}
99
100	#[inline]
101	fn split_ref(&self) -> &[U] {
102		must_cast_ref::<Self, [U; 4]>(self)
103	}
104
105	#[inline]
106	fn split_mut(&mut self) -> &mut [U] {
107		must_cast_mut::<Self, [U; 4]>(self)
108	}
109}
110
111impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
112	type Output = Self;
113
114	fn bitand(self, rhs: Self) -> Self::Output {
115		Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
116	}
117}
118
119impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
120	fn bitand_assign(&mut self, rhs: Self) {
121		for i in 0..N {
122			self.0[i] &= rhs.0[i];
123		}
124	}
125}
126
127impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
128	type Output = Self;
129
130	fn bitor(self, rhs: Self) -> Self::Output {
131		Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
132	}
133}
134
135impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
136	fn bitor_assign(&mut self, rhs: Self) {
137		for i in 0..N {
138			self.0[i] |= rhs.0[i];
139		}
140	}
141}
142
143impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
144	type Output = Self;
145
146	fn bitxor(self, rhs: Self) -> Self::Output {
147		Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
148	}
149}
150
151impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
152	fn bitxor_assign(&mut self, rhs: Self) {
153		for i in 0..N {
154			self.0[i] ^= rhs.0[i];
155		}
156	}
157}
158
159impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
160	type Output = Self;
161
162	fn shr(self, rhs: usize) -> Self::Output {
163		let mut result = Self::default();
164
165		let shift_in_items = rhs / U::BITS;
166		for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
167			if i + shift_in_items < N {
168				result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
169			}
170			if i + shift_in_items + 1 < N && rhs % U::BITS != 0 {
171				result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
172			}
173		}
174
175		result
176	}
177}
178
179impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
180	type Output = Self;
181
182	fn shl(self, rhs: usize) -> Self::Output {
183		let mut result = Self::default();
184
185		let shift_in_items = rhs / U::BITS;
186		for i in shift_in_items.saturating_sub(1)..N {
187			if i >= shift_in_items {
188				result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
189			}
190			if i > shift_in_items && rhs % U::BITS != 0 {
191				result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
192			}
193		}
194
195		result
196	}
197}
198
199impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
200	type Output = Self;
201
202	fn not(self) -> Self::Output {
203		Self(self.0.map(U::not))
204	}
205}
206
207impl<U: UnderlierWithBitOps + Pod, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N> {
208	const ZERO: Self = Self([U::ZERO; N]);
209	const ONE: Self = {
210		let mut arr = [U::ZERO; N];
211		arr[0] = U::ONE;
212		Self(arr)
213	};
214	const ONES: Self = Self([U::ONES; N]);
215
216	#[inline]
217	fn fill_with_bit(val: u8) -> Self {
218		Self(array::from_fn(|_| U::fill_with_bit(val)))
219	}
220
221	#[inline]
222	fn shl_128b_lanes(self, rhs: usize) -> Self {
223		// We assume that the underlier type has at least 128 bits as the current implementation
224		// is valid for this case only.
225		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
226		assert!(U::BITS >= 128);
227
228		Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
229	}
230
231	#[inline]
232	fn shr_128b_lanes(self, rhs: usize) -> Self {
233		// We assume that the underlier type has at least 128 bits as the current implementation
234		// is valid for this case only.
235		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
236		assert!(U::BITS >= 128);
237
238		Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
239	}
240
241	#[inline]
242	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
243		// We assume that the underlier type has at least 128 bits as the current implementation
244		// is valid for this case only.
245		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
246		assert!(U::BITS >= 128);
247
248		Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
249	}
250
251	#[inline]
252	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
253		// We assume that the underlier type has at least 128 bits as the current implementation
254		// is valid for this case only.
255		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
256		assert!(U::BITS >= 128);
257
258		Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
259	}
260}
261
262#[cfg(test)]
263mod tests {
264	use super::*;
265
266	#[test]
267	fn test_shr() {
268		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
269		assert_eq!(
270			val >> 1,
271			ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
272		);
273		assert_eq!(
274			val >> 2,
275			ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
276		);
277		assert_eq!(
278			val >> 8,
279			ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
280		);
281		assert_eq!(
282			val >> 9,
283			ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
284		);
285	}
286
287	#[test]
288	fn test_shl() {
289		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
290		assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
291		assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
292		assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
293		assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
294	}
295}