1use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8 underlier_type::{NumCast, UnderlierType},
9 U1, U2, U4,
10};
11use crate::tower_levels::TowerLevel;
12
13pub trait UnderlierWithBitOps:
15 UnderlierType
16 + BitAnd<Self, Output = Self>
17 + BitAndAssign<Self>
18 + BitOr<Self, Output = Self>
19 + BitOrAssign<Self>
20 + BitXor<Self, Output = Self>
21 + BitXorAssign<Self>
22 + Shr<usize, Output = Self>
23 + Shl<usize, Output = Self>
24 + Not<Output = Self>
25{
26 const ZERO: Self;
27 const ONE: Self;
28 const ONES: Self;
29
30 fn fill_with_bit(val: u8) -> Self;
33
34 #[inline]
35 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
36 where
37 T: UnderlierType,
38 Self: From<T>,
39 {
40 let mut result = Self::default();
43 let width = checked_int_div(Self::BITS, T::BITS);
44 for i in 0..width {
45 result |= Self::from(f(i)) << (i * T::BITS);
46 }
47
48 result
49 }
50
51 #[inline]
54 fn broadcast_subvalue<T>(value: T) -> Self
55 where
56 T: UnderlierType,
57 Self: From<T>,
58 {
59 let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
62 let mut result = Self::from(value);
63 for i in 0..height {
64 result |= result << ((1 << i) * T::BITS);
65 }
66
67 result
68 }
69
70 #[inline]
76 unsafe fn get_subvalue<T>(&self, i: usize) -> T
77 where
78 T: UnderlierType + NumCast<Self>,
79 {
80 debug_assert!(
81 i < checked_int_div(Self::BITS, T::BITS),
82 "i: {} Self::BITS: {}, T::BITS: {}",
83 i,
84 Self::BITS,
85 T::BITS
86 );
87 T::num_cast_from(*self >> (i * T::BITS))
88 }
89
90 #[inline]
96 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
97 where
98 T: UnderlierWithBitOps,
99 Self: From<T>,
100 {
101 debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
102 let mask = Self::from(single_element_mask::<T>());
103
104 *self &= !(mask << (i * T::BITS));
105 *self |= Self::from(val) << (i * T::BITS);
106 }
107
108 #[inline]
115 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
116 where
117 T: UnderlierWithBitOps + NumCast<Self>,
118 Self: From<T>,
119 {
120 spread_fallback(self, log_block_len, block_idx)
121 }
122
123 fn shl_128b_lanes(self, shift: usize) -> Self;
126
127 fn shr_128b_lanes(self, shift: usize) -> Self;
130
131 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
140 unpack_lo_128b_fallback(self, other, log_block_len)
141 }
142
143 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
152 unpack_hi_128b_fallback(self, other, log_block_len)
153 }
154
155 fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
165 where
166 u8: NumCast<Self>,
167 Self: From<u8>,
168 {
169 assert!(TL::LOG_WIDTH <= 4);
170
171 let result = TL::from_fn(|row| {
172 Self::from_fn(|col| {
173 let index = row * (Self::BITS / 8) + col;
174
175 unsafe { values[index % TL::WIDTH].get_subvalue::<u8>(index / TL::WIDTH) }
177 })
178 });
179
180 *values = result;
181 }
182
183 fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
193 where
194 u8: NumCast<Self>,
195 Self: From<u8>,
196 {
197 assert!(TL::LOG_WIDTH <= 4);
198
199 let bytes = Self::BITS / 8;
200 let result = TL::from_fn(|row| {
201 Self::from_fn(|col| {
202 let index = row + col * TL::WIDTH;
203
204 unsafe { values[index / bytes].get_subvalue::<u8>(index % bytes) }
206 })
207 });
208
209 *values = result;
210 }
211}
212
213fn single_element_mask<T>() -> T
217where
218 T: UnderlierWithBitOps,
219{
220 single_element_mask_bits(T::BITS)
221}
222
223#[allow(dead_code)]
225#[inline(always)]
226pub(crate) fn pair_unpack_lo_hi_128b_lanes<U: UnderlierWithBitOps>(
227 values: &mut impl AsMut<[U]>,
228 i: usize,
229 j: usize,
230 log_block_len: usize,
231) {
232 let values = values.as_mut();
233
234 (values[i], values[j]) = (
235 values[i].unpack_lo_128b_lanes(values[j], log_block_len),
236 values[i].unpack_hi_128b_lanes(values[j], log_block_len),
237 );
238}
239
240#[allow(dead_code)]
243#[inline(always)]
244pub(crate) fn transpose_128b_blocks_low_to_high<U: UnderlierWithBitOps, TL: TowerLevel>(
245 values: &mut TL::Data<U>,
246 log_block_len: usize,
247) {
248 assert!(TL::WIDTH <= 16);
249
250 if TL::WIDTH == 1 {
251 return;
252 }
253
254 let (left, right) = TL::split_mut(values);
255 transpose_128b_blocks_low_to_high::<_, TL::Base>(left, log_block_len);
256 transpose_128b_blocks_low_to_high::<_, TL::Base>(right, log_block_len);
257
258 let log_block_len = log_block_len + TL::LOG_WIDTH + 2;
259 for i in 0..TL::WIDTH / 2 {
260 pair_unpack_lo_hi_128b_lanes(values, i, i + TL::WIDTH / 2, log_block_len);
261 }
262}
263
264#[allow(dead_code)]
267#[inline(always)]
268pub(crate) fn transpose_128b_values<U: UnderlierWithBitOps, TL: TowerLevel>(
269 values: &mut TL::Data<U>,
270 log_block_len: usize,
271) {
272 assert!(U::BITS == 128);
273
274 transpose_128b_blocks_low_to_high::<U, TL>(values, log_block_len);
275
276 match TL::LOG_WIDTH {
278 0 | 1 => {}
279 2 => {
280 values.as_mut().swap(1, 2);
281 }
282 3 => {
283 values.as_mut().swap(1, 4);
284 values.as_mut().swap(3, 6);
285 }
286 4 => {
287 values.as_mut().swap(1, 8);
288 values.as_mut().swap(2, 4);
289 values.as_mut().swap(3, 12);
290 values.as_mut().swap(5, 10);
291 values.as_mut().swap(7, 14);
292 values.as_mut().swap(11, 13);
293 }
294 _ => panic!("unsupported tower level"),
295 }
296}
297
298pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
304where
305 U: UnderlierWithBitOps + From<T>,
306 T: UnderlierWithBitOps + NumCast<U>,
307{
308 debug_assert!(
309 log_block_len + T::LOG_BITS <= U::LOG_BITS,
310 "log_block_len: {}, U::BITS: {}, T::BITS: {}",
311 log_block_len,
312 U::BITS,
313 T::BITS
314 );
315 debug_assert!(
316 block_idx < 1 << (U::LOG_BITS - log_block_len),
317 "block_idx: {}, U::BITS: {}, log_block_len: {}",
318 block_idx,
319 U::BITS,
320 log_block_len
321 );
322
323 let mut result = U::ZERO;
324 let block_offset = block_idx << log_block_len;
325 let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
326 for i in 0..1 << log_block_len {
327 unsafe {
328 result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
329 }
330 }
331
332 for i in 0..log_repeat {
333 result |= result << (1 << (T::LOG_BITS + i));
334 }
335
336 result
337}
338
339#[inline(always)]
340fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
341 let mut mask = single_element_mask_bits(1 << log_block_len);
342 for i in 1..T::BITS / 128 {
343 mask |= mask << (i * 128);
344 }
345
346 mask
347}
348
349pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
350 lhs: T,
351 rhs: T,
352 log_block_len: usize,
353) -> T {
354 assert!(log_block_len <= 6);
355
356 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
357
358 let mut result = T::ZERO;
359 for i in 0..1 << (6 - log_block_len) {
360 result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
361 .shl_128b_lanes(i << (log_block_len + 1));
362 result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
363 .shl_128b_lanes((2 * i + 1) << log_block_len);
364 }
365
366 result
367}
368
369pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
370 lhs: T,
371 rhs: T,
372 log_block_len: usize,
373) -> T {
374 assert!(log_block_len <= 6);
375
376 let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
377 let mut result = T::ZERO;
378 for i in 0..1 << (6 - log_block_len) {
379 result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
380 .shl_128b_lanes(i << (log_block_len + 1));
381 result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
382 .shl_128b_lanes((2 * i + 1) << log_block_len);
383 }
384
385 result
386}
387
388pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
389 if bits_count == T::BITS {
390 !T::ZERO
391 } else {
392 let mut result = T::ONE;
393 for height in 0..checked_log_2(bits_count) {
394 result |= result << (1 << height)
395 }
396
397 result
398 }
399}
400
401pub(crate) trait SpreadToByte {
403 fn spread_to_byte(self) -> u8;
404}
405
406impl SpreadToByte for U1 {
407 #[inline(always)]
408 fn spread_to_byte(self) -> u8 {
409 u8::fill_with_bit(self.val())
410 }
411}
412
413impl SpreadToByte for U2 {
414 #[inline(always)]
415 fn spread_to_byte(self) -> u8 {
416 let mut result = self.val();
417 result |= result << 2;
418 result |= result << 4;
419
420 result
421 }
422}
423
424impl SpreadToByte for U4 {
425 #[inline(always)]
426 fn spread_to_byte(self) -> u8 {
427 let mut result = self.val();
428 result |= result << 4;
429
430 result
431 }
432}
433
434#[allow(unused)]
439#[inline(always)]
440pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
441 value: U,
442 block_idx: usize,
443) -> [T; BLOCK_LEN]
444where
445 U: UnderlierWithBitOps + From<T>,
446 T: UnderlierType + NumCast<U>,
447{
448 std::array::from_fn(|i| value.get_subvalue::<T>(block_idx * BLOCK_LEN + i))
449}
450
451#[allow(unused)]
456#[inline(always)]
457pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
458 value: U,
459 block_idx: usize,
460) -> [u8; BLOCK_LEN]
461where
462 U: UnderlierWithBitOps + From<T>,
463 T: UnderlierType + SpreadToByte + NumCast<U>,
464{
465 get_block_values::<U, T, BLOCK_LEN>(value, block_idx).map(SpreadToByte::spread_to_byte)
466}
467
468#[cfg(test)]
469mod tests {
470 use proptest::{arbitrary::any, bits, proptest};
471
472 use super::{
473 super::small_uint::{U1, U2, U4},
474 *,
475 };
476 use crate::tower_levels::{TowerLevel1, TowerLevel2};
477
478 #[test]
479 fn test_from_fn() {
480 assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
481 assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
482 assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
483
484 assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
485 assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
486 assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
487 assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
488 assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
489
490 assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
491 assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
492 assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
493 assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
494 assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
495
496 assert_eq!(u32::from_fn(|_| 0u8), 0);
497 assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
498 assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
499 assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
500 }
501
502 #[test]
503 fn test_broadcast_subvalue() {
504 assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
505 assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
506
507 assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
508 assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
509 assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
510 assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
511
512 assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
513 assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
514 assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
515 assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
516
517 assert_eq!(u32::broadcast_subvalue(0u8), 0);
518 assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
519 assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
520 }
521
522 #[test]
523 fn test_get_subvalue() {
524 let value = 0xab12cd34u32;
525
526 unsafe {
527 assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
528 assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
529 assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
530 assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
531
532 assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
533 assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
534 assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
535 assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
536
537 assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
538 assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
539 assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
540 assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
541
542 assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
543 assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
544 assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
545 assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
546 }
547 }
548
549 proptest! {
550 #[test]
551 fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
552 unsafe {
553 init_val.set_subvalue(i, U1::new(val));
554 assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
555 }
556 }
557
558 #[test]
559 fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
560 unsafe {
561 init_val.set_subvalue(i, U2::new(val));
562 assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
563 }
564 }
565
566 #[test]
567 fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
568 unsafe {
569 init_val.set_subvalue(i, U4::new(val));
570 assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
571 }
572 }
573
574 #[test]
575 fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
576 unsafe {
577 init_val.set_subvalue(i, val);
578 assert_eq!(init_val.get_subvalue::<u8>(i), val);
579 }
580 }
581 }
582
583 #[test]
584 fn test_transpose_from_byte_sliced() {
585 let mut value = [0x01234567u32];
586 u32::transpose_bytes_from_byte_sliced::<TowerLevel1>(&mut value);
587 assert_eq!(value, [0x01234567u32]);
588
589 let mut value = [0x67452301u32, 0xefcdab89u32];
590 u32::transpose_bytes_from_byte_sliced::<TowerLevel2>(&mut value);
591 assert_eq!(value, [0xab238901u32, 0xef67cd45u32]);
592 }
593
594 #[test]
595 fn test_transpose_to_byte_sliced() {
596 let mut value = [0x01234567u32];
597 u32::transpose_bytes_to_byte_sliced::<TowerLevel1>(&mut value);
598 assert_eq!(value, [0x01234567u32]);
599
600 let mut value = [0x67452301u32, 0xefcdab89u32];
601 u32::transpose_bytes_to_byte_sliced::<TowerLevel2>(&mut value);
602 assert_eq!(value, [0xcd894501u32, 0xefab6723u32]);
603 }
604}