1use std::{
4 array,
5 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{must_cast_mut, must_cast_ref, NoUninit, Pod, Zeroable};
10use rand::RngCore;
11use subtle::{Choice, ConstantTimeEq};
12
13use super::{Divisible, Random, UnderlierType, UnderlierWithBitOps};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
18#[repr(transparent)]
19pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
20
21impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
22 fn default() -> Self {
23 Self(array::from_fn(|_| U::default()))
24 }
25}
26
27impl<U: Random, const N: usize> Random for ScaledUnderlier<U, N> {
28 fn random(mut rng: impl RngCore) -> Self {
29 Self(array::from_fn(|_| U::random(&mut rng)))
30 }
31}
32
33impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
34 fn from(val: ScaledUnderlier<U, N>) -> Self {
35 val.0
36 }
37}
38
39impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
40 fn from(value: [T; N]) -> Self {
41 Self(value.map(U::from))
42 }
43}
44
45impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
46 fn from(value: [T; 4]) -> Self {
47 Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
48 }
49}
50
51impl<U: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledUnderlier<U, N> {
52 fn ct_eq(&self, other: &Self) -> Choice {
53 self.0.ct_eq(&other.0)
54 }
55}
56
57unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
58
59unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
60
61impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
62 const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
63}
64
65unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
66where
67 Self: UnderlierType,
68 U: UnderlierType,
69{
70 type Array = [U; N];
71
72 #[inline]
73 fn split_val(self) -> Self::Array {
74 self.0
75 }
76
77 #[inline]
78 fn split_ref(&self) -> &[U] {
79 &self.0
80 }
81
82 #[inline]
83 fn split_mut(&mut self) -> &mut [U] {
84 &mut self.0
85 }
86}
87
88unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
89where
90 Self: UnderlierType + NoUninit,
91 U: UnderlierType + Pod,
92{
93 type Array = [U; 4];
94
95 #[inline]
96 fn split_val(self) -> Self::Array {
97 bytemuck::must_cast(self)
98 }
99
100 #[inline]
101 fn split_ref(&self) -> &[U] {
102 must_cast_ref::<Self, [U; 4]>(self)
103 }
104
105 #[inline]
106 fn split_mut(&mut self) -> &mut [U] {
107 must_cast_mut::<Self, [U; 4]>(self)
108 }
109}
110
111impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
112 type Output = Self;
113
114 fn bitand(self, rhs: Self) -> Self::Output {
115 Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
116 }
117}
118
119impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
120 fn bitand_assign(&mut self, rhs: Self) {
121 for i in 0..N {
122 self.0[i] &= rhs.0[i];
123 }
124 }
125}
126
127impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
128 type Output = Self;
129
130 fn bitor(self, rhs: Self) -> Self::Output {
131 Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
132 }
133}
134
135impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
136 fn bitor_assign(&mut self, rhs: Self) {
137 for i in 0..N {
138 self.0[i] |= rhs.0[i];
139 }
140 }
141}
142
143impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
144 type Output = Self;
145
146 fn bitxor(self, rhs: Self) -> Self::Output {
147 Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
148 }
149}
150
151impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
152 fn bitxor_assign(&mut self, rhs: Self) {
153 for i in 0..N {
154 self.0[i] ^= rhs.0[i];
155 }
156 }
157}
158
159impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
160 type Output = Self;
161
162 fn shr(self, rhs: usize) -> Self::Output {
163 let mut result = Self::default();
164
165 let shift_in_items = rhs / U::BITS;
166 for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
167 if i + shift_in_items < N {
168 result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
169 }
170 if i + shift_in_items + 1 < N && rhs % U::BITS != 0 {
171 result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
172 }
173 }
174
175 result
176 }
177}
178
179impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
180 type Output = Self;
181
182 fn shl(self, rhs: usize) -> Self::Output {
183 let mut result = Self::default();
184
185 let shift_in_items = rhs / U::BITS;
186 for i in shift_in_items.saturating_sub(1)..N {
187 if i >= shift_in_items {
188 result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
189 }
190 if i > shift_in_items && rhs % U::BITS != 0 {
191 result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
192 }
193 }
194
195 result
196 }
197}
198
199impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
200 type Output = Self;
201
202 fn not(self) -> Self::Output {
203 Self(self.0.map(U::not))
204 }
205}
206
207impl<U: UnderlierWithBitOps + Pod, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N> {
208 const ZERO: Self = Self([U::ZERO; N]);
209 const ONE: Self = {
210 let mut arr = [U::ZERO; N];
211 arr[0] = U::ONE;
212 Self(arr)
213 };
214 const ONES: Self = Self([U::ONES; N]);
215
216 #[inline]
217 fn fill_with_bit(val: u8) -> Self {
218 Self(array::from_fn(|_| U::fill_with_bit(val)))
219 }
220
221 #[inline]
222 fn shl_128b_lanes(self, rhs: usize) -> Self {
223 assert!(U::BITS >= 128);
227
228 Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
229 }
230
231 #[inline]
232 fn shr_128b_lanes(self, rhs: usize) -> Self {
233 assert!(U::BITS >= 128);
237
238 Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
239 }
240
241 #[inline]
242 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
243 assert!(U::BITS >= 128);
247
248 Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
249 }
250
251 #[inline]
252 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
253 assert!(U::BITS >= 128);
257
258 Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_shr() {
268 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
269 assert_eq!(
270 val >> 1,
271 ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
272 );
273 assert_eq!(
274 val >> 2,
275 ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
276 );
277 assert_eq!(
278 val >> 8,
279 ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
280 );
281 assert_eq!(
282 val >> 9,
283 ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
284 );
285 }
286
287 #[test]
288 fn test_shl() {
289 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
290 assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
291 assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
292 assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
293 assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
294 }
295}