binius_field/underlier/
scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{NoUninit, Pod, Zeroable, must_cast_mut, must_cast_ref};
10use rand::RngCore;
11use subtle::{Choice, ConstantTimeEq};
12
13use super::{Divisible, NumCast, Random, UnderlierType, UnderlierWithBitOps};
14use crate::tower_levels::TowerLevel;
15
16/// A type that represents a pair of elements of the same underlier type.
17/// We use it as an underlier for the `ScaledPackedField` type.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19#[repr(transparent)]
20pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
21
22impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
23	fn default() -> Self {
24		Self(array::from_fn(|_| U::default()))
25	}
26}
27
28impl<U: Random, const N: usize> Random for ScaledUnderlier<U, N> {
29	fn random(mut rng: impl RngCore) -> Self {
30		Self(array::from_fn(|_| U::random(&mut rng)))
31	}
32}
33
34impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
35	fn from(val: ScaledUnderlier<U, N>) -> Self {
36		val.0
37	}
38}
39
40impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
41	fn from(value: [T; N]) -> Self {
42		Self(value.map(U::from))
43	}
44}
45
46impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
47	fn from(value: [T; 4]) -> Self {
48		Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
49	}
50}
51
52impl<U: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledUnderlier<U, N> {
53	fn ct_eq(&self, other: &Self) -> Choice {
54		self.0.ct_eq(&other.0)
55	}
56}
57
58unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
59
60unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
61
62impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
63	const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
64}
65
66unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
67where
68	Self: UnderlierType,
69	U: UnderlierType,
70{
71	type Array = [U; N];
72
73	#[inline]
74	fn split_val(self) -> Self::Array {
75		self.0
76	}
77
78	#[inline]
79	fn split_ref(&self) -> &[U] {
80		&self.0
81	}
82
83	#[inline]
84	fn split_mut(&mut self) -> &mut [U] {
85		&mut self.0
86	}
87}
88
89unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
90where
91	Self: UnderlierType + NoUninit,
92	U: UnderlierType + Pod,
93{
94	type Array = [U; 4];
95
96	#[inline]
97	fn split_val(self) -> Self::Array {
98		bytemuck::must_cast(self)
99	}
100
101	#[inline]
102	fn split_ref(&self) -> &[U] {
103		must_cast_ref::<Self, [U; 4]>(self)
104	}
105
106	#[inline]
107	fn split_mut(&mut self) -> &mut [U] {
108		must_cast_mut::<Self, [U; 4]>(self)
109	}
110}
111
112impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
113	type Output = Self;
114
115	fn bitand(self, rhs: Self) -> Self::Output {
116		Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
117	}
118}
119
120impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
121	fn bitand_assign(&mut self, rhs: Self) {
122		for i in 0..N {
123			self.0[i] &= rhs.0[i];
124		}
125	}
126}
127
128impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
129	type Output = Self;
130
131	fn bitor(self, rhs: Self) -> Self::Output {
132		Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
133	}
134}
135
136impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
137	fn bitor_assign(&mut self, rhs: Self) {
138		for i in 0..N {
139			self.0[i] |= rhs.0[i];
140		}
141	}
142}
143
144impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
145	type Output = Self;
146
147	fn bitxor(self, rhs: Self) -> Self::Output {
148		Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
149	}
150}
151
152impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
153	fn bitxor_assign(&mut self, rhs: Self) {
154		for i in 0..N {
155			self.0[i] ^= rhs.0[i];
156		}
157	}
158}
159
160impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
161	type Output = Self;
162
163	fn shr(self, rhs: usize) -> Self::Output {
164		let mut result = Self::default();
165
166		let shift_in_items = rhs / U::BITS;
167		for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
168			if i + shift_in_items < N {
169				result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
170			}
171			if i + shift_in_items + 1 < N && rhs % U::BITS != 0 {
172				result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
173			}
174		}
175
176		result
177	}
178}
179
180impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
181	type Output = Self;
182
183	fn shl(self, rhs: usize) -> Self::Output {
184		let mut result = Self::default();
185
186		let shift_in_items = rhs / U::BITS;
187		for i in shift_in_items.saturating_sub(1)..N {
188			if i >= shift_in_items {
189				result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
190			}
191			if i > shift_in_items && rhs % U::BITS != 0 {
192				result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
193			}
194		}
195
196		result
197	}
198}
199
200impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
201	type Output = Self;
202
203	fn not(self) -> Self::Output {
204		Self(self.0.map(U::not))
205	}
206}
207
208impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
209where
210	U: UnderlierWithBitOps + Pod + From<u8>,
211	u8: NumCast<U>,
212{
213	const ZERO: Self = Self([U::ZERO; N]);
214	const ONE: Self = {
215		let mut arr = [U::ZERO; N];
216		arr[0] = U::ONE;
217		Self(arr)
218	};
219	const ONES: Self = Self([U::ONES; N]);
220
221	#[inline]
222	fn fill_with_bit(val: u8) -> Self {
223		Self(array::from_fn(|_| U::fill_with_bit(val)))
224	}
225
226	#[inline]
227	fn shl_128b_lanes(self, rhs: usize) -> Self {
228		// We assume that the underlier type has at least 128 bits as the current implementation
229		// is valid for this case only.
230		// On practice, we don't use scaled underliers with underlier types that have less than 128
231		// bits.
232		assert!(U::BITS >= 128);
233
234		Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
235	}
236
237	#[inline]
238	fn shr_128b_lanes(self, rhs: usize) -> Self {
239		// We assume that the underlier type has at least 128 bits as the current implementation
240		// is valid for this case only.
241		// On practice, we don't use scaled underliers with underlier types that have less than 128
242		// bits.
243		assert!(U::BITS >= 128);
244
245		Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
246	}
247
248	#[inline]
249	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
250		// We assume that the underlier type has at least 128 bits as the current implementation
251		// is valid for this case only.
252		// On practice, we don't use scaled underliers with underlier types that have less than 128
253		// bits.
254		assert!(U::BITS >= 128);
255
256		Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
257	}
258
259	#[inline]
260	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
261		// We assume that the underlier type has at least 128 bits as the current implementation
262		// is valid for this case only.
263		// On practice, we don't use scaled underliers with underlier types that have less than 128
264		// bits.
265		assert!(U::BITS >= 128);
266
267		Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
268	}
269
270	#[inline]
271	fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
272	where
273		u8: NumCast<Self>,
274		Self: From<u8>,
275	{
276		for col in 0..N {
277			let mut column = TL::from_fn(|row| values[row].0[col]);
278			U::transpose_bytes_from_byte_sliced::<TL>(&mut column);
279			for row in 0..TL::WIDTH {
280				values[row].0[col] = column[row];
281			}
282		}
283
284		let mut result = TL::default::<Self>();
285		for row in 0..TL::WIDTH {
286			for col in 0..N {
287				let index = row * N + col;
288
289				result[row].0[col] = values[index % TL::WIDTH].0[index / TL::WIDTH];
290			}
291		}
292
293		*values = result;
294	}
295
296	#[inline]
297	fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
298	where
299		u8: NumCast<Self>,
300		Self: From<u8>,
301	{
302		let mut result = TL::from_fn(|row| {
303			Self(array::from_fn(|col| {
304				let index = row + col * TL::WIDTH;
305
306				values[index / N].0[index % N]
307			}))
308		});
309
310		for col in 0..N {
311			let mut column = TL::from_fn(|row| result[row].0[col]);
312			U::transpose_bytes_to_byte_sliced::<TL>(&mut column);
313			for row in 0..TL::WIDTH {
314				result[row].0[col] = column[row];
315			}
316		}
317
318		*values = result;
319	}
320}
321
322impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
323where
324	Self: NumCast<U>,
325{
326	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
327		Self::num_cast_from(val.0[0])
328	}
329}
330
331impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
332where
333	U: From<u8>,
334{
335	fn from(val: u8) -> Self {
336		Self(array::from_fn(|_| U::from(val)))
337	}
338}
339
340#[cfg(test)]
341mod tests {
342	use super::*;
343
344	#[test]
345	fn test_shr() {
346		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
347		assert_eq!(
348			val >> 1,
349			ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
350		);
351		assert_eq!(
352			val >> 2,
353			ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
354		);
355		assert_eq!(
356			val >> 8,
357			ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
358		);
359		assert_eq!(
360			val >> 9,
361			ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
362		);
363	}
364
365	#[test]
366	fn test_shl() {
367		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
368		assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
369		assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
370		assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
371		assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
372	}
373}