1use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8 U1, U2, U4,
9 underlier_type::{NumCast, UnderlierType},
10};
11use crate::Divisible;
12
13pub trait UnderlierWithBitOps:
15 UnderlierType
16 + BitAnd<Self, Output = Self>
17 + BitAndAssign<Self>
18 + BitOr<Self, Output = Self>
19 + BitOrAssign<Self>
20 + BitXor<Self, Output = Self>
21 + BitXorAssign<Self>
22 + Shr<usize, Output = Self>
23 + Shl<usize, Output = Self>
24 + Not<Output = Self>
25{
26 const ZERO: Self;
27 const ONE: Self;
28 const ONES: Self;
29
30 fn fill_with_bit(val: u8) -> Self;
33
34 #[inline]
35 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
36 where
37 T: UnderlierType,
38 Self: From<T>,
39 {
40 let mut result = Self::default();
43 let width = checked_int_div(Self::BITS, T::BITS);
44 for i in 0..width {
45 result |= Self::from(f(i)) << (i * T::BITS);
46 }
47
48 result
49 }
50
51 #[inline]
54 fn broadcast_subvalue<T>(value: T) -> Self
55 where
56 T: UnderlierType,
57 Self: From<T>,
58 {
59 let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
62 let mut result = Self::from(value);
63 for i in 0..height {
64 result |= result << ((1 << i) * T::BITS);
65 }
66
67 result
68 }
69
70 #[inline]
76 unsafe fn get_subvalue<T>(&self, i: usize) -> T
77 where
78 T: UnderlierType,
79 Self: Divisible<T>,
80 {
81 debug_assert!(
82 i < checked_int_div(Self::BITS, T::BITS),
83 "i: {} Self::BITS: {}, T::BITS: {}",
84 i,
85 Self::BITS,
86 T::BITS
87 );
88 Divisible::<T>::get(*self, i)
89 }
90
91 #[inline]
97 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
98 where
99 T: UnderlierWithBitOps,
100 Self: Divisible<T>,
101 {
102 debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
103 *self = (*self).set(i, val);
104 }
105
106 #[inline]
113 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
114 where
115 T: UnderlierWithBitOps + NumCast<Self>,
116 Self: Divisible<T> + From<T>,
117 {
118 unsafe { spread_fallback(self, log_block_len, block_idx) }
119 }
120}
121
122pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
128where
129 U: UnderlierWithBitOps + From<T> + Divisible<T>,
130 T: UnderlierWithBitOps + NumCast<U>,
131{
132 debug_assert!(
133 log_block_len + T::LOG_BITS <= U::LOG_BITS,
134 "log_block_len: {}, U::BITS: {}, T::BITS: {}",
135 log_block_len,
136 U::BITS,
137 T::BITS
138 );
139 debug_assert!(
140 block_idx < 1 << (U::LOG_BITS - log_block_len),
141 "block_idx: {}, U::BITS: {}, log_block_len: {}",
142 block_idx,
143 U::BITS,
144 log_block_len
145 );
146
147 let mut result = U::ZERO;
148 let block_offset = block_idx << log_block_len;
149 let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
150 for i in 0..1 << log_block_len {
151 unsafe {
152 result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
153 }
154 }
155
156 for i in 0..log_repeat {
157 result |= result << (1 << (T::LOG_BITS + i));
158 }
159
160 result
161}
162
163#[cfg(test)]
164#[allow(unused)]
165pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
166 if bits_count == T::BITS {
167 !T::ZERO
168 } else {
169 let mut result = T::ONE;
170 for height in 0..checked_log_2(bits_count) {
171 result |= result << (1 << height)
172 }
173
174 result
175 }
176}
177
178pub(crate) trait SpreadToByte {
180 fn spread_to_byte(self) -> u8;
181}
182
183impl SpreadToByte for U1 {
184 #[inline(always)]
185 fn spread_to_byte(self) -> u8 {
186 u8::fill_with_bit(self.val())
187 }
188}
189
190impl SpreadToByte for U2 {
191 #[inline(always)]
192 fn spread_to_byte(self) -> u8 {
193 let mut result = self.val();
194 result |= result << 2;
195 result |= result << 4;
196
197 result
198 }
199}
200
201impl SpreadToByte for U4 {
202 #[inline(always)]
203 fn spread_to_byte(self) -> u8 {
204 let mut result = self.val();
205 result |= result << 4;
206
207 result
208 }
209}
210
211#[allow(unused)]
216#[inline(always)]
217pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
218 value: U,
219 block_idx: usize,
220) -> [T; BLOCK_LEN]
221where
222 U: UnderlierWithBitOps + From<T> + Divisible<T>,
223 T: UnderlierType + NumCast<U>,
224{
225 std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
226}
227
228#[allow(unused)]
233#[inline(always)]
234pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
235 value: U,
236 block_idx: usize,
237) -> [u8; BLOCK_LEN]
238where
239 U: UnderlierWithBitOps + From<T> + Divisible<T>,
240 T: UnderlierType + SpreadToByte + NumCast<U>,
241{
242 unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
243 .map(SpreadToByte::spread_to_byte)
244}
245
246#[cfg(test)]
247mod tests {
248 use proptest::{arbitrary::any, bits, proptest};
249
250 use super::{
251 super::small_uint::{U1, U2, U4},
252 *,
253 };
254
255 #[test]
256 fn test_from_fn() {
257 assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
258 assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
259 assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
260
261 assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
262 assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
263 assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
264 assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
265 assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
266
267 assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
268 assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
269 assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
270 assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
271 assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
272
273 assert_eq!(u32::from_fn(|_| 0u8), 0);
274 assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
275 assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
276 assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
277 }
278
279 #[test]
280 fn test_broadcast_subvalue() {
281 assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
282 assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
283
284 assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
285 assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
286 assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
287 assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
288
289 assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
290 assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
291 assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
292 assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
293
294 assert_eq!(u32::broadcast_subvalue(0u8), 0);
295 assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
296 assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
297 }
298
299 #[test]
300 fn test_get_subvalue() {
301 let value = 0xab12cd34u32;
302
303 unsafe {
304 assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
305 assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
306 assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
307 assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
308
309 assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
310 assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
311 assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
312 assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
313
314 assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
315 assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
316 assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
317 assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
318
319 assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
320 assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
321 assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
322 assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
323 }
324 }
325
326 proptest! {
327 #[test]
328 fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
329 unsafe {
330 init_val.set_subvalue(i, U1::new(val));
331 assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
332 }
333 }
334
335 #[test]
336 fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
337 unsafe {
338 init_val.set_subvalue(i, U2::new(val));
339 assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
340 }
341 }
342
343 #[test]
344 fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
345 unsafe {
346 init_val.set_subvalue(i, U4::new(val));
347 assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
348 }
349 }
350
351 #[test]
352 fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
353 unsafe {
354 init_val.set_subvalue(i, val);
355 assert_eq!(init_val.get_subvalue::<u8>(i), val);
356 }
357 }
358 }
359}