1use std::{
4 array,
5 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9use bytemuck::{Pod, Zeroable};
10use rand::{
11 Rng,
12 distr::{Distribution, StandardUniform},
13};
14
15use super::{NumCast, UnderlierType, UnderlierWithBitOps};
16use crate::Random;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
21#[repr(transparent)]
22pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
23
24impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
25 fn default() -> Self {
26 Self(array::from_fn(|_| U::default()))
27 }
28}
29
30impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
31 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
32 ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
33 }
34}
35
36impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
37 fn from(val: ScaledUnderlier<U, N>) -> Self {
38 val.0
39 }
40}
41
42impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
43 fn from(value: [T; N]) -> Self {
44 Self(value.map(U::from))
45 }
46}
47
48impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
49 fn from(value: [T; 4]) -> Self {
50 Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
51 }
52}
53
54unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
55
56unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
57
58impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
59 const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
60}
61
62impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
63 type Output = Self;
64
65 fn bitand(self, rhs: Self) -> Self::Output {
66 Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
67 }
68}
69
70impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
71 fn bitand_assign(&mut self, rhs: Self) {
72 for i in 0..N {
73 self.0[i] &= rhs.0[i];
74 }
75 }
76}
77
78impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
79 type Output = Self;
80
81 fn bitor(self, rhs: Self) -> Self::Output {
82 Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
83 }
84}
85
86impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
87 fn bitor_assign(&mut self, rhs: Self) {
88 for i in 0..N {
89 self.0[i] |= rhs.0[i];
90 }
91 }
92}
93
94impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
95 type Output = Self;
96
97 fn bitxor(self, rhs: Self) -> Self::Output {
98 Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
99 }
100}
101
102impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
103 fn bitxor_assign(&mut self, rhs: Self) {
104 for i in 0..N {
105 self.0[i] ^= rhs.0[i];
106 }
107 }
108}
109
110impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
111 type Output = Self;
112
113 fn shr(self, rhs: usize) -> Self::Output {
114 let mut result = Self::default();
115
116 let shift_in_items = rhs / U::BITS;
117 for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
118 if i + shift_in_items < N {
119 result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
120 }
121 if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
122 result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
123 }
124 }
125
126 result
127 }
128}
129
130impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
131 type Output = Self;
132
133 fn shl(self, rhs: usize) -> Self::Output {
134 let mut result = Self::default();
135
136 let shift_in_items = rhs / U::BITS;
137 for i in shift_in_items.saturating_sub(1)..N {
138 if i >= shift_in_items {
139 result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
140 }
141 if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
142 result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
143 }
144 }
145
146 result
147 }
148}
149
150impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
151 type Output = Self;
152
153 fn not(self) -> Self::Output {
154 Self(self.0.map(U::not))
155 }
156}
157
158impl<U, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N>
159where
160 U: UnderlierWithBitOps + Pod + From<u8>,
161 u8: NumCast<U>,
162{
163 const ZERO: Self = Self([U::ZERO; N]);
164 const ONE: Self = {
165 let mut arr = [U::ZERO; N];
166 arr[0] = U::ONE;
167 Self(arr)
168 };
169 const ONES: Self = Self([U::ONES; N]);
170
171 #[inline]
172 fn fill_with_bit(val: u8) -> Self {
173 Self(array::from_fn(|_| U::fill_with_bit(val)))
174 }
175}
176
177impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
178where
179 Self: NumCast<U>,
180{
181 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
182 Self::num_cast_from(val.0[0])
183 }
184}
185
186impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
187where
188 U: From<u8>,
189{
190 fn from(val: u8) -> Self {
191 Self(array::from_fn(|_| U::from(val)))
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_shr() {
201 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
202 assert_eq!(
203 val >> 1,
204 ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
205 );
206 assert_eq!(
207 val >> 2,
208 ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
209 );
210 assert_eq!(
211 val >> 8,
212 ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
213 );
214 assert_eq!(
215 val >> 9,
216 ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
217 );
218 }
219
220 #[test]
221 fn test_shl() {
222 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
223 assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
224 assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
225 assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
226 assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
227 }
228}