1use std::{
5 array,
6 fmt::{self, LowerHex},
7 mem,
8 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
9};
10
11use binius_utils::{
12 DeserializeBytes, SerializationError, SerializeBytes,
13 bytes::{Buf, BufMut},
14 checked_arithmetics::checked_log_2,
15};
16use bytemuck::{Pod, Zeroable};
17use rand::{
18 Rng,
19 distr::{Distribution, StandardUniform},
20};
21
22use super::{Divisible, NumCast, U1, UnderlierType, mapget};
23use crate::Random;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
28#[repr(transparent)]
29pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
30
31impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
32 fn default() -> Self {
33 Self(array::from_fn(|_| U::default()))
34 }
35}
36
37impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
38 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
39 ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
40 }
41}
42
43impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
44 fn from(val: ScaledUnderlier<U, N>) -> Self {
45 val.0
46 }
47}
48
49impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
50 fn from(value: [T; N]) -> Self {
51 Self(value.map(U::from))
52 }
53}
54
55impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
56 fn from(value: [T; 4]) -> Self {
57 Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
58 }
59}
60
61unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
62unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
63
64impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
65 type Output = Self;
66
67 fn bitand(self, rhs: Self) -> Self::Output {
68 Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
69 }
70}
71
72impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
73 fn bitand_assign(&mut self, rhs: Self) {
74 for i in 0..N {
75 self.0[i] &= rhs.0[i];
76 }
77 }
78}
79
80impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
81 type Output = Self;
82
83 fn bitor(self, rhs: Self) -> Self::Output {
84 Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
85 }
86}
87
88impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
89 fn bitor_assign(&mut self, rhs: Self) {
90 for i in 0..N {
91 self.0[i] |= rhs.0[i];
92 }
93 }
94}
95
96impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
97 type Output = Self;
98
99 fn bitxor(self, rhs: Self) -> Self::Output {
100 Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
101 }
102}
103
104impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
105 fn bitxor_assign(&mut self, rhs: Self) {
106 for i in 0..N {
107 self.0[i] ^= rhs.0[i];
108 }
109 }
110}
111
112impl<U: UnderlierType, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
113 type Output = Self;
114
115 fn shr(self, rhs: usize) -> Self::Output {
116 let mut result = Self::default();
117
118 let shift_in_items = rhs / U::BITS;
119 for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
120 if i + shift_in_items < N {
121 result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
122 }
123 if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
124 result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
125 }
126 }
127
128 result
129 }
130}
131
132impl<U: UnderlierType, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
133 type Output = Self;
134
135 fn shl(self, rhs: usize) -> Self::Output {
136 let mut result = Self::default();
137
138 let shift_in_items = rhs / U::BITS;
139 for i in shift_in_items.saturating_sub(1)..N {
140 if i >= shift_in_items {
141 result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
142 }
143 if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
144 result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
145 }
146 }
147
148 result
149 }
150}
151
152impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
153 type Output = Self;
154
155 fn not(self) -> Self::Output {
156 Self(self.0.map(U::not))
157 }
158}
159
160impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
161 const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
162
163 const ZERO: Self = Self([U::ZERO; N]);
164 const ONE: Self = {
165 let mut arr = [U::ZERO; N];
166 arr[0] = U::ONE;
167 Self(arr)
168 };
169 const ONES: Self = Self([U::ONES; N]);
170
171 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
172 if log_block_len < U::LOG_BITS {
173 let pairs: [(U, U); N] =
175 array::from_fn(|i| self.0[i].interleave(other.0[i], log_block_len));
176 (Self(array::from_fn(|i| pairs[i].0)), Self(array::from_fn(|i| pairs[i].1)))
177 } else {
178 let block_len = 1 << (log_block_len - U::LOG_BITS);
181
182 let mut a = self.0;
183 let mut b = other.0;
184 for super_block in 0..(N / (2 * block_len)) {
185 let base = super_block * 2 * block_len;
186 for offset in 0..block_len {
187 mem::swap(&mut a[base + block_len + offset], &mut b[base + offset]);
188 }
189 }
190
191 (Self(a), Self(b))
192 }
193 }
194}
195
196impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
197where
198 Self: NumCast<U>,
199{
200 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
201 Self::num_cast_from(val.0[0])
202 }
203}
204
205impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for U1
206where
207 Self: NumCast<U>,
208{
209 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
210 Self::num_cast_from(val.0[0])
211 }
212}
213
214impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for crate::arch::M128
217where
218 Self: NumCast<U>,
219{
220 fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
221 Self::num_cast_from(val.0[0])
222 }
223}
224
225impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
226where
227 U: From<u8>,
228{
229 fn from(val: u8) -> Self {
230 Self(array::from_fn(|_| U::from(val)))
231 }
232}
233
234impl<const N: usize> From<crate::arch::M128> for ScaledUnderlier<crate::arch::M128, N> {
240 fn from(val: crate::arch::M128) -> Self {
241 let mut limbs = [<crate::arch::M128 as UnderlierType>::ZERO; N];
242 limbs[0] = val;
243 Self(limbs)
244 }
245}
246
247impl<const N: usize> From<U1> for ScaledUnderlier<crate::arch::M128, N> {
249 fn from(val: U1) -> Self {
250 Self::from(crate::arch::M128::from(val))
251 }
252}
253
254impl<U: UnderlierType + LowerHex, const N: usize> LowerHex for ScaledUnderlier<U, N> {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 let width = U::BITS / 4;
259 let top = self
260 .0
261 .iter()
262 .rposition(|limb| *limb != U::ZERO)
263 .unwrap_or(0);
264 write!(f, "{:x}", self.0[top])?;
265 for limb in self.0[..top].iter().rev() {
266 write!(f, "{limb:0width$x}")?;
267 }
268 Ok(())
269 }
270}
271
272impl<U, T, const N: usize> Divisible<T> for ScaledUnderlier<U, N>
273where
274 U: Divisible<T> + Pod + Send + Sync,
275 T: Send + 'static,
276{
277 const LOG_N: usize = <U as Divisible<T>>::LOG_N + checked_log_2(N);
278
279 #[inline]
280 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = T> + Send + Clone {
281 mapget::value_iter(value)
282 }
283
284 #[inline]
285 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
286 mapget::value_iter(*value)
287 }
288
289 #[inline]
290 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
291 mapget::slice_iter(slice)
292 }
293
294 #[inline]
295 unsafe fn get_unchecked(&self, index: usize) -> T {
296 let u_index = index >> <U as Divisible<T>>::LOG_N;
297 let sub_index = index & (<U as Divisible<T>>::N - 1);
298 unsafe { Divisible::<T>::get_unchecked(self.0.get_unchecked(u_index), sub_index) }
301 }
302
303 #[inline]
304 unsafe fn set_unchecked(&mut self, index: usize, val: T) {
305 let u_index = index >> <U as Divisible<T>>::LOG_N;
306 let sub_index = index & (<U as Divisible<T>>::N - 1);
307 unsafe { Divisible::<T>::set_unchecked(self.0.get_unchecked_mut(u_index), sub_index, val) };
310 }
311
312 #[inline]
313 fn broadcast(val: T) -> Self {
314 Self([Divisible::<T>::broadcast(val); N])
315 }
316
317 #[inline]
318 fn from_iter(mut iter: impl Iterator<Item = T>) -> Self {
319 Self(array::from_fn(|_| Divisible::<T>::from_iter(&mut iter)))
320 }
321}
322
323impl<U: SerializeBytes, const N: usize> SerializeBytes for ScaledUnderlier<U, N> {
324 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
325 self.0.serialize(write_buf)
326 }
327}
328
329impl<U: DeserializeBytes, const N: usize> DeserializeBytes for ScaledUnderlier<U, N> {
330 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError> {
331 <[U; N]>::deserialize(read_buf).map(Self)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_shr() {
341 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
342 assert_eq!(
343 val >> 1,
344 ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
345 );
346 assert_eq!(
347 val >> 2,
348 ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
349 );
350 assert_eq!(
351 val >> 8,
352 ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
353 );
354 assert_eq!(
355 val >> 9,
356 ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
357 );
358 }
359
360 #[test]
361 fn test_shl() {
362 let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
363 assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
364 assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
365 assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
366 assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
367 }
368
369 #[test]
370 fn test_interleave_within_element() {
371 let a = ScaledUnderlier::<u8, 4>([0b01010101, 0b11110000, 0b00001111, 0b10101010]);
375 let b = ScaledUnderlier::<u8, 4>([0b10101010, 0b00001111, 0b11110000, 0b01010101]);
376
377 let (c, d) = a.interleave(b, 0);
379
380 for i in 0..4 {
382 let (expected_c, expected_d) = a.0[i].interleave(b.0[i], 0);
383 assert_eq!(c.0[i], expected_c);
384 assert_eq!(d.0[i], expected_d);
385 }
386 }
387
388 #[test]
389 fn test_interleave_across_elements() {
390 let a = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
392 let b = ScaledUnderlier::<u8, 4>([4, 5, 6, 7]);
393
394 let (c, d) = a.interleave(b, 3);
396 assert_eq!(c.0, [0, 4, 2, 6]);
397 assert_eq!(d.0, [1, 5, 3, 7]);
398
399 let (c, d) = a.interleave(b, 4);
401 assert_eq!(c.0, [0, 1, 4, 5]);
402 assert_eq!(d.0, [2, 3, 6, 7]);
403 }
404}