1use std::mem::size_of;
4
5pub trait Divisible<T>: Copy {
23 const LOG_N: usize;
25
26 const N: usize = 1 << Self::LOG_N;
28
29 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = T> + Send + Clone;
31
32 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_;
34
35 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_;
37
38 fn get(self, index: usize) -> T;
44
45 fn set(self, index: usize, val: T) -> Self;
51}
52
53pub mod memcast {
57 use bytemuck::Pod;
58
59 #[cfg(target_endian = "little")]
61 #[inline]
62 pub fn value_iter<Big, Small, const N: usize>(
63 value: Big,
64 ) -> impl ExactSizeIterator<Item = Small> + Send + Clone
65 where
66 Big: Pod,
67 Small: Pod + Send,
68 {
69 bytemuck::must_cast::<Big, [Small; N]>(value).into_iter()
70 }
71
72 #[cfg(target_endian = "big")]
74 #[inline]
75 pub fn value_iter<Big, Small, const N: usize>(
76 value: Big,
77 ) -> impl ExactSizeIterator<Item = Small> + Send + Clone
78 where
79 Big: Pod,
80 Small: Pod + Send,
81 {
82 bytemuck::must_cast::<Big, [Small; N]>(value)
83 .into_iter()
84 .rev()
85 }
86
87 #[cfg(target_endian = "little")]
89 #[inline]
90 pub fn ref_iter<Big, Small, const N: usize>(
91 value: &Big,
92 ) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
93 where
94 Big: Pod,
95 Small: Pod + Send + Sync,
96 {
97 bytemuck::must_cast_ref::<Big, [Small; N]>(value)
98 .iter()
99 .copied()
100 }
101
102 #[cfg(target_endian = "big")]
104 #[inline]
105 pub fn ref_iter<Big, Small, const N: usize>(
106 value: &Big,
107 ) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
108 where
109 Big: Pod,
110 Small: Pod + Send + Sync,
111 {
112 bytemuck::must_cast_ref::<Big, [Small; N]>(value)
113 .iter()
114 .rev()
115 .copied()
116 }
117
118 #[cfg(target_endian = "little")]
120 #[inline]
121 pub fn slice_iter<Big, Small>(
122 slice: &[Big],
123 ) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
124 where
125 Big: Pod,
126 Small: Pod + Send + Sync,
127 {
128 bytemuck::must_cast_slice::<Big, Small>(slice)
129 .iter()
130 .copied()
131 }
132
133 #[cfg(target_endian = "big")]
138 #[inline]
139 pub fn slice_iter<Big, Small, const LOG_N: usize>(
140 slice: &[Big],
141 ) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
142 where
143 Big: Pod,
144 Small: Pod + Send + Sync,
145 {
146 const N: usize = 1 << LOG_N;
147 let raw_slice = bytemuck::must_cast_slice::<Big, Small>(slice);
148 (0..raw_slice.len()).map(move |i| {
149 let element_idx = i >> LOG_N;
150 let sub_idx = i & (N - 1);
151 let reversed_sub_idx = N - 1 - sub_idx;
152 let raw_idx = element_idx * N + reversed_sub_idx;
153 raw_slice[raw_idx]
154 })
155 }
156
157 #[cfg(target_endian = "little")]
159 #[inline]
160 pub fn get<Big, Small, const N: usize>(value: &Big, index: usize) -> Small
161 where
162 Big: Pod,
163 Small: Pod,
164 {
165 bytemuck::must_cast_ref::<Big, [Small; N]>(value)[index]
166 }
167
168 #[cfg(target_endian = "big")]
170 #[inline]
171 pub fn get<Big, Small, const N: usize>(value: &Big, index: usize) -> Small
172 where
173 Big: Pod,
174 Small: Pod,
175 {
176 bytemuck::must_cast_ref::<Big, [Small; N]>(value)[N - 1 - index]
177 }
178
179 #[cfg(target_endian = "little")]
181 #[inline]
182 pub fn set<Big, Small, const N: usize>(value: &Big, index: usize, val: Small) -> Big
183 where
184 Big: Pod,
185 Small: Pod,
186 {
187 let mut arr = *bytemuck::must_cast_ref::<Big, [Small; N]>(value);
188 arr[index] = val;
189 bytemuck::must_cast(arr)
190 }
191
192 #[cfg(target_endian = "big")]
194 #[inline]
195 pub fn set<Big, Small, const N: usize>(value: &Big, index: usize, val: Small) -> Big
196 where
197 Big: Pod,
198 Small: Pod,
199 {
200 let mut arr = *bytemuck::must_cast_ref::<Big, [Small; N]>(value);
201 arr[N - 1 - index] = val;
202 bytemuck::must_cast(arr)
203 }
204}
205
206pub mod bitmask {
211 use super::{Divisible, SmallU};
212
213 #[inline]
215 pub fn get<Big, const BITS: usize>(value: Big, index: usize) -> SmallU<BITS>
216 where
217 Big: Divisible<u8>,
218 {
219 let elems_per_byte = 8 / BITS;
220 let byte_index = index / elems_per_byte;
221 let sub_index = index % elems_per_byte;
222 let byte = Divisible::<u8>::get(value, byte_index);
223 let shift = sub_index * BITS;
224 SmallU::<BITS>::new(byte >> shift)
225 }
226
227 #[inline]
229 pub fn set<Big, const BITS: usize>(value: Big, index: usize, val: SmallU<BITS>) -> Big
230 where
231 Big: Divisible<u8>,
232 {
233 let elems_per_byte = 8 / BITS;
234 let byte_index = index / elems_per_byte;
235 let sub_index = index % elems_per_byte;
236 let byte = Divisible::<u8>::get(value, byte_index);
237 let shift = sub_index * BITS;
238 let mask = (1u8 << BITS) - 1;
239 let new_byte = (byte & !(mask << shift)) | (val.val() << shift);
240 Divisible::<u8>::set(value, byte_index, new_byte)
241 }
242}
243
244pub mod mapget {
249 use binius_utils::iter::IterExtensions;
250
251 use super::Divisible;
252
253 #[inline]
255 pub fn value_iter<Big, Small>(value: Big) -> impl ExactSizeIterator<Item = Small> + Send + Clone
256 where
257 Big: Divisible<Small> + Send,
258 Small: Send,
259 {
260 (0..Big::N).map_skippable(move |i| Divisible::<Small>::get(value, i))
261 }
262
263 #[inline]
265 pub fn slice_iter<Big, Small>(
266 slice: &[Big],
267 ) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
268 where
269 Big: Divisible<Small> + Send + Sync,
270 Small: Send,
271 {
272 let total = slice.len() * Big::N;
273 (0..total).map_skippable(move |global_idx| {
274 let elem_idx = global_idx / Big::N;
275 let sub_idx = global_idx % Big::N;
276 Divisible::<Small>::get(slice[elem_idx], sub_idx)
277 })
278 }
279}
280
281#[derive(Clone)]
286pub struct SmallUDivisIter<I, const N: usize> {
287 byte_iter: I,
288 current_byte: Option<u8>,
289 sub_idx: usize,
290}
291
292impl<I: Iterator<Item = u8>, const N: usize> SmallUDivisIter<I, N> {
293 const ELEMS_PER_BYTE: usize = 8 / N;
294
295 pub fn new(mut byte_iter: I) -> Self {
296 let current_byte = byte_iter.next();
297 Self {
298 byte_iter,
299 current_byte,
300 sub_idx: 0,
301 }
302 }
303}
304
305impl<I: ExactSizeIterator<Item = u8>, const N: usize> Iterator for SmallUDivisIter<I, N> {
306 type Item = SmallU<N>;
307
308 #[inline]
309 fn next(&mut self) -> Option<Self::Item> {
310 let byte = self.current_byte?;
311 let shift = self.sub_idx * N;
312 let result = SmallU::<N>::new(byte >> shift);
313
314 self.sub_idx += 1;
315 if self.sub_idx >= Self::ELEMS_PER_BYTE {
316 self.sub_idx = 0;
317 self.current_byte = self.byte_iter.next();
318 }
319
320 Some(result)
321 }
322
323 #[inline]
324 fn size_hint(&self) -> (usize, Option<usize>) {
325 let remaining_in_current = if self.current_byte.is_some() {
326 Self::ELEMS_PER_BYTE - self.sub_idx
327 } else {
328 0
329 };
330 let remaining_bytes = self.byte_iter.len();
331 let total = remaining_in_current + remaining_bytes * Self::ELEMS_PER_BYTE;
332 (total, Some(total))
333 }
334}
335
336impl<I: ExactSizeIterator<Item = u8>, const N: usize> ExactSizeIterator for SmallUDivisIter<I, N> {}
337
338macro_rules! impl_divisible_memcast {
343 ($big:ty, $($small:ty),+) => {
344 $(
345 impl $crate::underlier::Divisible<$small> for $big {
346 const LOG_N: usize = (size_of::<$big>() / size_of::<$small>()).ilog2() as usize;
347
348 #[inline]
349 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $small> + Send + Clone {
350 const N: usize = size_of::<$big>() / size_of::<$small>();
351 $crate::underlier::memcast::value_iter::<$big, $small, N>(value)
352 }
353
354 #[inline]
355 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $small> + Send + Clone + '_ {
356 const N: usize = size_of::<$big>() / size_of::<$small>();
357 $crate::underlier::memcast::ref_iter::<$big, $small, N>(value)
358 }
359
360 #[inline]
361 #[cfg(target_endian = "little")]
362 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $small> + Send + Clone + '_ {
363 $crate::underlier::memcast::slice_iter::<$big, $small>(slice)
364 }
365
366 #[inline]
367 #[cfg(target_endian = "big")]
368 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $small> + Send + Clone + '_ {
369 const LOG_N: usize = (size_of::<$big>() / size_of::<$small>()).ilog2() as usize;
370 $crate::underlier::memcast::slice_iter::<$big, $small, LOG_N>(slice)
371 }
372
373 #[inline]
374 fn get(self, index: usize) -> $small {
375 const N: usize = size_of::<$big>() / size_of::<$small>();
376 $crate::underlier::memcast::get::<$big, $small, N>(&self, index)
377 }
378
379 #[inline]
380 fn set(self, index: usize, val: $small) -> Self {
381 const N: usize = size_of::<$big>() / size_of::<$small>();
382 $crate::underlier::memcast::set::<$big, $small, N>(&self, index, val)
383 }
384 }
385 )+
386 };
387}
388
389#[allow(unused)]
390pub(crate) use impl_divisible_memcast;
391
392macro_rules! impl_divisible_bitmask {
397 (u8, $($bits:expr),+) => {
399 $(
400 impl $crate::underlier::Divisible<$crate::underlier::SmallU<$bits>> for u8 {
401 const LOG_N: usize = (8usize / $bits).ilog2() as usize;
402
403 #[inline]
404 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone {
405 $crate::underlier::SmallUDivisIter::new(std::iter::once(value))
406 }
407
408 #[inline]
409 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
410 $crate::underlier::SmallUDivisIter::new(std::iter::once(*value))
411 }
412
413 #[inline]
414 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
415 $crate::underlier::SmallUDivisIter::new(slice.iter().copied())
416 }
417
418 #[inline]
419 fn get(self, index: usize) -> $crate::underlier::SmallU<$bits> {
420 let shift = index * $bits;
421 $crate::underlier::SmallU::<$bits>::new(self >> shift)
422 }
423
424 #[inline]
425 fn set(self, index: usize, val: $crate::underlier::SmallU<$bits>) -> Self {
426 let shift = index * $bits;
427 let mask = (1u8 << $bits) - 1;
428 (self & !(mask << shift)) | (val.val() << shift)
429 }
430 }
431 )+
432 };
433
434 ($big:ty, $($bits:expr),+) => {
436 $(
437 impl $crate::underlier::Divisible<$crate::underlier::SmallU<$bits>> for $big {
438 const LOG_N: usize = (8 * size_of::<$big>() / $bits).ilog2() as usize;
439
440 #[inline]
441 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone {
442 $crate::underlier::SmallUDivisIter::new(
443 $crate::underlier::Divisible::<u8>::value_iter(value)
444 )
445 }
446
447 #[inline]
448 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
449 $crate::underlier::SmallUDivisIter::new(
450 $crate::underlier::Divisible::<u8>::ref_iter(value)
451 )
452 }
453
454 #[inline]
455 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
456 $crate::underlier::SmallUDivisIter::new(
457 $crate::underlier::Divisible::<u8>::slice_iter(slice)
458 )
459 }
460
461 #[inline]
462 fn get(self, index: usize) -> $crate::underlier::SmallU<$bits> {
463 $crate::underlier::bitmask::get::<Self, $bits>(self, index)
464 }
465
466 #[inline]
467 fn set(self, index: usize, val: $crate::underlier::SmallU<$bits>) -> Self {
468 $crate::underlier::bitmask::set::<Self, $bits>(self, index, val)
469 }
470 }
471 )+
472 };
473}
474
475#[allow(unused)]
476pub(crate) use impl_divisible_bitmask;
477
478use super::small_uint::SmallU;
479
480impl_divisible_memcast!(u128, u64, u32, u16, u8);
482impl_divisible_memcast!(u64, u32, u16, u8);
483impl_divisible_memcast!(u32, u16, u8);
484impl_divisible_memcast!(u16, u8);
485
486impl_divisible_bitmask!(u8, 1, 2, 4);
488impl_divisible_bitmask!(u16, 1, 2, 4);
489impl_divisible_bitmask!(u32, 1, 2, 4);
490impl_divisible_bitmask!(u64, 1, 2, 4);
491impl_divisible_bitmask!(u128, 1, 2, 4);
492
493impl Divisible<SmallU<1>> for SmallU<2> {
495 const LOG_N: usize = 1;
496
497 #[inline]
498 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone {
499 mapget::value_iter(value)
500 }
501
502 #[inline]
503 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
504 mapget::value_iter(*value)
505 }
506
507 #[inline]
508 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
509 mapget::slice_iter(slice)
510 }
511
512 #[inline]
513 fn get(self, index: usize) -> SmallU<1> {
514 SmallU::<1>::new(self.val() >> index)
515 }
516
517 #[inline]
518 fn set(self, index: usize, val: SmallU<1>) -> Self {
519 let mask = 1u8 << index;
520 SmallU::<2>::new((self.val() & !mask) | (val.val() << index))
521 }
522}
523
524impl Divisible<SmallU<1>> for SmallU<4> {
525 const LOG_N: usize = 2;
526
527 #[inline]
528 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone {
529 mapget::value_iter(value)
530 }
531
532 #[inline]
533 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
534 mapget::value_iter(*value)
535 }
536
537 #[inline]
538 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
539 mapget::slice_iter(slice)
540 }
541
542 #[inline]
543 fn get(self, index: usize) -> SmallU<1> {
544 SmallU::<1>::new(self.val() >> index)
545 }
546
547 #[inline]
548 fn set(self, index: usize, val: SmallU<1>) -> Self {
549 let mask = 1u8 << index;
550 SmallU::<4>::new((self.val() & !mask) | (val.val() << index))
551 }
552}
553
554impl Divisible<SmallU<2>> for SmallU<4> {
555 const LOG_N: usize = 1;
556
557 #[inline]
558 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = SmallU<2>> + Send + Clone {
559 mapget::value_iter(value)
560 }
561
562 #[inline]
563 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = SmallU<2>> + Send + Clone + '_ {
564 mapget::value_iter(*value)
565 }
566
567 #[inline]
568 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = SmallU<2>> + Send + Clone + '_ {
569 mapget::slice_iter(slice)
570 }
571
572 #[inline]
573 fn get(self, index: usize) -> SmallU<2> {
574 SmallU::<2>::new(self.val() >> (index * 2))
575 }
576
577 #[inline]
578 fn set(self, index: usize, val: SmallU<2>) -> Self {
579 let shift = index * 2;
580 let mask = 0b11u8 << shift;
581 SmallU::<4>::new((self.val() & !mask) | (val.val() << shift))
582 }
583}
584
585impl<T: Copy + Send + Sync> Divisible<T> for T {
587 const LOG_N: usize = 0;
588
589 #[inline]
590 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = T> + Send + Clone {
591 std::iter::once(value)
592 }
593
594 #[inline]
595 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
596 std::iter::once(*value)
597 }
598
599 #[inline]
600 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
601 slice.iter().copied()
602 }
603
604 #[inline]
605 fn get(self, index: usize) -> T {
606 debug_assert_eq!(index, 0);
607 self
608 }
609
610 #[inline]
611 fn set(self, index: usize, val: T) -> Self {
612 debug_assert_eq!(index, 0);
613 val
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620 use crate::underlier::small_uint::{U1, U2, U4};
621
622 #[test]
623 fn test_divisible_u8_u4() {
624 let val: u8 = 0x34;
625
626 assert_eq!(Divisible::<U4>::get(val, 0), U4::new(0x4));
628 assert_eq!(Divisible::<U4>::get(val, 1), U4::new(0x3));
629
630 let modified = Divisible::<U4>::set(val, 0, U4::new(0xF));
632 assert_eq!(modified, 0x3F);
633 let modified = Divisible::<U4>::set(val, 1, U4::new(0xA));
634 assert_eq!(modified, 0xA4);
635
636 let parts: Vec<U4> = Divisible::<U4>::ref_iter(&val).collect();
638 assert_eq!(parts.len(), 2);
639 assert_eq!(parts[0], U4::new(0x4));
640 assert_eq!(parts[1], U4::new(0x3));
641
642 let parts: Vec<U4> = Divisible::<U4>::value_iter(val).collect();
644 assert_eq!(parts.len(), 2);
645 assert_eq!(parts[0], U4::new(0x4));
646 assert_eq!(parts[1], U4::new(0x3));
647
648 let vals = [0x34u8, 0x56u8];
650 let parts: Vec<U4> = Divisible::<U4>::slice_iter(&vals).collect();
651 assert_eq!(parts.len(), 4);
652 assert_eq!(parts[0], U4::new(0x4));
653 assert_eq!(parts[1], U4::new(0x3));
654 assert_eq!(parts[2], U4::new(0x6));
655 assert_eq!(parts[3], U4::new(0x5));
656 }
657
658 #[test]
659 fn test_divisible_u16_u4() {
660 let val: u16 = 0x1234;
661
662 assert_eq!(Divisible::<U4>::get(val, 0), U4::new(0x4));
664 assert_eq!(Divisible::<U4>::get(val, 1), U4::new(0x3));
665 assert_eq!(Divisible::<U4>::get(val, 2), U4::new(0x2));
666 assert_eq!(Divisible::<U4>::get(val, 3), U4::new(0x1));
667
668 let modified = Divisible::<U4>::set(val, 1, U4::new(0xF));
670 assert_eq!(modified, 0x12F4);
671
672 let parts: Vec<U4> = Divisible::<U4>::ref_iter(&val).collect();
674 assert_eq!(parts.len(), 4);
675 assert_eq!(parts[0], U4::new(0x4));
676 assert_eq!(parts[3], U4::new(0x1));
677 }
678
679 #[test]
680 fn test_divisible_u16_u2() {
681 let val: u16 = 0b1011001011010011;
683
684 assert_eq!(Divisible::<U2>::get(val, 0), U2::new(0b11)); assert_eq!(Divisible::<U2>::get(val, 1), U2::new(0b00)); assert_eq!(Divisible::<U2>::get(val, 7), U2::new(0b10)); let parts: Vec<U2> = Divisible::<U2>::ref_iter(&val).collect();
691 assert_eq!(parts.len(), 8);
692 assert_eq!(parts[0], U2::new(0b11));
693 assert_eq!(parts[7], U2::new(0b10));
694 }
695
696 #[test]
697 fn test_divisible_u16_u1() {
698 let val: u16 = 0b1010110000110101;
700
701 assert_eq!(Divisible::<U1>::get(val, 0), U1::new(1)); assert_eq!(Divisible::<U1>::get(val, 1), U1::new(0)); assert_eq!(Divisible::<U1>::get(val, 15), U1::new(1)); let modified = Divisible::<U1>::set(val, 0, U1::new(0));
708 assert_eq!(modified, 0b1010110000110100);
709
710 let parts: Vec<U1> = Divisible::<U1>::ref_iter(&val).collect();
712 assert_eq!(parts.len(), 16);
713 assert_eq!(parts[0], U1::new(1));
714 assert_eq!(parts[15], U1::new(1));
715 }
716
717 #[test]
718 fn test_divisible_u64_u4() {
719 let val: u64 = 0x123456789ABCDEF0;
720
721 assert_eq!(Divisible::<U4>::get(val, 0), U4::new(0x0));
723 assert_eq!(Divisible::<U4>::get(val, 1), U4::new(0xF));
724 assert_eq!(Divisible::<U4>::get(val, 15), U4::new(0x1));
725
726 let parts: Vec<U4> = Divisible::<U4>::ref_iter(&val).collect();
728 assert_eq!(parts.len(), 16);
729 }
730
731 #[test]
732 fn test_divisible_u32_u8_slice() {
733 let vals: [u32; 2] = [0x04030201, 0x08070605];
734
735 let parts: Vec<u8> = Divisible::<u8>::slice_iter(&vals).collect();
737 assert_eq!(parts.len(), 8);
738 assert_eq!(parts[0], 0x01);
740 assert_eq!(parts[1], 0x02);
741 assert_eq!(parts[2], 0x03);
742 assert_eq!(parts[3], 0x04);
743 assert_eq!(parts[4], 0x05);
744 assert_eq!(parts[5], 0x06);
745 assert_eq!(parts[6], 0x07);
746 assert_eq!(parts[7], 0x08);
747 }
748}