1use std::{
4 array,
5 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{NoUninit, Pod, Zeroable, must_cast_mut, must_cast_ref};
10use rand::{
11 Rng,
12 distr::{Distribution, StandardUniform},
13};
14
15use super::{Divisible, NumCast, UnderlierType, UnderlierWithBitOps};
16use crate::Random;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
21#[repr(transparent)]
22pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
23
24impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
25 fn default() -> Self {
26 Self(array::from_fn(|_| U::default()))
27 }
28}
29
30impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
31 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
32 ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
33 }
34}
35
36impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
37 fn from(val: ScaledUnderlier<U, N>) -> Self {
38 val.0
39 }
40}
41
42impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
43 fn from(value: [T; N]) -> Self {
44 Self(value.map(U::from))
45 }
46}
47
48impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
49 fn from(value: [T; 4]) -> Self {
50 Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
51 }
52}
53
54unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
55
56unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
57
58impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
59 const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
60}
61
62unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
63where
64 Self: UnderlierType,
65 U: UnderlierType,
66{
67 type Array = [U; N];
68
69 #[inline]
70 fn split_val(self) -> Self::Array {
71 self.0
72 }
73
74 #[inline]
75 fn split_ref(&self) -> &[U] {
76 &self.0
77 }
78
79 #[inline]
80 fn split_mut(&mut self) -> &mut [U] {
81 &mut self.0
82 }
83}
84
85unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
86where
87 Self: UnderlierType + NoUninit,
88 U: UnderlierType + Pod,
89{
90 type Array = [U; 4];
91
92 #[inline]
93 fn split_val(self) -> Self::Array {
94 bytemuck::must_cast(self)
95 }
96
97 #[inline]
98 fn split_ref(&self) -> &[U] {
99 must_cast_ref::<Self, [U; 4]>(self)
100 }
101
102 #[inline]
103 fn split_mut(&mut self) -> &mut [U] {
104 must_cast_mut::<Self, [U; 4]>(self)
105 }
106}
107
108impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
109 type Output = Self;
110
111 fn bitand(self, rhs: Self) -> Self::Output {
112 Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
113 }
114}
115
116impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
117 fn bitand_assign(&mut self, rhs: Self) {
118 for i in 0..N {
119 self.0[i] &= rhs.0[i];
120 }
121 }
122}
123
124impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
125 type Output = Self;
126
127 fn bitor(self, rhs: Self) -> Self::Output {
128 Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
129 }
130}
131
132impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
133 fn bitor_assign(&mut self, rhs: Self) {
134 for i in 0..N {
135 self.0[i] |= rhs.0[i];
136 }
137 }
138}
139
140impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
141 type Output = Self;
142
143 fn bitxor(self, rhs: Self) -> Self::Output {
144 Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
145 }
146}
147
148impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
149 fn bitxor_assign(&mut self, rhs: Self) {
150 for i in 0..N {
151 self.0[i] ^= rhs.0[i];
152 }
153 }
154}
155
156impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
157 type Output = Self;
158
159 fn shr(self, rhs: usize) -> Self::Output {
160 let mut result = Self::default();
161
162 let shift_in_items = rhs / U::BITS;
163 for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
164 if i + shift_in_items < N {
165 result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
166 }
167 if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
168 result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
169 }
170 }
171
172 result
173 }
174}
175
176impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
177 type Output = Self;
178
179 fn shl(self, rhs: usize) -> Self::Output {
180 let mut result = Self::default();
181
182 let shift_in_items = rhs / U::BITS;
183 for i in shift_in_items.saturating_sub(1)..N {
184 if i >= shift_in_items {
185 result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
186 }
187 if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
188 result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
189 }
190 }
191
192 result
193 }
194}
195
196impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
197 type Output = Self;
198
199 fn not(self) -> Self::Output {
200 Self(self.0.map(U::not))
201 }
202}
203
204impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
205where
206 U: UnderlierWithBitOps + Pod + From<u8>,
207 u8: NumCast<U>,
208{
209 const ZERO: Self = Self([U::ZERO; N]);
210 const ONE: Self = {
211 let mut arr = [U::ZERO; N];
212 arr[0] = U::ONE;
213 Self(arr)
214 };
215 const ONES: Self = Self([U::ONES; N]);
216
217 #[inline]
218 fn fill_with_bit(val: u8) -> Self {
219 Self(array::from_fn(|_| U::fill_with_bit(val)))
220 }
221
222 #[inline]
223 fn shl_128b_lanes(self, rhs: usize) -> Self {
224 assert!(U::BITS >= 128);
229
230 Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
231 }
232
233 #[inline]
234 fn shr_128b_lanes(self, rhs: usize) -> Self {
235 assert!(U::BITS >= 128);
240
241 Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
242 }
243
244 #[inline]
245 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
246 assert!(U::BITS >= 128);
251
252 Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
253 }
254
255 #[inline]
256 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
257 assert!(U::BITS >= 128);
262
263 Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
264 }
265}
266
267impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
268where
269 Self: NumCast<U>,
270{
271 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
272 Self::num_cast_from(val.0[0])
273 }
274}
275
276impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
277where
278 U: From<u8>,
279{
280 fn from(val: u8) -> Self {
281 Self(array::from_fn(|_| U::from(val)))
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_shr() {
291 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
292 assert_eq!(
293 val >> 1,
294 ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
295 );
296 assert_eq!(
297 val >> 2,
298 ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
299 );
300 assert_eq!(
301 val >> 8,
302 ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
303 );
304 assert_eq!(
305 val >> 9,
306 ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
307 );
308 }
309
310 #[test]
311 fn test_shl() {
312 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
313 assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
314 assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
315 assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
316 assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
317 }
318}