binius_field/underlier/
scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{must_cast_mut, must_cast_ref, NoUninit, Pod, Zeroable};
10use rand::RngCore;
11use subtle::{Choice, ConstantTimeEq};
12
13use super::{Divisible, NumCast, Random, UnderlierType, UnderlierWithBitOps};
14use crate::tower_levels::TowerLevel;
15
16/// A type that represents a pair of elements of the same underlier type.
17/// We use it as an underlier for the `ScaledPAckedField` type.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19#[repr(transparent)]
20pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
21
22impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
23	fn default() -> Self {
24		Self(array::from_fn(|_| U::default()))
25	}
26}
27
28impl<U: Random, const N: usize> Random for ScaledUnderlier<U, N> {
29	fn random(mut rng: impl RngCore) -> Self {
30		Self(array::from_fn(|_| U::random(&mut rng)))
31	}
32}
33
34impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
35	fn from(val: ScaledUnderlier<U, N>) -> Self {
36		val.0
37	}
38}
39
40impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
41	fn from(value: [T; N]) -> Self {
42		Self(value.map(U::from))
43	}
44}
45
46impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
47	fn from(value: [T; 4]) -> Self {
48		Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
49	}
50}
51
52impl<U: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledUnderlier<U, N> {
53	fn ct_eq(&self, other: &Self) -> Choice {
54		self.0.ct_eq(&other.0)
55	}
56}
57
58unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
59
60unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
61
62impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
63	const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
64}
65
66unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
67where
68	Self: UnderlierType,
69	U: UnderlierType,
70{
71	type Array = [U; N];
72
73	#[inline]
74	fn split_val(self) -> Self::Array {
75		self.0
76	}
77
78	#[inline]
79	fn split_ref(&self) -> &[U] {
80		&self.0
81	}
82
83	#[inline]
84	fn split_mut(&mut self) -> &mut [U] {
85		&mut self.0
86	}
87}
88
89unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
90where
91	Self: UnderlierType + NoUninit,
92	U: UnderlierType + Pod,
93{
94	type Array = [U; 4];
95
96	#[inline]
97	fn split_val(self) -> Self::Array {
98		bytemuck::must_cast(self)
99	}
100
101	#[inline]
102	fn split_ref(&self) -> &[U] {
103		must_cast_ref::<Self, [U; 4]>(self)
104	}
105
106	#[inline]
107	fn split_mut(&mut self) -> &mut [U] {
108		must_cast_mut::<Self, [U; 4]>(self)
109	}
110}
111
112impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
113	type Output = Self;
114
115	fn bitand(self, rhs: Self) -> Self::Output {
116		Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
117	}
118}
119
120impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
121	fn bitand_assign(&mut self, rhs: Self) {
122		for i in 0..N {
123			self.0[i] &= rhs.0[i];
124		}
125	}
126}
127
128impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
129	type Output = Self;
130
131	fn bitor(self, rhs: Self) -> Self::Output {
132		Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
133	}
134}
135
136impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
137	fn bitor_assign(&mut self, rhs: Self) {
138		for i in 0..N {
139			self.0[i] |= rhs.0[i];
140		}
141	}
142}
143
144impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
145	type Output = Self;
146
147	fn bitxor(self, rhs: Self) -> Self::Output {
148		Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
149	}
150}
151
152impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
153	fn bitxor_assign(&mut self, rhs: Self) {
154		for i in 0..N {
155			self.0[i] ^= rhs.0[i];
156		}
157	}
158}
159
160impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
161	type Output = Self;
162
163	fn shr(self, rhs: usize) -> Self::Output {
164		let mut result = Self::default();
165
166		let shift_in_items = rhs / U::BITS;
167		for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
168			if i + shift_in_items < N {
169				result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
170			}
171			if i + shift_in_items + 1 < N && rhs % U::BITS != 0 {
172				result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
173			}
174		}
175
176		result
177	}
178}
179
180impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
181	type Output = Self;
182
183	fn shl(self, rhs: usize) -> Self::Output {
184		let mut result = Self::default();
185
186		let shift_in_items = rhs / U::BITS;
187		for i in shift_in_items.saturating_sub(1)..N {
188			if i >= shift_in_items {
189				result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
190			}
191			if i > shift_in_items && rhs % U::BITS != 0 {
192				result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
193			}
194		}
195
196		result
197	}
198}
199
200impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
201	type Output = Self;
202
203	fn not(self) -> Self::Output {
204		Self(self.0.map(U::not))
205	}
206}
207
208impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
209where
210	U: UnderlierWithBitOps + Pod + From<u8>,
211	u8: NumCast<U>,
212{
213	const ZERO: Self = Self([U::ZERO; N]);
214	const ONE: Self = {
215		let mut arr = [U::ZERO; N];
216		arr[0] = U::ONE;
217		Self(arr)
218	};
219	const ONES: Self = Self([U::ONES; N]);
220
221	#[inline]
222	fn fill_with_bit(val: u8) -> Self {
223		Self(array::from_fn(|_| U::fill_with_bit(val)))
224	}
225
226	#[inline]
227	fn shl_128b_lanes(self, rhs: usize) -> Self {
228		// We assume that the underlier type has at least 128 bits as the current implementation
229		// is valid for this case only.
230		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
231		assert!(U::BITS >= 128);
232
233		Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
234	}
235
236	#[inline]
237	fn shr_128b_lanes(self, rhs: usize) -> Self {
238		// We assume that the underlier type has at least 128 bits as the current implementation
239		// is valid for this case only.
240		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
241		assert!(U::BITS >= 128);
242
243		Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
244	}
245
246	#[inline]
247	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
248		// We assume that the underlier type has at least 128 bits as the current implementation
249		// is valid for this case only.
250		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
251		assert!(U::BITS >= 128);
252
253		Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
254	}
255
256	#[inline]
257	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
258		// We assume that the underlier type has at least 128 bits as the current implementation
259		// is valid for this case only.
260		// On practice, we don't use scaled underliers with underlier types that have less than 128 bits.
261		assert!(U::BITS >= 128);
262
263		Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
264	}
265
266	#[inline]
267	fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
268	where
269		u8: NumCast<Self>,
270		Self: From<u8>,
271	{
272		for col in 0..N {
273			let mut column = TL::from_fn(|row| values[row].0[col]);
274			U::transpose_bytes_from_byte_sliced::<TL>(&mut column);
275			for row in 0..TL::WIDTH {
276				values[row].0[col] = column[row];
277			}
278		}
279
280		let mut result = TL::default::<Self>();
281		for row in 0..TL::WIDTH {
282			for col in 0..N {
283				let index = row * N + col;
284
285				result[row].0[col] = values[index % TL::WIDTH].0[index / TL::WIDTH];
286			}
287		}
288
289		*values = result;
290	}
291
292	#[inline]
293	fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
294	where
295		u8: NumCast<Self>,
296		Self: From<u8>,
297	{
298		let mut result = TL::from_fn(|row| {
299			Self(array::from_fn(|col| {
300				let index = row + col * TL::WIDTH;
301
302				values[index / N].0[index % N]
303			}))
304		});
305
306		for col in 0..N {
307			let mut column = TL::from_fn(|row| result[row].0[col]);
308			U::transpose_bytes_to_byte_sliced::<TL>(&mut column);
309			for row in 0..TL::WIDTH {
310				result[row].0[col] = column[row];
311			}
312		}
313
314		*values = result;
315	}
316}
317
318impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
319where
320	Self: NumCast<U>,
321{
322	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
323		Self::num_cast_from(val.0[0])
324	}
325}
326
327impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
328where
329	U: From<u8>,
330{
331	fn from(val: u8) -> Self {
332		Self(array::from_fn(|_| U::from(val)))
333	}
334}
335
336#[cfg(test)]
337mod tests {
338	use super::*;
339
340	#[test]
341	fn test_shr() {
342		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
343		assert_eq!(
344			val >> 1,
345			ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
346		);
347		assert_eq!(
348			val >> 2,
349			ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
350		);
351		assert_eq!(
352			val >> 8,
353			ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
354		);
355		assert_eq!(
356			val >> 9,
357			ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
358		);
359	}
360
361	#[test]
362	fn test_shl() {
363		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
364		assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
365		assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
366		assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
367		assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
368	}
369}