1use std::{
4 array,
5 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{must_cast_mut, must_cast_ref, NoUninit, Pod, Zeroable};
10use rand::RngCore;
11use subtle::{Choice, ConstantTimeEq};
12
13use super::{Divisible, NumCast, Random, UnderlierType, UnderlierWithBitOps};
14use crate::tower_levels::TowerLevel;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19#[repr(transparent)]
20pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
21
22impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
23 fn default() -> Self {
24 Self(array::from_fn(|_| U::default()))
25 }
26}
27
28impl<U: Random, const N: usize> Random for ScaledUnderlier<U, N> {
29 fn random(mut rng: impl RngCore) -> Self {
30 Self(array::from_fn(|_| U::random(&mut rng)))
31 }
32}
33
34impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
35 fn from(val: ScaledUnderlier<U, N>) -> Self {
36 val.0
37 }
38}
39
40impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
41 fn from(value: [T; N]) -> Self {
42 Self(value.map(U::from))
43 }
44}
45
46impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
47 fn from(value: [T; 4]) -> Self {
48 Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
49 }
50}
51
52impl<U: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledUnderlier<U, N> {
53 fn ct_eq(&self, other: &Self) -> Choice {
54 self.0.ct_eq(&other.0)
55 }
56}
57
58unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
59
60unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
61
62impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
63 const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
64}
65
66unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
67where
68 Self: UnderlierType,
69 U: UnderlierType,
70{
71 type Array = [U; N];
72
73 #[inline]
74 fn split_val(self) -> Self::Array {
75 self.0
76 }
77
78 #[inline]
79 fn split_ref(&self) -> &[U] {
80 &self.0
81 }
82
83 #[inline]
84 fn split_mut(&mut self) -> &mut [U] {
85 &mut self.0
86 }
87}
88
89unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
90where
91 Self: UnderlierType + NoUninit,
92 U: UnderlierType + Pod,
93{
94 type Array = [U; 4];
95
96 #[inline]
97 fn split_val(self) -> Self::Array {
98 bytemuck::must_cast(self)
99 }
100
101 #[inline]
102 fn split_ref(&self) -> &[U] {
103 must_cast_ref::<Self, [U; 4]>(self)
104 }
105
106 #[inline]
107 fn split_mut(&mut self) -> &mut [U] {
108 must_cast_mut::<Self, [U; 4]>(self)
109 }
110}
111
112impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
113 type Output = Self;
114
115 fn bitand(self, rhs: Self) -> Self::Output {
116 Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
117 }
118}
119
120impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
121 fn bitand_assign(&mut self, rhs: Self) {
122 for i in 0..N {
123 self.0[i] &= rhs.0[i];
124 }
125 }
126}
127
128impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
129 type Output = Self;
130
131 fn bitor(self, rhs: Self) -> Self::Output {
132 Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
133 }
134}
135
136impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
137 fn bitor_assign(&mut self, rhs: Self) {
138 for i in 0..N {
139 self.0[i] |= rhs.0[i];
140 }
141 }
142}
143
144impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
145 type Output = Self;
146
147 fn bitxor(self, rhs: Self) -> Self::Output {
148 Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
149 }
150}
151
152impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
153 fn bitxor_assign(&mut self, rhs: Self) {
154 for i in 0..N {
155 self.0[i] ^= rhs.0[i];
156 }
157 }
158}
159
160impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
161 type Output = Self;
162
163 fn shr(self, rhs: usize) -> Self::Output {
164 let mut result = Self::default();
165
166 let shift_in_items = rhs / U::BITS;
167 for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
168 if i + shift_in_items < N {
169 result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
170 }
171 if i + shift_in_items + 1 < N && rhs % U::BITS != 0 {
172 result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
173 }
174 }
175
176 result
177 }
178}
179
180impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
181 type Output = Self;
182
183 fn shl(self, rhs: usize) -> Self::Output {
184 let mut result = Self::default();
185
186 let shift_in_items = rhs / U::BITS;
187 for i in shift_in_items.saturating_sub(1)..N {
188 if i >= shift_in_items {
189 result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
190 }
191 if i > shift_in_items && rhs % U::BITS != 0 {
192 result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
193 }
194 }
195
196 result
197 }
198}
199
200impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
201 type Output = Self;
202
203 fn not(self) -> Self::Output {
204 Self(self.0.map(U::not))
205 }
206}
207
208impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
209where
210 U: UnderlierWithBitOps + Pod + From<u8>,
211 u8: NumCast<U>,
212{
213 const ZERO: Self = Self([U::ZERO; N]);
214 const ONE: Self = {
215 let mut arr = [U::ZERO; N];
216 arr[0] = U::ONE;
217 Self(arr)
218 };
219 const ONES: Self = Self([U::ONES; N]);
220
221 #[inline]
222 fn fill_with_bit(val: u8) -> Self {
223 Self(array::from_fn(|_| U::fill_with_bit(val)))
224 }
225
226 #[inline]
227 fn shl_128b_lanes(self, rhs: usize) -> Self {
228 assert!(U::BITS >= 128);
232
233 Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
234 }
235
236 #[inline]
237 fn shr_128b_lanes(self, rhs: usize) -> Self {
238 assert!(U::BITS >= 128);
242
243 Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
244 }
245
246 #[inline]
247 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
248 assert!(U::BITS >= 128);
252
253 Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
254 }
255
256 #[inline]
257 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
258 assert!(U::BITS >= 128);
262
263 Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
264 }
265
266 #[inline]
267 fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
268 where
269 u8: NumCast<Self>,
270 Self: From<u8>,
271 {
272 for col in 0..N {
273 let mut column = TL::from_fn(|row| values[row].0[col]);
274 U::transpose_bytes_from_byte_sliced::<TL>(&mut column);
275 for row in 0..TL::WIDTH {
276 values[row].0[col] = column[row];
277 }
278 }
279
280 let mut result = TL::default::<Self>();
281 for row in 0..TL::WIDTH {
282 for col in 0..N {
283 let index = row * N + col;
284
285 result[row].0[col] = values[index % TL::WIDTH].0[index / TL::WIDTH];
286 }
287 }
288
289 *values = result;
290 }
291
292 #[inline]
293 fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
294 where
295 u8: NumCast<Self>,
296 Self: From<u8>,
297 {
298 let mut result = TL::from_fn(|row| {
299 Self(array::from_fn(|col| {
300 let index = row + col * TL::WIDTH;
301
302 values[index / N].0[index % N]
303 }))
304 });
305
306 for col in 0..N {
307 let mut column = TL::from_fn(|row| result[row].0[col]);
308 U::transpose_bytes_to_byte_sliced::<TL>(&mut column);
309 for row in 0..TL::WIDTH {
310 result[row].0[col] = column[row];
311 }
312 }
313
314 *values = result;
315 }
316}
317
318impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
319where
320 Self: NumCast<U>,
321{
322 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
323 Self::num_cast_from(val.0[0])
324 }
325}
326
327impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
328where
329 U: From<u8>,
330{
331 fn from(val: u8) -> Self {
332 Self(array::from_fn(|_| U::from(val)))
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_shr() {
342 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
343 assert_eq!(
344 val >> 1,
345 ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
346 );
347 assert_eq!(
348 val >> 2,
349 ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
350 );
351 assert_eq!(
352 val >> 8,
353 ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
354 );
355 assert_eq!(
356 val >> 9,
357 ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
358 );
359 }
360
361 #[test]
362 fn test_shl() {
363 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
364 assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
365 assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
366 assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
367 assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
368 }
369}