1use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8 U1, U2, U4,
9 underlier_type::{NumCast, UnderlierType},
10};
11use crate::tower_levels::TowerLevel;
12
13pub trait UnderlierWithBitOps:
15 UnderlierType
16 + BitAnd<Self, Output = Self>
17 + BitAndAssign<Self>
18 + BitOr<Self, Output = Self>
19 + BitOrAssign<Self>
20 + BitXor<Self, Output = Self>
21 + BitXorAssign<Self>
22 + Shr<usize, Output = Self>
23 + Shl<usize, Output = Self>
24 + Not<Output = Self>
25{
26 const ZERO: Self;
27 const ONE: Self;
28 const ONES: Self;
29
30 fn fill_with_bit(val: u8) -> Self;
33
34 #[inline]
35 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
36 where
37 T: UnderlierType,
38 Self: From<T>,
39 {
40 let mut result = Self::default();
43 let width = checked_int_div(Self::BITS, T::BITS);
44 for i in 0..width {
45 result |= Self::from(f(i)) << (i * T::BITS);
46 }
47
48 result
49 }
50
51 #[inline]
54 fn broadcast_subvalue<T>(value: T) -> Self
55 where
56 T: UnderlierType,
57 Self: From<T>,
58 {
59 let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
62 let mut result = Self::from(value);
63 for i in 0..height {
64 result |= result << ((1 << i) * T::BITS);
65 }
66
67 result
68 }
69
70 #[inline]
76 unsafe fn get_subvalue<T>(&self, i: usize) -> T
77 where
78 T: UnderlierType + NumCast<Self>,
79 {
80 debug_assert!(
81 i < checked_int_div(Self::BITS, T::BITS),
82 "i: {} Self::BITS: {}, T::BITS: {}",
83 i,
84 Self::BITS,
85 T::BITS
86 );
87 T::num_cast_from(*self >> (i * T::BITS))
88 }
89
90 #[inline]
96 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
97 where
98 T: UnderlierWithBitOps,
99 Self: From<T>,
100 {
101 debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
102 let mask = Self::from(single_element_mask::<T>());
103
104 *self &= !(mask << (i * T::BITS));
105 *self |= Self::from(val) << (i * T::BITS);
106 }
107
108 #[inline]
115 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
116 where
117 T: UnderlierWithBitOps + NumCast<Self>,
118 Self: From<T>,
119 {
120 unsafe { spread_fallback(self, log_block_len, block_idx) }
121 }
122
123 fn shl_128b_lanes(self, shift: usize) -> Self;
126
127 fn shr_128b_lanes(self, shift: usize) -> Self;
130
131 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
141 unpack_lo_128b_fallback(self, other, log_block_len)
142 }
143
144 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
154 unpack_hi_128b_fallback(self, other, log_block_len)
155 }
156
157 fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
167 where
168 u8: NumCast<Self>,
169 Self: From<u8>,
170 {
171 assert!(TL::LOG_WIDTH <= 4);
172
173 let result = TL::from_fn(|row| {
174 Self::from_fn(|col| {
175 let index = row * (Self::BITS / 8) + col;
176
177 unsafe { values[index % TL::WIDTH].get_subvalue::<u8>(index / TL::WIDTH) }
179 })
180 });
181
182 *values = result;
183 }
184
185 fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
195 where
196 u8: NumCast<Self>,
197 Self: From<u8>,
198 {
199 assert!(TL::LOG_WIDTH <= 4);
200
201 let bytes = Self::BITS / 8;
202 let result = TL::from_fn(|row| {
203 Self::from_fn(|col| {
204 let index = row + col * TL::WIDTH;
205
206 unsafe { values[index / bytes].get_subvalue::<u8>(index % bytes) }
208 })
209 });
210
211 *values = result;
212 }
213}
214
215fn single_element_mask<T>() -> T
219where
220 T: UnderlierWithBitOps,
221{
222 single_element_mask_bits(T::BITS)
223}
224
225#[allow(dead_code)]
227#[inline(always)]
228pub(crate) fn pair_unpack_lo_hi_128b_lanes<U: UnderlierWithBitOps>(
229 values: &mut impl AsMut<[U]>,
230 i: usize,
231 j: usize,
232 log_block_len: usize,
233) {
234 let values = values.as_mut();
235
236 (values[i], values[j]) = (
237 values[i].unpack_lo_128b_lanes(values[j], log_block_len),
238 values[i].unpack_hi_128b_lanes(values[j], log_block_len),
239 );
240}
241
242#[allow(dead_code)]
245#[inline(always)]
246pub(crate) fn transpose_128b_blocks_low_to_high<U: UnderlierWithBitOps, TL: TowerLevel>(
247 values: &mut TL::Data<U>,
248 log_block_len: usize,
249) {
250 assert!(TL::WIDTH <= 16);
251
252 if TL::WIDTH == 1 {
253 return;
254 }
255
256 let (left, right) = TL::split_mut(values);
257 transpose_128b_blocks_low_to_high::<_, TL::Base>(left, log_block_len);
258 transpose_128b_blocks_low_to_high::<_, TL::Base>(right, log_block_len);
259
260 let log_block_len = log_block_len + TL::LOG_WIDTH + 2;
261 for i in 0..TL::WIDTH / 2 {
262 pair_unpack_lo_hi_128b_lanes(values, i, i + TL::WIDTH / 2, log_block_len);
263 }
264}
265
266#[allow(dead_code)]
269#[inline(always)]
270pub(crate) fn transpose_128b_values<U: UnderlierWithBitOps, TL: TowerLevel>(
271 values: &mut TL::Data<U>,
272 log_block_len: usize,
273) {
274 assert!(U::BITS == 128);
275
276 transpose_128b_blocks_low_to_high::<U, TL>(values, log_block_len);
277
278 match TL::LOG_WIDTH {
280 0 | 1 => {}
281 2 => {
282 values.as_mut().swap(1, 2);
283 }
284 3 => {
285 values.as_mut().swap(1, 4);
286 values.as_mut().swap(3, 6);
287 }
288 4 => {
289 values.as_mut().swap(1, 8);
290 values.as_mut().swap(2, 4);
291 values.as_mut().swap(3, 12);
292 values.as_mut().swap(5, 10);
293 values.as_mut().swap(7, 14);
294 values.as_mut().swap(11, 13);
295 }
296 _ => panic!("unsupported tower level"),
297 }
298}
299
300pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
306where
307 U: UnderlierWithBitOps + From<T>,
308 T: UnderlierWithBitOps + NumCast<U>,
309{
310 debug_assert!(
311 log_block_len + T::LOG_BITS <= U::LOG_BITS,
312 "log_block_len: {}, U::BITS: {}, T::BITS: {}",
313 log_block_len,
314 U::BITS,
315 T::BITS
316 );
317 debug_assert!(
318 block_idx < 1 << (U::LOG_BITS - log_block_len),
319 "block_idx: {}, U::BITS: {}, log_block_len: {}",
320 block_idx,
321 U::BITS,
322 log_block_len
323 );
324
325 let mut result = U::ZERO;
326 let block_offset = block_idx << log_block_len;
327 let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
328 for i in 0..1 << log_block_len {
329 unsafe {
330 result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
331 }
332 }
333
334 for i in 0..log_repeat {
335 result |= result << (1 << (T::LOG_BITS + i));
336 }
337
338 result
339}
340
341#[inline(always)]
342fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
343 let mut mask = single_element_mask_bits(1 << log_block_len);
344 for i in 1..T::BITS / 128 {
345 mask |= mask << (i * 128);
346 }
347
348 mask
349}
350
351pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
352 lhs: T,
353 rhs: T,
354 log_block_len: usize,
355) -> T {
356 assert!(log_block_len <= 6);
357
358 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
359
360 let mut result = T::ZERO;
361 for i in 0..1 << (6 - log_block_len) {
362 result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
363 .shl_128b_lanes(i << (log_block_len + 1));
364 result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
365 .shl_128b_lanes((2 * i + 1) << log_block_len);
366 }
367
368 result
369}
370
371pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
372 lhs: T,
373 rhs: T,
374 log_block_len: usize,
375) -> T {
376 assert!(log_block_len <= 6);
377
378 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
379 let mut result = T::ZERO;
380 for i in 0..1 << (6 - log_block_len) {
381 result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
382 .shl_128b_lanes(i << (log_block_len + 1));
383 result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
384 .shl_128b_lanes((2 * i + 1) << log_block_len);
385 }
386
387 result
388}
389
390pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
391 if bits_count == T::BITS {
392 !T::ZERO
393 } else {
394 let mut result = T::ONE;
395 for height in 0..checked_log_2(bits_count) {
396 result |= result << (1 << height)
397 }
398
399 result
400 }
401}
402
403pub(crate) trait SpreadToByte {
405 fn spread_to_byte(self) -> u8;
406}
407
408impl SpreadToByte for U1 {
409 #[inline(always)]
410 fn spread_to_byte(self) -> u8 {
411 u8::fill_with_bit(self.val())
412 }
413}
414
415impl SpreadToByte for U2 {
416 #[inline(always)]
417 fn spread_to_byte(self) -> u8 {
418 let mut result = self.val();
419 result |= result << 2;
420 result |= result << 4;
421
422 result
423 }
424}
425
426impl SpreadToByte for U4 {
427 #[inline(always)]
428 fn spread_to_byte(self) -> u8 {
429 let mut result = self.val();
430 result |= result << 4;
431
432 result
433 }
434}
435
436#[allow(unused)]
441#[inline(always)]
442pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
443 value: U,
444 block_idx: usize,
445) -> [T; BLOCK_LEN]
446where
447 U: UnderlierWithBitOps + From<T>,
448 T: UnderlierType + NumCast<U>,
449{
450 std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
451}
452
453#[allow(unused)]
458#[inline(always)]
459pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
460 value: U,
461 block_idx: usize,
462) -> [u8; BLOCK_LEN]
463where
464 U: UnderlierWithBitOps + From<T>,
465 T: UnderlierType + SpreadToByte + NumCast<U>,
466{
467 unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
468 .map(SpreadToByte::spread_to_byte)
469}
470
471#[cfg(test)]
472mod tests {
473 use proptest::{arbitrary::any, bits, proptest};
474
475 use super::{
476 super::small_uint::{U1, U2, U4},
477 *,
478 };
479 use crate::tower_levels::{TowerLevel1, TowerLevel2};
480
481 #[test]
482 fn test_from_fn() {
483 assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
484 assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
485 assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
486
487 assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
488 assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
489 assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
490 assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
491 assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
492
493 assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
494 assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
495 assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
496 assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
497 assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
498
499 assert_eq!(u32::from_fn(|_| 0u8), 0);
500 assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
501 assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
502 assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
503 }
504
505 #[test]
506 fn test_broadcast_subvalue() {
507 assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
508 assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
509
510 assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
511 assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
512 assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
513 assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
514
515 assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
516 assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
517 assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
518 assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
519
520 assert_eq!(u32::broadcast_subvalue(0u8), 0);
521 assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
522 assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
523 }
524
525 #[test]
526 fn test_get_subvalue() {
527 let value = 0xab12cd34u32;
528
529 unsafe {
530 assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
531 assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
532 assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
533 assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
534
535 assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
536 assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
537 assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
538 assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
539
540 assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
541 assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
542 assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
543 assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
544
545 assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
546 assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
547 assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
548 assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
549 }
550 }
551
552 proptest! {
553 #[test]
554 fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
555 unsafe {
556 init_val.set_subvalue(i, U1::new(val));
557 assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
558 }
559 }
560
561 #[test]
562 fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
563 unsafe {
564 init_val.set_subvalue(i, U2::new(val));
565 assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
566 }
567 }
568
569 #[test]
570 fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
571 unsafe {
572 init_val.set_subvalue(i, U4::new(val));
573 assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
574 }
575 }
576
577 #[test]
578 fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
579 unsafe {
580 init_val.set_subvalue(i, val);
581 assert_eq!(init_val.get_subvalue::<u8>(i), val);
582 }
583 }
584 }
585
586 #[test]
587 fn test_transpose_from_byte_sliced() {
588 let mut value = [0x01234567u32];
589 u32::transpose_bytes_from_byte_sliced::<TowerLevel1>(&mut value);
590 assert_eq!(value, [0x01234567u32]);
591
592 let mut value = [0x67452301u32, 0xefcdab89u32];
593 u32::transpose_bytes_from_byte_sliced::<TowerLevel2>(&mut value);
594 assert_eq!(value, [0xab238901u32, 0xef67cd45u32]);
595 }
596
597 #[test]
598 fn test_transpose_to_byte_sliced() {
599 let mut value = [0x01234567u32];
600 u32::transpose_bytes_to_byte_sliced::<TowerLevel1>(&mut value);
601 assert_eq!(value, [0x01234567u32]);
602
603 let mut value = [0x67452301u32, 0xefcdab89u32];
604 u32::transpose_bytes_to_byte_sliced::<TowerLevel2>(&mut value);
605 assert_eq!(value, [0xcd894501u32, 0xefab6723u32]);
606 }
607}