1use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8 U1, U2, U4,
9 underlier_type::{NumCast, UnderlierType},
10};
11
12pub trait UnderlierWithBitOps:
14 UnderlierType
15 + BitAnd<Self, Output = Self>
16 + BitAndAssign<Self>
17 + BitOr<Self, Output = Self>
18 + BitOrAssign<Self>
19 + BitXor<Self, Output = Self>
20 + BitXorAssign<Self>
21 + Shr<usize, Output = Self>
22 + Shl<usize, Output = Self>
23 + Not<Output = Self>
24{
25 const ZERO: Self;
26 const ONE: Self;
27 const ONES: Self;
28
29 fn fill_with_bit(val: u8) -> Self;
32
33 #[inline]
34 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
35 where
36 T: UnderlierType,
37 Self: From<T>,
38 {
39 let mut result = Self::default();
42 let width = checked_int_div(Self::BITS, T::BITS);
43 for i in 0..width {
44 result |= Self::from(f(i)) << (i * T::BITS);
45 }
46
47 result
48 }
49
50 #[inline]
53 fn broadcast_subvalue<T>(value: T) -> Self
54 where
55 T: UnderlierType,
56 Self: From<T>,
57 {
58 let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
61 let mut result = Self::from(value);
62 for i in 0..height {
63 result |= result << ((1 << i) * T::BITS);
64 }
65
66 result
67 }
68
69 #[inline]
75 unsafe fn get_subvalue<T>(&self, i: usize) -> T
76 where
77 T: UnderlierType + NumCast<Self>,
78 {
79 debug_assert!(
80 i < checked_int_div(Self::BITS, T::BITS),
81 "i: {} Self::BITS: {}, T::BITS: {}",
82 i,
83 Self::BITS,
84 T::BITS
85 );
86 T::num_cast_from(*self >> (i * T::BITS))
87 }
88
89 #[inline]
95 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
96 where
97 T: UnderlierWithBitOps,
98 Self: From<T>,
99 {
100 debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
101 let mask = Self::from(single_element_mask::<T>());
102
103 *self &= !(mask << (i * T::BITS));
104 *self |= Self::from(val) << (i * T::BITS);
105 }
106
107 #[inline]
114 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
115 where
116 T: UnderlierWithBitOps + NumCast<Self>,
117 Self: From<T>,
118 {
119 unsafe { spread_fallback(self, log_block_len, block_idx) }
120 }
121
122 fn shl_128b_lanes(self, shift: usize) -> Self;
125
126 fn shr_128b_lanes(self, shift: usize) -> Self;
129
130 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
140 unpack_lo_128b_fallback(self, other, log_block_len)
141 }
142
143 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
153 unpack_hi_128b_fallback(self, other, log_block_len)
154 }
155}
156
157fn single_element_mask<T>() -> T
161where
162 T: UnderlierWithBitOps,
163{
164 single_element_mask_bits(T::BITS)
165}
166
167#[allow(dead_code)]
169#[inline(always)]
170pub(crate) fn pair_unpack_lo_hi_128b_lanes<U: UnderlierWithBitOps>(
171 values: &mut impl AsMut<[U]>,
172 i: usize,
173 j: usize,
174 log_block_len: usize,
175) {
176 let values = values.as_mut();
177
178 (values[i], values[j]) = (
179 values[i].unpack_lo_128b_lanes(values[j], log_block_len),
180 values[i].unpack_hi_128b_lanes(values[j], log_block_len),
181 );
182}
183
184pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
190where
191 U: UnderlierWithBitOps + From<T>,
192 T: UnderlierWithBitOps + NumCast<U>,
193{
194 debug_assert!(
195 log_block_len + T::LOG_BITS <= U::LOG_BITS,
196 "log_block_len: {}, U::BITS: {}, T::BITS: {}",
197 log_block_len,
198 U::BITS,
199 T::BITS
200 );
201 debug_assert!(
202 block_idx < 1 << (U::LOG_BITS - log_block_len),
203 "block_idx: {}, U::BITS: {}, log_block_len: {}",
204 block_idx,
205 U::BITS,
206 log_block_len
207 );
208
209 let mut result = U::ZERO;
210 let block_offset = block_idx << log_block_len;
211 let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
212 for i in 0..1 << log_block_len {
213 unsafe {
214 result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
215 }
216 }
217
218 for i in 0..log_repeat {
219 result |= result << (1 << (T::LOG_BITS + i));
220 }
221
222 result
223}
224
225#[inline(always)]
226fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
227 let mut mask = single_element_mask_bits(1 << log_block_len);
228 for i in 1..T::BITS / 128 {
229 mask |= mask << (i * 128);
230 }
231
232 mask
233}
234
235pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
236 lhs: T,
237 rhs: T,
238 log_block_len: usize,
239) -> T {
240 assert!(log_block_len <= 6);
241
242 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
243
244 let mut result = T::ZERO;
245 for i in 0..1 << (6 - log_block_len) {
246 result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
247 .shl_128b_lanes(i << (log_block_len + 1));
248 result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
249 .shl_128b_lanes((2 * i + 1) << log_block_len);
250 }
251
252 result
253}
254
255pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
256 lhs: T,
257 rhs: T,
258 log_block_len: usize,
259) -> T {
260 assert!(log_block_len <= 6);
261
262 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
263 let mut result = T::ZERO;
264 for i in 0..1 << (6 - log_block_len) {
265 result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
266 .shl_128b_lanes(i << (log_block_len + 1));
267 result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
268 .shl_128b_lanes((2 * i + 1) << log_block_len);
269 }
270
271 result
272}
273
274pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
275 if bits_count == T::BITS {
276 !T::ZERO
277 } else {
278 let mut result = T::ONE;
279 for height in 0..checked_log_2(bits_count) {
280 result |= result << (1 << height)
281 }
282
283 result
284 }
285}
286
287pub(crate) trait SpreadToByte {
289 fn spread_to_byte(self) -> u8;
290}
291
292impl SpreadToByte for U1 {
293 #[inline(always)]
294 fn spread_to_byte(self) -> u8 {
295 u8::fill_with_bit(self.val())
296 }
297}
298
299impl SpreadToByte for U2 {
300 #[inline(always)]
301 fn spread_to_byte(self) -> u8 {
302 let mut result = self.val();
303 result |= result << 2;
304 result |= result << 4;
305
306 result
307 }
308}
309
310impl SpreadToByte for U4 {
311 #[inline(always)]
312 fn spread_to_byte(self) -> u8 {
313 let mut result = self.val();
314 result |= result << 4;
315
316 result
317 }
318}
319
320#[allow(unused)]
325#[inline(always)]
326pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
327 value: U,
328 block_idx: usize,
329) -> [T; BLOCK_LEN]
330where
331 U: UnderlierWithBitOps + From<T>,
332 T: UnderlierType + NumCast<U>,
333{
334 std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
335}
336
337#[allow(unused)]
342#[inline(always)]
343pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
344 value: U,
345 block_idx: usize,
346) -> [u8; BLOCK_LEN]
347where
348 U: UnderlierWithBitOps + From<T>,
349 T: UnderlierType + SpreadToByte + NumCast<U>,
350{
351 unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
352 .map(SpreadToByte::spread_to_byte)
353}
354
355#[cfg(test)]
356mod tests {
357 use proptest::{arbitrary::any, bits, proptest};
358
359 use super::{
360 super::small_uint::{U1, U2, U4},
361 *,
362 };
363
364 #[test]
365 fn test_from_fn() {
366 assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
367 assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
368 assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
369
370 assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
371 assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
372 assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
373 assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
374 assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
375
376 assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
377 assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
378 assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
379 assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
380 assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
381
382 assert_eq!(u32::from_fn(|_| 0u8), 0);
383 assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
384 assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
385 assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
386 }
387
388 #[test]
389 fn test_broadcast_subvalue() {
390 assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
391 assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
392
393 assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
394 assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
395 assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
396 assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
397
398 assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
399 assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
400 assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
401 assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
402
403 assert_eq!(u32::broadcast_subvalue(0u8), 0);
404 assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
405 assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
406 }
407
408 #[test]
409 fn test_get_subvalue() {
410 let value = 0xab12cd34u32;
411
412 unsafe {
413 assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
414 assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
415 assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
416 assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
417
418 assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
419 assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
420 assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
421 assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
422
423 assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
424 assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
425 assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
426 assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
427
428 assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
429 assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
430 assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
431 assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
432 }
433 }
434
435 proptest! {
436 #[test]
437 fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
438 unsafe {
439 init_val.set_subvalue(i, U1::new(val));
440 assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
441 }
442 }
443
444 #[test]
445 fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
446 unsafe {
447 init_val.set_subvalue(i, U2::new(val));
448 assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
449 }
450 }
451
452 #[test]
453 fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
454 unsafe {
455 init_val.set_subvalue(i, U4::new(val));
456 assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
457 }
458 }
459
460 #[test]
461 fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
462 unsafe {
463 init_val.set_subvalue(i, val);
464 assert_eq!(init_val.get_subvalue::<u8>(i), val);
465 }
466 }
467 }
468}