1use std::{
4 array, mem,
5 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::{
9 DeserializeBytes, SerializationError, SerializeBytes,
10 bytes::{Buf, BufMut},
11 checked_arithmetics::checked_log_2,
12};
13use bytemuck::{Pod, Zeroable};
14use rand::{
15 Rng,
16 distr::{Distribution, StandardUniform},
17};
18
19use super::{Divisible, NumCast, UnderlierType, UnderlierWithBitOps, mapget};
20use crate::Random;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25#[repr(transparent)]
26pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
27
28impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
29 fn default() -> Self {
30 Self(array::from_fn(|_| U::default()))
31 }
32}
33
34impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
35 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
36 ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
37 }
38}
39
40impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
41 fn from(val: ScaledUnderlier<U, N>) -> Self {
42 val.0
43 }
44}
45
46impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
47 fn from(value: [T; N]) -> Self {
48 Self(value.map(U::from))
49 }
50}
51
52impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
53 fn from(value: [T; 4]) -> Self {
54 Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
55 }
56}
57
58unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
59unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
60
61impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
62 const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
63}
64
65impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
66 type Output = Self;
67
68 fn bitand(self, rhs: Self) -> Self::Output {
69 Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
70 }
71}
72
73impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
74 fn bitand_assign(&mut self, rhs: Self) {
75 for i in 0..N {
76 self.0[i] &= rhs.0[i];
77 }
78 }
79}
80
81impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
82 type Output = Self;
83
84 fn bitor(self, rhs: Self) -> Self::Output {
85 Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
86 }
87}
88
89impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
90 fn bitor_assign(&mut self, rhs: Self) {
91 for i in 0..N {
92 self.0[i] |= rhs.0[i];
93 }
94 }
95}
96
97impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
98 type Output = Self;
99
100 fn bitxor(self, rhs: Self) -> Self::Output {
101 Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
102 }
103}
104
105impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
106 fn bitxor_assign(&mut self, rhs: Self) {
107 for i in 0..N {
108 self.0[i] ^= rhs.0[i];
109 }
110 }
111}
112
113impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
114 type Output = Self;
115
116 fn shr(self, rhs: usize) -> Self::Output {
117 let mut result = Self::default();
118
119 let shift_in_items = rhs / U::BITS;
120 for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
121 if i + shift_in_items < N {
122 result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
123 }
124 if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
125 result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
126 }
127 }
128
129 result
130 }
131}
132
133impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
134 type Output = Self;
135
136 fn shl(self, rhs: usize) -> Self::Output {
137 let mut result = Self::default();
138
139 let shift_in_items = rhs / U::BITS;
140 for i in shift_in_items.saturating_sub(1)..N {
141 if i >= shift_in_items {
142 result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
143 }
144 if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
145 result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
146 }
147 }
148
149 result
150 }
151}
152
153impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
154 type Output = Self;
155
156 fn not(self) -> Self::Output {
157 Self(self.0.map(U::not))
158 }
159}
160
161impl<U: UnderlierWithBitOps + Pod, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N> {
162 const ZERO: Self = Self([U::ZERO; N]);
163 const ONE: Self = {
164 let mut arr = [U::ZERO; N];
165 arr[0] = U::ONE;
166 Self(arr)
167 };
168 const ONES: Self = Self([U::ONES; N]);
169
170 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
171 if log_block_len < U::LOG_BITS {
172 let pairs: [(U, U); N] =
174 array::from_fn(|i| self.0[i].interleave(other.0[i], log_block_len));
175 (Self(array::from_fn(|i| pairs[i].0)), Self(array::from_fn(|i| pairs[i].1)))
176 } else {
177 let block_len = 1 << (log_block_len - U::LOG_BITS);
180
181 let mut a = self.0;
182 let mut b = other.0;
183 for super_block in 0..(N / (2 * block_len)) {
184 let base = super_block * 2 * block_len;
185 for offset in 0..block_len {
186 mem::swap(&mut a[base + block_len + offset], &mut b[base + offset]);
187 }
188 }
189
190 (Self(a), Self(b))
191 }
192 }
193}
194
195impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
196where
197 Self: NumCast<U>,
198{
199 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
200 Self::num_cast_from(val.0[0])
201 }
202}
203
204impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
205where
206 U: From<u8>,
207{
208 fn from(val: u8) -> Self {
209 Self(array::from_fn(|_| U::from(val)))
210 }
211}
212
213impl<U, T, const N: usize> Divisible<T> for ScaledUnderlier<U, N>
214where
215 U: Divisible<T> + Pod + Send + Sync,
216 T: Send + 'static,
217{
218 const LOG_N: usize = <U as Divisible<T>>::LOG_N + checked_log_2(N);
219
220 #[inline]
221 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = T> + Send + Clone {
222 mapget::value_iter(value)
223 }
224
225 #[inline]
226 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
227 mapget::value_iter(*value)
228 }
229
230 #[inline]
231 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
232 mapget::slice_iter(slice)
233 }
234
235 #[inline]
236 fn get(self, index: usize) -> T {
237 let u_index = index >> <U as Divisible<T>>::LOG_N;
238 let sub_index = index & (<U as Divisible<T>>::N - 1);
239 Divisible::<T>::get(self.0[u_index], sub_index)
240 }
241
242 #[inline]
243 fn set(self, index: usize, val: T) -> Self {
244 let u_index = index >> <U as Divisible<T>>::LOG_N;
245 let sub_index = index & (<U as Divisible<T>>::N - 1);
246 let mut arr = self.0;
247 arr[u_index] = Divisible::<T>::set(arr[u_index], sub_index, val);
248 Self(arr)
249 }
250
251 #[inline]
252 fn broadcast(val: T) -> Self {
253 Self([Divisible::<T>::broadcast(val); N])
254 }
255
256 #[inline]
257 fn from_iter(mut iter: impl Iterator<Item = T>) -> Self {
258 Self(array::from_fn(|_| Divisible::<T>::from_iter(&mut iter)))
259 }
260}
261
262impl<U: SerializeBytes, const N: usize> SerializeBytes for ScaledUnderlier<U, N> {
263 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
264 self.0.serialize(write_buf)
265 }
266}
267
268impl<U: DeserializeBytes, const N: usize> DeserializeBytes for ScaledUnderlier<U, N> {
269 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError> {
270 <[U; N]>::deserialize(read_buf).map(Self)
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_shr() {
280 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
281 assert_eq!(
282 val >> 1,
283 ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
284 );
285 assert_eq!(
286 val >> 2,
287 ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
288 );
289 assert_eq!(
290 val >> 8,
291 ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
292 );
293 assert_eq!(
294 val >> 9,
295 ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
296 );
297 }
298
299 #[test]
300 fn test_shl() {
301 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
302 assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
303 assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
304 assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
305 assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
306 }
307
308 #[test]
309 fn test_interleave_within_element() {
310 let a = ScaledUnderlier::<u8, 4>([0b01010101, 0b11110000, 0b00001111, 0b10101010]);
314 let b = ScaledUnderlier::<u8, 4>([0b10101010, 0b00001111, 0b11110000, 0b01010101]);
315
316 let (c, d) = a.interleave(b, 0);
318
319 for i in 0..4 {
321 let (expected_c, expected_d) = a.0[i].interleave(b.0[i], 0);
322 assert_eq!(c.0[i], expected_c);
323 assert_eq!(d.0[i], expected_d);
324 }
325 }
326
327 #[test]
328 fn test_interleave_across_elements() {
329 let a = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
331 let b = ScaledUnderlier::<u8, 4>([4, 5, 6, 7]);
332
333 let (c, d) = a.interleave(b, 3);
335 assert_eq!(c.0, [0, 4, 2, 6]);
336 assert_eq!(d.0, [1, 5, 3, 7]);
337
338 let (c, d) = a.interleave(b, 4);
340 assert_eq!(c.0, [0, 1, 4, 5]);
341 assert_eq!(d.0, [2, 3, 6, 7]);
342 }
343}