binius_field/underlier/
scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{Pod, Zeroable};
10use rand::{
11	Rng,
12	distr::{Distribution, StandardUniform},
13};
14
15use super::{NumCast, UnderlierType, UnderlierWithBitOps};
16use crate::Random;
17
18/// A type that represents a pair of elements of the same underlier type.
19/// We use it as an underlier for the `ScaledPackedField` type.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
21#[repr(transparent)]
22pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
23
24impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
25	fn default() -> Self {
26		Self(array::from_fn(|_| U::default()))
27	}
28}
29
30impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
31	fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
32		ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
33	}
34}
35
36impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
37	fn from(val: ScaledUnderlier<U, N>) -> Self {
38		val.0
39	}
40}
41
42impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
43	fn from(value: [T; N]) -> Self {
44		Self(value.map(U::from))
45	}
46}
47
48impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
49	fn from(value: [T; 4]) -> Self {
50		Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
51	}
52}
53
54unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
55
56unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
57
58impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
59	const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
60}
61
62impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
63	type Output = Self;
64
65	fn bitand(self, rhs: Self) -> Self::Output {
66		Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
67	}
68}
69
70impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
71	fn bitand_assign(&mut self, rhs: Self) {
72		for i in 0..N {
73			self.0[i] &= rhs.0[i];
74		}
75	}
76}
77
78impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
79	type Output = Self;
80
81	fn bitor(self, rhs: Self) -> Self::Output {
82		Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
83	}
84}
85
86impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
87	fn bitor_assign(&mut self, rhs: Self) {
88		for i in 0..N {
89			self.0[i] |= rhs.0[i];
90		}
91	}
92}
93
94impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
95	type Output = Self;
96
97	fn bitxor(self, rhs: Self) -> Self::Output {
98		Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
99	}
100}
101
102impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
103	fn bitxor_assign(&mut self, rhs: Self) {
104		for i in 0..N {
105			self.0[i] ^= rhs.0[i];
106		}
107	}
108}
109
110impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
111	type Output = Self;
112
113	fn shr(self, rhs: usize) -> Self::Output {
114		let mut result = Self::default();
115
116		let shift_in_items = rhs / U::BITS;
117		for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
118			if i + shift_in_items < N {
119				result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
120			}
121			if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
122				result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
123			}
124		}
125
126		result
127	}
128}
129
130impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
131	type Output = Self;
132
133	fn shl(self, rhs: usize) -> Self::Output {
134		let mut result = Self::default();
135
136		let shift_in_items = rhs / U::BITS;
137		for i in shift_in_items.saturating_sub(1)..N {
138			if i >= shift_in_items {
139				result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
140			}
141			if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
142				result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
143			}
144		}
145
146		result
147	}
148}
149
150impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
151	type Output = Self;
152
153	fn not(self) -> Self::Output {
154		Self(self.0.map(U::not))
155	}
156}
157
158impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
159where
160	U: UnderlierWithBitOps + Pod + From<u8>,
161	u8: NumCast<U>,
162{
163	const ZERO: Self = Self([U::ZERO; N]);
164	const ONE: Self = {
165		let mut arr = [U::ZERO; N];
166		arr[0] = U::ONE;
167		Self(arr)
168	};
169	const ONES: Self = Self([U::ONES; N]);
170
171	#[inline]
172	fn fill_with_bit(val: u8) -> Self {
173		Self(array::from_fn(|_| U::fill_with_bit(val)))
174	}
175}
176
177impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
178where
179	Self: NumCast<U>,
180{
181	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
182		Self::num_cast_from(val.0[0])
183	}
184}
185
186impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
187where
188	U: From<u8>,
189{
190	fn from(val: u8) -> Self {
191		Self(array::from_fn(|_| U::from(val)))
192	}
193}
194
195#[cfg(test)]
196mod tests {
197	use super::*;
198
199	#[test]
200	fn test_shr() {
201		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
202		assert_eq!(
203			val >> 1,
204			ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
205		);
206		assert_eq!(
207			val >> 2,
208			ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
209		);
210		assert_eq!(
211			val >> 8,
212			ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
213		);
214		assert_eq!(
215			val >> 9,
216			ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
217		);
218	}
219
220	#[test]
221	fn test_shl() {
222		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
223		assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
224		assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
225		assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
226		assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
227	}
228}