1use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::checked_int_div;
6
7use super::{
8 U1, U2, U4,
9 underlier_type::{NumCast, UnderlierType},
10};
11use crate::Divisible;
12
13pub trait UnderlierWithBitOps:
15 UnderlierType
16 + BitAnd<Self, Output = Self>
17 + BitAndAssign<Self>
18 + BitOr<Self, Output = Self>
19 + BitOrAssign<Self>
20 + BitXor<Self, Output = Self>
21 + BitXorAssign<Self>
22 + Shr<usize, Output = Self>
23 + Shl<usize, Output = Self>
24 + Not<Output = Self>
25 + Divisible<U1>
26{
27 const ZERO: Self;
28 const ONE: Self;
29 const ONES: Self;
30
31 fn fill_with_bit(val: u8) -> Self {
34 Self::broadcast_subvalue(U1::new(val))
35 }
36
37 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
39
40 fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) {
42 assert!(log_block_len < Self::LOG_BITS);
43
44 for log_block_len in (log_block_len..Self::LOG_BITS).rev() {
45 (self, other) = self.interleave(other, log_block_len);
46 }
47
48 (self, other)
49 }
50
51 #[inline]
52 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self
53 where
54 T: UnderlierType,
55 Self: Divisible<T>,
56 {
57 Self::from_iter((0..<Self as Divisible<T>>::N).map(f))
58 }
59
60 #[inline]
63 fn broadcast_subvalue<T>(value: T) -> Self
64 where
65 T: UnderlierType,
66 Self: Divisible<T>,
67 {
68 Divisible::<T>::broadcast(value)
69 }
70
71 #[inline]
77 unsafe fn get_subvalue<T>(&self, i: usize) -> T
78 where
79 T: UnderlierType,
80 Self: Divisible<T>,
81 {
82 debug_assert!(
83 i < checked_int_div(Self::BITS, T::BITS),
84 "i: {} Self::BITS: {}, T::BITS: {}",
85 i,
86 Self::BITS,
87 T::BITS
88 );
89 Divisible::<T>::get(*self, i)
90 }
91
92 #[inline]
98 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
99 where
100 T: UnderlierWithBitOps,
101 Self: Divisible<T>,
102 {
103 debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
104 *self = (*self).set(i, val);
105 }
106
107 #[inline]
114 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
115 where
116 T: UnderlierWithBitOps,
117 Self: Divisible<T>,
118 {
119 unsafe { spread_fallback::<Self, T>(self, log_block_len, block_idx) }
120 }
121}
122
123pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
129where
130 U: UnderlierWithBitOps + Divisible<T>,
131 T: UnderlierWithBitOps,
132{
133 debug_assert!(
134 log_block_len + T::LOG_BITS <= U::LOG_BITS,
135 "log_block_len: {}, U::BITS: {}, T::BITS: {}",
136 log_block_len,
137 U::BITS,
138 T::BITS
139 );
140 debug_assert!(
141 block_idx < 1 << (U::LOG_BITS - log_block_len),
142 "block_idx: {}, U::BITS: {}, log_block_len: {}",
143 block_idx,
144 U::BITS,
145 log_block_len
146 );
147
148 let mut result = U::ZERO;
149 let block_offset = block_idx << log_block_len;
150 let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
151 for i in 0..1 << log_block_len {
152 unsafe {
153 result.set_subvalue(i << log_repeat, value.get_subvalue::<T>(block_offset + i));
154 }
155 }
156
157 for i in 0..log_repeat {
158 result |= result << (1 << (T::LOG_BITS + i));
159 }
160
161 result
162}
163
164#[cfg(test)]
165#[allow(unused)]
166pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
167 use binius_utils::checked_arithmetics::checked_log_2;
168
169 if bits_count == T::BITS {
170 !T::ZERO
171 } else {
172 let mut result = T::ONE;
173 for height in 0..checked_log_2(bits_count) {
174 result |= result << (1 << height)
175 }
176
177 result
178 }
179}
180
181pub(crate) trait SpreadToByte {
183 fn spread_to_byte(self) -> u8;
184}
185
186impl SpreadToByte for U1 {
187 #[inline(always)]
188 fn spread_to_byte(self) -> u8 {
189 u8::fill_with_bit(self.val())
190 }
191}
192
193impl SpreadToByte for U2 {
194 #[inline(always)]
195 fn spread_to_byte(self) -> u8 {
196 let mut result = self.val();
197 result |= result << 2;
198 result |= result << 4;
199
200 result
201 }
202}
203
204impl SpreadToByte for U4 {
205 #[inline(always)]
206 fn spread_to_byte(self) -> u8 {
207 let mut result = self.val();
208 result |= result << 4;
209
210 result
211 }
212}
213
214#[allow(unused)]
219#[inline(always)]
220pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
221 value: U,
222 block_idx: usize,
223) -> [T; BLOCK_LEN]
224where
225 U: UnderlierWithBitOps + From<T> + Divisible<T>,
226 T: UnderlierType + NumCast<U>,
227{
228 std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
229}
230
231#[allow(unused)]
236#[inline(always)]
237pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
238 value: U,
239 block_idx: usize,
240) -> [u8; BLOCK_LEN]
241where
242 U: UnderlierWithBitOps + From<T> + Divisible<T>,
243 T: UnderlierType + SpreadToByte + NumCast<U>,
244{
245 unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
246 .map(SpreadToByte::spread_to_byte)
247}
248
249#[cfg(test)]
250mod tests {
251 use proptest::{arbitrary::any, bits, proptest};
252
253 use super::{
254 super::small_uint::{U1, U2, U4},
255 *,
256 };
257
258 #[test]
259 fn test_from_fn() {
260 assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
261 assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
262 assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
263
264 assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
265 assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
266 assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
267 assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
268 assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
269
270 assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
271 assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
272 assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
273 assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
274 assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
275
276 assert_eq!(u32::from_fn(|_| 0u8), 0);
277 assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
278 assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
279 assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
280 }
281
282 #[test]
283 fn test_broadcast_subvalue() {
284 assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
285 assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
286
287 assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
288 assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
289 assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
290 assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
291
292 assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
293 assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
294 assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
295 assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
296
297 assert_eq!(u32::broadcast_subvalue(0u8), 0);
298 assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
299 assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
300 }
301
302 #[test]
303 fn test_get_subvalue() {
304 let value = 0xab12cd34u32;
305
306 unsafe {
307 assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
308 assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
309 assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
310 assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
311
312 assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
313 assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
314 assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
315 assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
316
317 assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
318 assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
319 assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
320 assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
321
322 assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
323 assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
324 assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
325 assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
326 }
327 }
328
329 proptest! {
330 #[test]
331 fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
332 unsafe {
333 init_val.set_subvalue(i, U1::new(val));
334 assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
335 }
336 }
337
338 #[test]
339 fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
340 unsafe {
341 init_val.set_subvalue(i, U2::new(val));
342 assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
343 }
344 }
345
346 #[test]
347 fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
348 unsafe {
349 init_val.set_subvalue(i, U4::new(val));
350 assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
351 }
352 }
353
354 #[test]
355 fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
356 unsafe {
357 init_val.set_subvalue(i, val);
358 assert_eq!(init_val.get_subvalue::<u8>(i), val);
359 }
360 }
361 }
362}