1use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8 underlier_type::{NumCast, UnderlierType},
9 U1, U2, U4,
10};
11
12pub trait UnderlierWithBitOps:
14 UnderlierType
15 + BitAnd<Self, Output = Self>
16 + BitAndAssign<Self>
17 + BitOr<Self, Output = Self>
18 + BitOrAssign<Self>
19 + BitXor<Self, Output = Self>
20 + BitXorAssign<Self>
21 + Shr<usize, Output = Self>
22 + Shl<usize, Output = Self>
23 + Not<Output = Self>
24{
25 const ZERO: Self;
26 const ONE: Self;
27 const ONES: Self;
28
29 fn fill_with_bit(val: u8) -> Self;
32
33 #[inline]
34 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
35 where
36 T: UnderlierType,
37 Self: From<T>,
38 {
39 let mut result = Self::default();
42 let width = checked_int_div(Self::BITS, T::BITS);
43 for i in 0..width {
44 result |= Self::from(f(i)) << (i * T::BITS);
45 }
46
47 result
48 }
49
50 #[inline]
53 fn broadcast_subvalue<T>(value: T) -> Self
54 where
55 T: UnderlierType,
56 Self: From<T>,
57 {
58 let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
61 let mut result = Self::from(value);
62 for i in 0..height {
63 result |= result << ((1 << i) * T::BITS);
64 }
65
66 result
67 }
68
69 #[inline]
75 unsafe fn get_subvalue<T>(&self, i: usize) -> T
76 where
77 T: UnderlierType + NumCast<Self>,
78 {
79 debug_assert!(
80 i < checked_int_div(Self::BITS, T::BITS),
81 "i: {} Self::BITS: {}, T::BITS: {}",
82 i,
83 Self::BITS,
84 T::BITS
85 );
86 T::num_cast_from(*self >> (i * T::BITS))
87 }
88
89 #[inline]
95 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
96 where
97 T: UnderlierWithBitOps,
98 Self: From<T>,
99 {
100 debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
101 let mask = Self::from(single_element_mask::<T>());
102
103 *self &= !(mask << (i * T::BITS));
104 *self |= Self::from(val) << (i * T::BITS);
105 }
106
107 #[inline]
114 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
115 where
116 T: UnderlierWithBitOps + NumCast<Self>,
117 Self: From<T>,
118 {
119 spread_fallback(self, log_block_len, block_idx)
120 }
121
122 fn shl_128b_lanes(self, shift: usize) -> Self;
125
126 fn shr_128b_lanes(self, shift: usize) -> Self;
129
130 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
139 unpack_lo_128b_fallback(self, other, log_block_len)
140 }
141
142 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
151 unpack_hi_128b_fallback(self, other, log_block_len)
152 }
153}
154
155fn single_element_mask<T>() -> T
159where
160 T: UnderlierWithBitOps,
161{
162 single_element_mask_bits(T::BITS)
163}
164
165pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
171where
172 U: UnderlierWithBitOps + From<T>,
173 T: UnderlierWithBitOps + NumCast<U>,
174{
175 debug_assert!(
176 log_block_len + T::LOG_BITS <= U::LOG_BITS,
177 "log_block_len: {}, U::BITS: {}, T::BITS: {}",
178 log_block_len,
179 U::BITS,
180 T::BITS
181 );
182 debug_assert!(
183 block_idx < 1 << (U::LOG_BITS - log_block_len),
184 "block_idx: {}, U::BITS: {}, log_block_len: {}",
185 block_idx,
186 U::BITS,
187 log_block_len
188 );
189
190 let mut result = U::ZERO;
191 let block_offset = block_idx << log_block_len;
192 let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
193 for i in 0..1 << log_block_len {
194 unsafe {
195 result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
196 }
197 }
198
199 for i in 0..log_repeat {
200 result |= result << (1 << (T::LOG_BITS + i));
201 }
202
203 result
204}
205
206#[inline(always)]
207fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
208 let mut mask = single_element_mask_bits(1 << log_block_len);
209 for i in 1..T::BITS / 128 {
210 mask |= mask << (i * 128);
211 }
212
213 mask
214}
215
216pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
217 lhs: T,
218 rhs: T,
219 log_block_len: usize,
220) -> T {
221 assert!(log_block_len <= 6);
222
223 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
224
225 let mut result = T::ZERO;
226 for i in 0..1 << (6 - log_block_len) {
227 result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
228 .shl_128b_lanes(i << (log_block_len + 1));
229 result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
230 .shl_128b_lanes((2 * i + 1) << log_block_len);
231 }
232
233 result
234}
235
236pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
237 lhs: T,
238 rhs: T,
239 log_block_len: usize,
240) -> T {
241 assert!(log_block_len <= 6);
242
243 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
244 let mut result = T::ZERO;
245 for i in 0..1 << (6 - log_block_len) {
246 result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
247 .shl_128b_lanes(i << (log_block_len + 1));
248 result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
249 .shl_128b_lanes((2 * i + 1) << log_block_len);
250 }
251
252 result
253}
254
255pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
256 if bits_count == T::BITS {
257 !T::ZERO
258 } else {
259 let mut result = T::ONE;
260 for height in 0..checked_log_2(bits_count) {
261 result |= result << (1 << height)
262 }
263
264 result
265 }
266}
267
268pub(crate) trait SpreadToByte {
270 fn spread_to_byte(self) -> u8;
271}
272
273impl SpreadToByte for U1 {
274 #[inline(always)]
275 fn spread_to_byte(self) -> u8 {
276 u8::fill_with_bit(self.val())
277 }
278}
279
280impl SpreadToByte for U2 {
281 #[inline(always)]
282 fn spread_to_byte(self) -> u8 {
283 let mut result = self.val();
284 result |= result << 2;
285 result |= result << 4;
286
287 result
288 }
289}
290
291impl SpreadToByte for U4 {
292 #[inline(always)]
293 fn spread_to_byte(self) -> u8 {
294 let mut result = self.val();
295 result |= result << 4;
296
297 result
298 }
299}
300
301#[allow(unused)]
306#[inline(always)]
307pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
308 value: U,
309 block_idx: usize,
310) -> [T; BLOCK_LEN]
311where
312 U: UnderlierWithBitOps + From<T>,
313 T: UnderlierType + NumCast<U>,
314{
315 std::array::from_fn(|i| value.get_subvalue::<T>(block_idx * BLOCK_LEN + i))
316}
317
318#[allow(unused)]
323#[inline(always)]
324pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
325 value: U,
326 block_idx: usize,
327) -> [u8; BLOCK_LEN]
328where
329 U: UnderlierWithBitOps + From<T>,
330 T: UnderlierType + SpreadToByte + NumCast<U>,
331{
332 get_block_values::<U, T, BLOCK_LEN>(value, block_idx).map(SpreadToByte::spread_to_byte)
333}
334
335#[cfg(test)]
336mod tests {
337 use proptest::{arbitrary::any, bits, proptest};
338
339 use super::{
340 super::small_uint::{U1, U2, U4},
341 *,
342 };
343
344 #[test]
345 fn test_from_fn() {
346 assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
347 assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
348 assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
349
350 assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
351 assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
352 assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
353 assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
354 assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
355
356 assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
357 assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
358 assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
359 assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
360 assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
361
362 assert_eq!(u32::from_fn(|_| 0u8), 0);
363 assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
364 assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
365 assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
366 }
367
368 #[test]
369 fn test_broadcast_subvalue() {
370 assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
371 assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
372
373 assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
374 assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
375 assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
376 assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
377
378 assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
379 assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
380 assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
381 assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
382
383 assert_eq!(u32::broadcast_subvalue(0u8), 0);
384 assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
385 assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
386 }
387
388 #[test]
389 fn test_get_subvalue() {
390 let value = 0xab12cd34u32;
391
392 unsafe {
393 assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
394 assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
395 assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
396 assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
397
398 assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
399 assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
400 assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
401 assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
402
403 assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
404 assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
405 assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
406 assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
407
408 assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
409 assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
410 assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
411 assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
412 }
413 }
414
415 proptest! {
416 #[test]
417 fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
418 unsafe {
419 init_val.set_subvalue(i, U1::new(val));
420 assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
421 }
422 }
423
424 #[test]
425 fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
426 unsafe {
427 init_val.set_subvalue(i, U2::new(val));
428 assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
429 }
430 }
431
432 #[test]
433 fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
434 unsafe {
435 init_val.set_subvalue(i, U4::new(val));
436 assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
437 }
438 }
439
440 #[test]
441 fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
442 unsafe {
443 init_val.set_subvalue(i, val);
444 assert_eq!(init_val.get_subvalue::<u8>(i), val);
445 }
446 }
447 }
448}