1use std::{
4 array,
5 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{NoUninit, Pod, Zeroable, must_cast_mut, must_cast_ref};
10use rand::RngCore;
11use subtle::{Choice, ConstantTimeEq};
12
13use super::{Divisible, NumCast, Random, UnderlierType, UnderlierWithBitOps};
14use crate::tower_levels::TowerLevel;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19#[repr(transparent)]
20pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
21
22impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
23 fn default() -> Self {
24 Self(array::from_fn(|_| U::default()))
25 }
26}
27
28impl<U: Random, const N: usize> Random for ScaledUnderlier<U, N> {
29 fn random(mut rng: impl RngCore) -> Self {
30 Self(array::from_fn(|_| U::random(&mut rng)))
31 }
32}
33
34impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
35 fn from(val: ScaledUnderlier<U, N>) -> Self {
36 val.0
37 }
38}
39
40impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
41 fn from(value: [T; N]) -> Self {
42 Self(value.map(U::from))
43 }
44}
45
46impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
47 fn from(value: [T; 4]) -> Self {
48 Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
49 }
50}
51
52impl<U: ConstantTimeEq, const N: usize> ConstantTimeEq for ScaledUnderlier<U, N> {
53 fn ct_eq(&self, other: &Self) -> Choice {
54 self.0.ct_eq(&other.0)
55 }
56}
57
58unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
59
60unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
61
62impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
63 const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
64}
65
66unsafe impl<U, const N: usize> Divisible<U> for ScaledUnderlier<U, N>
67where
68 Self: UnderlierType,
69 U: UnderlierType,
70{
71 type Array = [U; N];
72
73 #[inline]
74 fn split_val(self) -> Self::Array {
75 self.0
76 }
77
78 #[inline]
79 fn split_ref(&self) -> &[U] {
80 &self.0
81 }
82
83 #[inline]
84 fn split_mut(&mut self) -> &mut [U] {
85 &mut self.0
86 }
87}
88
89unsafe impl<U> Divisible<U> for ScaledUnderlier<ScaledUnderlier<U, 2>, 2>
90where
91 Self: UnderlierType + NoUninit,
92 U: UnderlierType + Pod,
93{
94 type Array = [U; 4];
95
96 #[inline]
97 fn split_val(self) -> Self::Array {
98 bytemuck::must_cast(self)
99 }
100
101 #[inline]
102 fn split_ref(&self) -> &[U] {
103 must_cast_ref::<Self, [U; 4]>(self)
104 }
105
106 #[inline]
107 fn split_mut(&mut self) -> &mut [U] {
108 must_cast_mut::<Self, [U; 4]>(self)
109 }
110}
111
112impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
113 type Output = Self;
114
115 fn bitand(self, rhs: Self) -> Self::Output {
116 Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
117 }
118}
119
120impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
121 fn bitand_assign(&mut self, rhs: Self) {
122 for i in 0..N {
123 self.0[i] &= rhs.0[i];
124 }
125 }
126}
127
128impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
129 type Output = Self;
130
131 fn bitor(self, rhs: Self) -> Self::Output {
132 Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
133 }
134}
135
136impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
137 fn bitor_assign(&mut self, rhs: Self) {
138 for i in 0..N {
139 self.0[i] |= rhs.0[i];
140 }
141 }
142}
143
144impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
145 type Output = Self;
146
147 fn bitxor(self, rhs: Self) -> Self::Output {
148 Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
149 }
150}
151
152impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
153 fn bitxor_assign(&mut self, rhs: Self) {
154 for i in 0..N {
155 self.0[i] ^= rhs.0[i];
156 }
157 }
158}
159
160impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
161 type Output = Self;
162
163 fn shr(self, rhs: usize) -> Self::Output {
164 let mut result = Self::default();
165
166 let shift_in_items = rhs / U::BITS;
167 for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
168 if i + shift_in_items < N {
169 result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
170 }
171 if i + shift_in_items + 1 < N && rhs % U::BITS != 0 {
172 result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
173 }
174 }
175
176 result
177 }
178}
179
180impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
181 type Output = Self;
182
183 fn shl(self, rhs: usize) -> Self::Output {
184 let mut result = Self::default();
185
186 let shift_in_items = rhs / U::BITS;
187 for i in shift_in_items.saturating_sub(1)..N {
188 if i >= shift_in_items {
189 result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
190 }
191 if i > shift_in_items && rhs % U::BITS != 0 {
192 result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
193 }
194 }
195
196 result
197 }
198}
199
200impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
201 type Output = Self;
202
203 fn not(self) -> Self::Output {
204 Self(self.0.map(U::not))
205 }
206}
207
208impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
209where
210 U: UnderlierWithBitOps + Pod + From<u8>,
211 u8: NumCast<U>,
212{
213 const ZERO: Self = Self([U::ZERO; N]);
214 const ONE: Self = {
215 let mut arr = [U::ZERO; N];
216 arr[0] = U::ONE;
217 Self(arr)
218 };
219 const ONES: Self = Self([U::ONES; N]);
220
221 #[inline]
222 fn fill_with_bit(val: u8) -> Self {
223 Self(array::from_fn(|_| U::fill_with_bit(val)))
224 }
225
226 #[inline]
227 fn shl_128b_lanes(self, rhs: usize) -> Self {
228 assert!(U::BITS >= 128);
233
234 Self(self.0.map(|x| x.shl_128b_lanes(rhs)))
235 }
236
237 #[inline]
238 fn shr_128b_lanes(self, rhs: usize) -> Self {
239 assert!(U::BITS >= 128);
244
245 Self(self.0.map(|x| x.shr_128b_lanes(rhs)))
246 }
247
248 #[inline]
249 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
250 assert!(U::BITS >= 128);
255
256 Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len)))
257 }
258
259 #[inline]
260 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
261 assert!(U::BITS >= 128);
266
267 Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len)))
268 }
269
270 #[inline]
271 fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
272 where
273 u8: NumCast<Self>,
274 Self: From<u8>,
275 {
276 for col in 0..N {
277 let mut column = TL::from_fn(|row| values[row].0[col]);
278 U::transpose_bytes_from_byte_sliced::<TL>(&mut column);
279 for row in 0..TL::WIDTH {
280 values[row].0[col] = column[row];
281 }
282 }
283
284 let mut result = TL::default::<Self>();
285 for row in 0..TL::WIDTH {
286 for col in 0..N {
287 let index = row * N + col;
288
289 result[row].0[col] = values[index % TL::WIDTH].0[index / TL::WIDTH];
290 }
291 }
292
293 *values = result;
294 }
295
296 #[inline]
297 fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
298 where
299 u8: NumCast<Self>,
300 Self: From<u8>,
301 {
302 let mut result = TL::from_fn(|row| {
303 Self(array::from_fn(|col| {
304 let index = row + col * TL::WIDTH;
305
306 values[index / N].0[index % N]
307 }))
308 });
309
310 for col in 0..N {
311 let mut column = TL::from_fn(|row| result[row].0[col]);
312 U::transpose_bytes_to_byte_sliced::<TL>(&mut column);
313 for row in 0..TL::WIDTH {
314 result[row].0[col] = column[row];
315 }
316 }
317
318 *values = result;
319 }
320}
321
322impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
323where
324 Self: NumCast<U>,
325{
326 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
327 Self::num_cast_from(val.0[0])
328 }
329}
330
331impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
332where
333 U: From<u8>,
334{
335 fn from(val: u8) -> Self {
336 Self(array::from_fn(|_| U::from(val)))
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_shr() {
346 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
347 assert_eq!(
348 val >> 1,
349 ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
350 );
351 assert_eq!(
352 val >> 2,
353 ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
354 );
355 assert_eq!(
356 val >> 8,
357 ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
358 );
359 assert_eq!(
360 val >> 9,
361 ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
362 );
363 }
364
365 #[test]
366 fn test_shl() {
367 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
368 assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
369 assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
370 assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
371 assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
372 }
373}