binius_field/underlier/
divisible.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::mem::size_of;
4
5/// Divides an underlier type into smaller underliers in memory and iterates over them.
6///
7/// [`Divisible`] provides iteration over the subdivisions of an underlier type, guaranteeing that
8/// iteration proceeds from the least significant bits to the most significant bits, regardless of
9/// the CPU architecture's endianness.
10///
11/// # Endianness Handling
12///
13/// To ensure consistent LSB-to-MSB iteration order across all platforms:
14/// - On little-endian systems: elements are naturally ordered LSB-to-MSB in memory, so iteration
15///   proceeds forward through the array
16/// - On big-endian systems: elements are ordered MSB-to-LSB in memory, so iteration is reversed to
17///   achieve LSB-to-MSB order
18///
19/// This abstraction allows code to work with subdivided underliers in a platform-independent way
20/// while maintaining the invariant that the first element always represents the least significant
21/// portion of the value.
22pub trait Divisible<T>: Copy {
23	/// The log2 of the number of `T` elements that fit in `Self`.
24	const LOG_N: usize;
25
26	/// The number of `T` elements that fit in `Self`.
27	const N: usize = 1 << Self::LOG_N;
28
29	/// Returns an iterator over subdivisions of this underlier value, ordered from LSB to MSB.
30	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = T> + Send + Clone;
31
32	/// Returns an iterator over subdivisions of this underlier reference, ordered from LSB to MSB.
33	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_;
34
35	/// Returns an iterator over subdivisions of a slice of underliers, ordered from LSB to MSB.
36	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_;
37
38	/// Get element at index (LSB-first ordering).
39	///
40	/// # Panics
41	///
42	/// Panics if `index >= Self::N`.
43	fn get(self, index: usize) -> T;
44
45	/// Set element at index (LSB-first ordering), returning modified value.
46	///
47	/// # Panics
48	///
49	/// Panics if `index >= Self::N`.
50	fn set(self, index: usize, val: T) -> Self;
51
52	/// Create a value with `val` broadcast to all `N` positions.
53	fn broadcast(val: T) -> Self;
54
55	/// Construct a value from an iterator of elements.
56	///
57	/// Consumes at most `N` elements from the iterator. If the iterator
58	/// yields fewer than `N` elements, remaining positions are filled with zeros.
59	fn from_iter(iter: impl Iterator<Item = T>) -> Self;
60}
61
62/// Helper functions for Divisible implementations using bytemuck memory casting.
63///
64/// These functions handle the endianness-aware iteration over subdivisions of an underlier type.
65pub mod memcast {
66	use bytemuck::{Pod, Zeroable};
67
68	/// Returns an iterator over subdivisions of a value, ordered from LSB to MSB.
69	#[cfg(target_endian = "little")]
70	#[inline]
71	pub fn value_iter<Big, Small, const N: usize>(
72		value: Big,
73	) -> impl ExactSizeIterator<Item = Small> + Send + Clone
74	where
75		Big: Pod,
76		Small: Pod + Send,
77	{
78		bytemuck::must_cast::<Big, [Small; N]>(value).into_iter()
79	}
80
81	/// Returns an iterator over subdivisions of a value, ordered from LSB to MSB.
82	#[cfg(target_endian = "big")]
83	#[inline]
84	pub fn value_iter<Big, Small, const N: usize>(
85		value: Big,
86	) -> impl ExactSizeIterator<Item = Small> + Send + Clone
87	where
88		Big: Pod,
89		Small: Pod + Send,
90	{
91		bytemuck::must_cast::<Big, [Small; N]>(value)
92			.into_iter()
93			.rev()
94	}
95
96	/// Returns an iterator over subdivisions of a reference, ordered from LSB to MSB.
97	#[cfg(target_endian = "little")]
98	#[inline]
99	pub fn ref_iter<Big, Small, const N: usize>(
100		value: &Big,
101	) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
102	where
103		Big: Pod,
104		Small: Pod + Send + Sync,
105	{
106		bytemuck::must_cast_ref::<Big, [Small; N]>(value)
107			.iter()
108			.copied()
109	}
110
111	/// Returns an iterator over subdivisions of a reference, ordered from LSB to MSB.
112	#[cfg(target_endian = "big")]
113	#[inline]
114	pub fn ref_iter<Big, Small, const N: usize>(
115		value: &Big,
116	) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
117	where
118		Big: Pod,
119		Small: Pod + Send + Sync,
120	{
121		bytemuck::must_cast_ref::<Big, [Small; N]>(value)
122			.iter()
123			.rev()
124			.copied()
125	}
126
127	/// Returns an iterator over subdivisions of a slice, ordered from LSB to MSB.
128	#[cfg(target_endian = "little")]
129	#[inline]
130	pub fn slice_iter<Big, Small>(
131		slice: &[Big],
132	) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
133	where
134		Big: Pod,
135		Small: Pod + Send + Sync,
136	{
137		bytemuck::must_cast_slice::<Big, Small>(slice)
138			.iter()
139			.copied()
140	}
141
142	/// Returns an iterator over subdivisions of a slice, ordered from LSB to MSB.
143	///
144	/// For big-endian: iterate through the raw slice, but for each element's
145	/// subdivisions, reverse the index to maintain LSB-first ordering.
146	#[cfg(target_endian = "big")]
147	#[inline]
148	pub fn slice_iter<Big, Small, const LOG_N: usize>(
149		slice: &[Big],
150	) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
151	where
152		Big: Pod,
153		Small: Pod + Send + Sync,
154	{
155		const N: usize = 1 << LOG_N;
156		let raw_slice = bytemuck::must_cast_slice::<Big, Small>(slice);
157		(0..raw_slice.len()).map(move |i| {
158			let element_idx = i >> LOG_N;
159			let sub_idx = i & (N - 1);
160			let reversed_sub_idx = N - 1 - sub_idx;
161			let raw_idx = element_idx * N + reversed_sub_idx;
162			raw_slice[raw_idx]
163		})
164	}
165
166	/// Get element at index (LSB-first ordering).
167	#[cfg(target_endian = "little")]
168	#[inline]
169	pub fn get<Big, Small, const N: usize>(value: &Big, index: usize) -> Small
170	where
171		Big: Pod,
172		Small: Pod,
173	{
174		bytemuck::must_cast_ref::<Big, [Small; N]>(value)[index]
175	}
176
177	/// Get element at index (LSB-first ordering).
178	#[cfg(target_endian = "big")]
179	#[inline]
180	pub fn get<Big, Small, const N: usize>(value: &Big, index: usize) -> Small
181	where
182		Big: Pod,
183		Small: Pod,
184	{
185		bytemuck::must_cast_ref::<Big, [Small; N]>(value)[N - 1 - index]
186	}
187
188	/// Set element at index (LSB-first ordering), returning modified value.
189	#[cfg(target_endian = "little")]
190	#[inline]
191	pub fn set<Big, Small, const N: usize>(value: &Big, index: usize, val: Small) -> Big
192	where
193		Big: Pod,
194		Small: Pod,
195	{
196		let mut arr = *bytemuck::must_cast_ref::<Big, [Small; N]>(value);
197		arr[index] = val;
198		bytemuck::must_cast(arr)
199	}
200
201	/// Set element at index (LSB-first ordering), returning modified value.
202	#[cfg(target_endian = "big")]
203	#[inline]
204	pub fn set<Big, Small, const N: usize>(value: &Big, index: usize, val: Small) -> Big
205	where
206		Big: Pod,
207		Small: Pod,
208	{
209		let mut arr = *bytemuck::must_cast_ref::<Big, [Small; N]>(value);
210		arr[N - 1 - index] = val;
211		bytemuck::must_cast(arr)
212	}
213
214	/// Broadcast a value to all positions.
215	#[inline]
216	pub fn broadcast<Big, Small, const N: usize>(val: Small) -> Big
217	where
218		Big: Pod,
219		Small: Pod + Copy,
220	{
221		bytemuck::must_cast::<[Small; N], Big>([val; N])
222	}
223
224	/// Construct a value from an iterator of elements.
225	#[cfg(target_endian = "little")]
226	#[inline]
227	pub fn from_iter<Big, Small, const N: usize>(iter: impl Iterator<Item = Small>) -> Big
228	where
229		Big: Pod,
230		Small: Pod,
231	{
232		let mut arr: [Small; N] = Zeroable::zeroed();
233		for (i, val) in iter.take(N).enumerate() {
234			arr[i] = val;
235		}
236		bytemuck::must_cast(arr)
237	}
238
239	/// Construct a value from an iterator of elements.
240	#[cfg(target_endian = "big")]
241	#[inline]
242	pub fn from_iter<Big, Small, const N: usize>(iter: impl Iterator<Item = Small>) -> Big
243	where
244		Big: Pod,
245		Small: Pod,
246	{
247		let mut arr: [Small; N] = Zeroable::zeroed();
248		for (i, val) in iter.take(N).enumerate() {
249			arr[N - 1 - i] = val;
250		}
251		bytemuck::must_cast(arr)
252	}
253}
254
255/// Helper functions for Divisible implementations using bitmask operations on sub-byte elements.
256///
257/// These functions work on any type that implements `Divisible<u8>` by extracting
258/// and modifying sub-byte elements through the byte interface.
259pub mod bitmask {
260	use super::{Divisible, SmallU};
261
262	/// Get a sub-byte element at index (LSB-first ordering).
263	#[inline]
264	pub fn get<Big, const BITS: usize>(value: Big, index: usize) -> SmallU<BITS>
265	where
266		Big: Divisible<u8>,
267	{
268		let elems_per_byte = 8 / BITS;
269		let byte_index = index / elems_per_byte;
270		let sub_index = index % elems_per_byte;
271		let byte = Divisible::<u8>::get(value, byte_index);
272		let shift = sub_index * BITS;
273		SmallU::<BITS>::new(byte >> shift)
274	}
275
276	/// Set a sub-byte element at index (LSB-first ordering), returning modified value.
277	#[inline]
278	pub fn set<Big, const BITS: usize>(value: Big, index: usize, val: SmallU<BITS>) -> Big
279	where
280		Big: Divisible<u8>,
281	{
282		let elems_per_byte = 8 / BITS;
283		let byte_index = index / elems_per_byte;
284		let sub_index = index % elems_per_byte;
285		let byte = Divisible::<u8>::get(value, byte_index);
286		let shift = sub_index * BITS;
287		let mask = (1u8 << BITS) - 1;
288		let new_byte = (byte & !(mask << shift)) | (val.val() << shift);
289		Divisible::<u8>::set(value, byte_index, new_byte)
290	}
291}
292
293/// Helper functions for Divisible implementations using the get method.
294///
295/// These functions create iterators by mapping indices through `Divisible::get`,
296/// useful for SIMD types where extract intrinsics provide efficient element access.
297pub mod mapget {
298	use binius_utils::iter::IterExtensions;
299
300	use super::Divisible;
301
302	/// Create an iterator over subdivisions by mapping get over indices.
303	#[inline]
304	pub fn value_iter<Big, Small>(value: Big) -> impl ExactSizeIterator<Item = Small> + Send + Clone
305	where
306		Big: Divisible<Small> + Send,
307		Small: Send,
308	{
309		(0..Big::N).map_skippable(move |i| Divisible::<Small>::get(value, i))
310	}
311
312	/// Create a slice iterator by computing global index and using get.
313	#[inline]
314	pub fn slice_iter<Big, Small>(
315		slice: &[Big],
316	) -> impl ExactSizeIterator<Item = Small> + Send + Clone + '_
317	where
318		Big: Divisible<Small> + Send + Sync,
319		Small: Send,
320	{
321		let total = slice.len() * Big::N;
322		(0..total).map_skippable(move |global_idx| {
323			let elem_idx = global_idx / Big::N;
324			let sub_idx = global_idx % Big::N;
325			Divisible::<Small>::get(slice[elem_idx], sub_idx)
326		})
327	}
328}
329
330/// Iterator for dividing an underlier into sub-byte elements (ie. [`SmallU`]).
331///
332/// This iterator wraps a byte iterator and extracts sub-byte elements from each byte.
333/// Generic over the byte iterator type `I`.
334#[derive(Clone)]
335pub struct SmallUDivisIter<I, const N: usize> {
336	byte_iter: I,
337	current_byte: Option<u8>,
338	sub_idx: usize,
339}
340
341impl<I: Iterator<Item = u8>, const N: usize> SmallUDivisIter<I, N> {
342	const ELEMS_PER_BYTE: usize = 8 / N;
343
344	pub fn new(mut byte_iter: I) -> Self {
345		let current_byte = byte_iter.next();
346		Self {
347			byte_iter,
348			current_byte,
349			sub_idx: 0,
350		}
351	}
352}
353
354impl<I: ExactSizeIterator<Item = u8>, const N: usize> Iterator for SmallUDivisIter<I, N> {
355	type Item = SmallU<N>;
356
357	#[inline]
358	fn next(&mut self) -> Option<Self::Item> {
359		let byte = self.current_byte?;
360		let shift = self.sub_idx * N;
361		let result = SmallU::<N>::new(byte >> shift);
362
363		self.sub_idx += 1;
364		if self.sub_idx >= Self::ELEMS_PER_BYTE {
365			self.sub_idx = 0;
366			self.current_byte = self.byte_iter.next();
367		}
368
369		Some(result)
370	}
371
372	#[inline]
373	fn size_hint(&self) -> (usize, Option<usize>) {
374		let remaining_in_current = if self.current_byte.is_some() {
375			Self::ELEMS_PER_BYTE - self.sub_idx
376		} else {
377			0
378		};
379		let remaining_bytes = self.byte_iter.len();
380		let total = remaining_in_current + remaining_bytes * Self::ELEMS_PER_BYTE;
381		(total, Some(total))
382	}
383}
384
385impl<I: ExactSizeIterator<Item = u8>, const N: usize> ExactSizeIterator for SmallUDivisIter<I, N> {}
386
387/// Implements `Divisible` trait using bytemuck memory casting.
388///
389/// This macro generates `Divisible` implementations for a big type over smaller types.
390/// The implementations use the helper functions in the `memcast` module.
391macro_rules! impl_divisible_memcast {
392	($big:ty, $($small:ty),+) => {
393		$(
394			impl $crate::underlier::Divisible<$small> for $big {
395				const LOG_N: usize = (size_of::<$big>() / size_of::<$small>()).ilog2() as usize;
396
397				#[inline]
398				fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $small> + Send + Clone {
399					const N: usize = size_of::<$big>() / size_of::<$small>();
400					$crate::underlier::memcast::value_iter::<$big, $small, N>(value)
401				}
402
403				#[inline]
404				fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $small> + Send + Clone + '_ {
405					const N: usize = size_of::<$big>() / size_of::<$small>();
406					$crate::underlier::memcast::ref_iter::<$big, $small, N>(value)
407				}
408
409				#[inline]
410				#[cfg(target_endian = "little")]
411				fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $small> + Send + Clone + '_ {
412					$crate::underlier::memcast::slice_iter::<$big, $small>(slice)
413				}
414
415				#[inline]
416				#[cfg(target_endian = "big")]
417				fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $small> + Send + Clone + '_ {
418					const LOG_N: usize = (size_of::<$big>() / size_of::<$small>()).ilog2() as usize;
419					$crate::underlier::memcast::slice_iter::<$big, $small, LOG_N>(slice)
420				}
421
422				#[inline]
423				fn get(self, index: usize) -> $small {
424					const N: usize = size_of::<$big>() / size_of::<$small>();
425					$crate::underlier::memcast::get::<$big, $small, N>(&self, index)
426				}
427
428				#[inline]
429				fn set(self, index: usize, val: $small) -> Self {
430					const N: usize = size_of::<$big>() / size_of::<$small>();
431					$crate::underlier::memcast::set::<$big, $small, N>(&self, index, val)
432				}
433
434				#[inline]
435				fn broadcast(val: $small) -> Self {
436					const N: usize = size_of::<$big>() / size_of::<$small>();
437					$crate::underlier::memcast::broadcast::<$big, $small, N>(val)
438				}
439
440				#[inline]
441				fn from_iter(iter: impl Iterator<Item = $small>) -> Self {
442					const N: usize = size_of::<$big>() / size_of::<$small>();
443					$crate::underlier::memcast::from_iter::<$big, $small, N>(iter)
444				}
445			}
446		)+
447	};
448}
449
450#[allow(unused)]
451pub(crate) use impl_divisible_memcast;
452
453/// Implements `Divisible` trait for SmallU types using bitmask operations.
454///
455/// This macro generates `Divisible<SmallU<BITS>>` implementations for a big type
456/// by wrapping byte iteration with bitmasking to extract sub-byte elements.
457macro_rules! impl_divisible_bitmask {
458	// Special case for u8: operates directly on the byte without needing Divisible::<u8>
459	(u8, $($bits:expr),+) => {
460		$(
461			impl $crate::underlier::Divisible<$crate::underlier::SmallU<$bits>> for u8 {
462				const LOG_N: usize = (8usize / $bits).ilog2() as usize;
463
464				#[inline]
465				fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone {
466					$crate::underlier::SmallUDivisIter::new(std::iter::once(value))
467				}
468
469				#[inline]
470				fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
471					$crate::underlier::SmallUDivisIter::new(std::iter::once(*value))
472				}
473
474				#[inline]
475				fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
476					$crate::underlier::SmallUDivisIter::new(slice.iter().copied())
477				}
478
479				#[inline]
480				fn get(self, index: usize) -> $crate::underlier::SmallU<$bits> {
481					let shift = index * $bits;
482					$crate::underlier::SmallU::<$bits>::new(self >> shift)
483				}
484
485				#[inline]
486				fn set(self, index: usize, val: $crate::underlier::SmallU<$bits>) -> Self {
487					let shift = index * $bits;
488					let mask = (1u8 << $bits) - 1;
489					(self & !(mask << shift)) | (val.val() << shift)
490				}
491
492				#[inline]
493				fn broadcast(val: $crate::underlier::SmallU<$bits>) -> Self {
494					let mut result = val.val();
495					// Self-replicate to fill the byte
496					let mut current_bits = $bits;
497					while current_bits < 8 {
498						result |= result << current_bits;
499						current_bits *= 2;
500					}
501					result
502				}
503
504				#[inline]
505				fn from_iter(iter: impl Iterator<Item = $crate::underlier::SmallU<$bits>>) -> Self {
506					const N: usize = 8 / $bits;
507					let mut result: Self = 0;
508					for (i, val) in iter.take(N).enumerate() {
509						result = $crate::underlier::Divisible::<$crate::underlier::SmallU<$bits>>::set(result, i, val);
510					}
511					result
512				}
513			}
514		)+
515	};
516
517	// General case for types larger than u8: wraps byte iteration
518	($big:ty, $($bits:expr),+) => {
519		$(
520			impl $crate::underlier::Divisible<$crate::underlier::SmallU<$bits>> for $big {
521				const LOG_N: usize = (8 * size_of::<$big>() / $bits).ilog2() as usize;
522
523				#[inline]
524				fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone {
525					$crate::underlier::SmallUDivisIter::new(
526						$crate::underlier::Divisible::<u8>::value_iter(value)
527					)
528				}
529
530				#[inline]
531				fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
532					$crate::underlier::SmallUDivisIter::new(
533						$crate::underlier::Divisible::<u8>::ref_iter(value)
534					)
535				}
536
537				#[inline]
538				fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $crate::underlier::SmallU<$bits>> + Send + Clone + '_ {
539					$crate::underlier::SmallUDivisIter::new(
540						$crate::underlier::Divisible::<u8>::slice_iter(slice)
541					)
542				}
543
544				#[inline]
545				fn get(self, index: usize) -> $crate::underlier::SmallU<$bits> {
546					$crate::underlier::bitmask::get::<Self, $bits>(self, index)
547				}
548
549				#[inline]
550				fn set(self, index: usize, val: $crate::underlier::SmallU<$bits>) -> Self {
551					$crate::underlier::bitmask::set::<Self, $bits>(self, index, val)
552				}
553
554				#[inline]
555				fn broadcast(val: $crate::underlier::SmallU<$bits>) -> Self {
556					// First splat to u8, then splat the byte to fill Self
557					let byte = $crate::underlier::Divisible::<$crate::underlier::SmallU<$bits>>::broadcast(val);
558					$crate::underlier::Divisible::<u8>::broadcast(byte)
559				}
560
561				#[inline]
562				fn from_iter(iter: impl Iterator<Item = $crate::underlier::SmallU<$bits>>) -> Self {
563					const N: usize = 8 * size_of::<$big>() / $bits;
564					let mut result: Self = bytemuck::Zeroable::zeroed();
565					for (i, val) in iter.take(N).enumerate() {
566						result = $crate::underlier::Divisible::<$crate::underlier::SmallU<$bits>>::set(result, i, val);
567					}
568					result
569				}
570			}
571		)+
572	};
573}
574
575#[allow(unused)]
576pub(crate) use impl_divisible_bitmask;
577
578use super::small_uint::SmallU;
579
580// Implement Divisible using memcast for primitive types
581impl_divisible_memcast!(u128, u64, u32, u16, u8);
582impl_divisible_memcast!(u64, u32, u16, u8);
583impl_divisible_memcast!(u32, u16, u8);
584impl_divisible_memcast!(u16, u8);
585
586// Implement Divisible using bitmask for SmallU types
587impl_divisible_bitmask!(u8, 1, 2, 4);
588impl_divisible_bitmask!(u16, 1, 2, 4);
589impl_divisible_bitmask!(u32, 1, 2, 4);
590impl_divisible_bitmask!(u64, 1, 2, 4);
591impl_divisible_bitmask!(u128, 1, 2, 4);
592
593// Divisible for SmallU types that subdivide into smaller SmallU types
594impl Divisible<SmallU<1>> for SmallU<2> {
595	const LOG_N: usize = 1;
596
597	#[inline]
598	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone {
599		mapget::value_iter(value)
600	}
601
602	#[inline]
603	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
604		mapget::value_iter(*value)
605	}
606
607	#[inline]
608	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
609		mapget::slice_iter(slice)
610	}
611
612	#[inline]
613	fn get(self, index: usize) -> SmallU<1> {
614		SmallU::<1>::new(self.val() >> index)
615	}
616
617	#[inline]
618	fn set(self, index: usize, val: SmallU<1>) -> Self {
619		let mask = 1u8 << index;
620		SmallU::<2>::new((self.val() & !mask) | (val.val() << index))
621	}
622
623	#[inline]
624	fn broadcast(val: SmallU<1>) -> Self {
625		// 0b0 -> 0b00, 0b1 -> 0b11
626		let v = val.val();
627		SmallU::<2>::new(v | (v << 1))
628	}
629
630	#[inline]
631	fn from_iter(iter: impl Iterator<Item = SmallU<1>>) -> Self {
632		iter.chain(std::iter::repeat(SmallU::<1>::new(0)))
633			.take(2)
634			.enumerate()
635			.fold(SmallU::<2>::new(0), |acc, (i, val)| acc.set(i, val))
636	}
637}
638
639impl Divisible<SmallU<1>> for SmallU<4> {
640	const LOG_N: usize = 2;
641
642	#[inline]
643	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone {
644		mapget::value_iter(value)
645	}
646
647	#[inline]
648	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
649		mapget::value_iter(*value)
650	}
651
652	#[inline]
653	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = SmallU<1>> + Send + Clone + '_ {
654		mapget::slice_iter(slice)
655	}
656
657	#[inline]
658	fn get(self, index: usize) -> SmallU<1> {
659		SmallU::<1>::new(self.val() >> index)
660	}
661
662	#[inline]
663	fn set(self, index: usize, val: SmallU<1>) -> Self {
664		let mask = 1u8 << index;
665		SmallU::<4>::new((self.val() & !mask) | (val.val() << index))
666	}
667
668	#[inline]
669	fn broadcast(val: SmallU<1>) -> Self {
670		// 0b0 -> 0b0000, 0b1 -> 0b1111
671		let mut v = val.val();
672		v |= v << 1;
673		v |= v << 2;
674		SmallU::<4>::new(v)
675	}
676
677	#[inline]
678	fn from_iter(iter: impl Iterator<Item = SmallU<1>>) -> Self {
679		iter.chain(std::iter::repeat(SmallU::<1>::new(0)))
680			.take(4)
681			.enumerate()
682			.fold(SmallU::<4>::new(0), |acc, (i, val)| acc.set(i, val))
683	}
684}
685
686impl Divisible<SmallU<2>> for SmallU<4> {
687	const LOG_N: usize = 1;
688
689	#[inline]
690	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = SmallU<2>> + Send + Clone {
691		mapget::value_iter(value)
692	}
693
694	#[inline]
695	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = SmallU<2>> + Send + Clone + '_ {
696		mapget::value_iter(*value)
697	}
698
699	#[inline]
700	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = SmallU<2>> + Send + Clone + '_ {
701		mapget::slice_iter(slice)
702	}
703
704	#[inline]
705	fn get(self, index: usize) -> SmallU<2> {
706		SmallU::<2>::new(self.val() >> (index * 2))
707	}
708
709	#[inline]
710	fn set(self, index: usize, val: SmallU<2>) -> Self {
711		let shift = index * 2;
712		let mask = 0b11u8 << shift;
713		SmallU::<4>::new((self.val() & !mask) | (val.val() << shift))
714	}
715
716	#[inline]
717	fn broadcast(val: SmallU<2>) -> Self {
718		// 0bXX -> 0bXXXX
719		let v = val.val();
720		SmallU::<4>::new(v | (v << 2))
721	}
722
723	#[inline]
724	fn from_iter(iter: impl Iterator<Item = SmallU<2>>) -> Self {
725		iter.chain(std::iter::repeat(SmallU::<2>::new(0)))
726			.take(2)
727			.enumerate()
728			.fold(SmallU::<4>::new(0), |acc, (i, val)| acc.set(i, val))
729	}
730}
731
732/// Implements reflexive `Divisible<Self>` for a type (dividing into itself once).
733macro_rules! impl_divisible_self {
734	($($ty:ty),+) => {
735		$(
736			impl Divisible<$ty> for $ty {
737				const LOG_N: usize = 0;
738
739				#[inline]
740				fn value_iter(value: Self) -> impl ExactSizeIterator<Item = $ty> + Send + Clone {
741					std::iter::once(value)
742				}
743
744				#[inline]
745				fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = $ty> + Send + Clone + '_ {
746					std::iter::once(*value)
747				}
748
749				#[inline]
750				fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = $ty> + Send + Clone + '_ {
751					slice.iter().copied()
752				}
753
754				#[inline]
755				fn get(self, index: usize) -> $ty {
756					debug_assert_eq!(index, 0);
757					self
758				}
759
760				#[inline]
761				fn set(self, index: usize, val: $ty) -> Self {
762					debug_assert_eq!(index, 0);
763					val
764				}
765
766				#[inline]
767				fn broadcast(val: $ty) -> Self {
768					val
769				}
770
771				#[inline]
772				fn from_iter(mut iter: impl Iterator<Item = $ty>) -> Self {
773					iter.next().unwrap_or_else(bytemuck::Zeroable::zeroed)
774				}
775			}
776		)+
777	};
778}
779
780impl_divisible_self!(u8, u16, u32, u64, u128, SmallU<1>, SmallU<2>, SmallU<4>);
781
782#[cfg(test)]
783mod tests {
784	use super::*;
785	use crate::underlier::small_uint::{U1, U2, U4};
786
787	#[test]
788	fn test_divisible_u8_u4() {
789		let val: u8 = 0x34;
790
791		// Test get - LSB first: nibbles
792		assert_eq!(Divisible::<U4>::get(val, 0), U4::new(0x4));
793		assert_eq!(Divisible::<U4>::get(val, 1), U4::new(0x3));
794
795		// Test set
796		let modified = Divisible::<U4>::set(val, 0, U4::new(0xF));
797		assert_eq!(modified, 0x3F);
798		let modified = Divisible::<U4>::set(val, 1, U4::new(0xA));
799		assert_eq!(modified, 0xA4);
800
801		// Test ref_iter
802		let parts: Vec<U4> = Divisible::<U4>::ref_iter(&val).collect();
803		assert_eq!(parts.len(), 2);
804		assert_eq!(parts[0], U4::new(0x4));
805		assert_eq!(parts[1], U4::new(0x3));
806
807		// Test value_iter
808		let parts: Vec<U4> = Divisible::<U4>::value_iter(val).collect();
809		assert_eq!(parts.len(), 2);
810		assert_eq!(parts[0], U4::new(0x4));
811		assert_eq!(parts[1], U4::new(0x3));
812
813		// Test slice_iter
814		let vals = [0x34u8, 0x56u8];
815		let parts: Vec<U4> = Divisible::<U4>::slice_iter(&vals).collect();
816		assert_eq!(parts.len(), 4);
817		assert_eq!(parts[0], U4::new(0x4));
818		assert_eq!(parts[1], U4::new(0x3));
819		assert_eq!(parts[2], U4::new(0x6));
820		assert_eq!(parts[3], U4::new(0x5));
821	}
822
823	#[test]
824	fn test_divisible_u16_u4() {
825		let val: u16 = 0x1234;
826
827		// Test get - LSB first: nibbles
828		assert_eq!(Divisible::<U4>::get(val, 0), U4::new(0x4));
829		assert_eq!(Divisible::<U4>::get(val, 1), U4::new(0x3));
830		assert_eq!(Divisible::<U4>::get(val, 2), U4::new(0x2));
831		assert_eq!(Divisible::<U4>::get(val, 3), U4::new(0x1));
832
833		// Test set
834		let modified = Divisible::<U4>::set(val, 1, U4::new(0xF));
835		assert_eq!(modified, 0x12F4);
836
837		// Test ref_iter
838		let parts: Vec<U4> = Divisible::<U4>::ref_iter(&val).collect();
839		assert_eq!(parts.len(), 4);
840		assert_eq!(parts[0], U4::new(0x4));
841		assert_eq!(parts[3], U4::new(0x1));
842	}
843
844	#[test]
845	fn test_divisible_u16_u2() {
846		// 0b1011_0010_1101_0011 = 0xB2D3
847		let val: u16 = 0b1011001011010011;
848
849		// Test get - LSB first: 2-bit chunks
850		assert_eq!(Divisible::<U2>::get(val, 0), U2::new(0b11)); // bits 0-1
851		assert_eq!(Divisible::<U2>::get(val, 1), U2::new(0b00)); // bits 2-3
852		assert_eq!(Divisible::<U2>::get(val, 7), U2::new(0b10)); // bits 14-15
853
854		// Test ref_iter
855		let parts: Vec<U2> = Divisible::<U2>::ref_iter(&val).collect();
856		assert_eq!(parts.len(), 8);
857		assert_eq!(parts[0], U2::new(0b11));
858		assert_eq!(parts[7], U2::new(0b10));
859	}
860
861	#[test]
862	fn test_divisible_u16_u1() {
863		// 0b1010_1100_0011_0101 = 0xAC35
864		let val: u16 = 0b1010110000110101;
865
866		// Test get - LSB first: individual bits
867		assert_eq!(Divisible::<U1>::get(val, 0), U1::new(1)); // bit 0
868		assert_eq!(Divisible::<U1>::get(val, 1), U1::new(0)); // bit 1
869		assert_eq!(Divisible::<U1>::get(val, 15), U1::new(1)); // bit 15
870
871		// Test set
872		let modified = Divisible::<U1>::set(val, 0, U1::new(0));
873		assert_eq!(modified, 0b1010110000110100);
874
875		// Test ref_iter
876		let parts: Vec<U1> = Divisible::<U1>::ref_iter(&val).collect();
877		assert_eq!(parts.len(), 16);
878		assert_eq!(parts[0], U1::new(1));
879		assert_eq!(parts[15], U1::new(1));
880	}
881
882	#[test]
883	fn test_divisible_u64_u4() {
884		let val: u64 = 0x123456789ABCDEF0;
885
886		// Test get - LSB first: nibbles
887		assert_eq!(Divisible::<U4>::get(val, 0), U4::new(0x0));
888		assert_eq!(Divisible::<U4>::get(val, 1), U4::new(0xF));
889		assert_eq!(Divisible::<U4>::get(val, 15), U4::new(0x1));
890
891		// Test ref_iter
892		let parts: Vec<U4> = Divisible::<U4>::ref_iter(&val).collect();
893		assert_eq!(parts.len(), 16);
894	}
895
896	#[test]
897	fn test_divisible_u32_u8_slice() {
898		let vals: [u32; 2] = [0x04030201, 0x08070605];
899
900		// Test slice_iter
901		let parts: Vec<u8> = Divisible::<u8>::slice_iter(&vals).collect();
902		assert_eq!(parts.len(), 8);
903		// LSB-first ordering within each u32
904		assert_eq!(parts[0], 0x01);
905		assert_eq!(parts[1], 0x02);
906		assert_eq!(parts[2], 0x03);
907		assert_eq!(parts[3], 0x04);
908		assert_eq!(parts[4], 0x05);
909		assert_eq!(parts[5], 0x06);
910		assert_eq!(parts[6], 0x07);
911		assert_eq!(parts[7], 0x08);
912	}
913
914	#[test]
915	fn test_broadcast_u32_u8() {
916		let result: u32 = Divisible::<u8>::broadcast(0xAB);
917		assert_eq!(result, 0xABABABAB);
918	}
919
920	#[test]
921	fn test_broadcast_u64_u16() {
922		let result: u64 = Divisible::<u16>::broadcast(0x1234);
923		assert_eq!(result, 0x1234123412341234);
924	}
925
926	#[test]
927	fn test_broadcast_u128_u32() {
928		let result: u128 = Divisible::<u32>::broadcast(0xDEADBEEF);
929		assert_eq!(result, 0xDEADBEEFDEADBEEFDEADBEEFDEADBEEF);
930	}
931
932	#[test]
933	fn test_broadcast_u8_u4() {
934		let result: u8 = Divisible::<U4>::broadcast(U4::new(0x5));
935		assert_eq!(result, 0x55);
936	}
937
938	#[test]
939	fn test_broadcast_u16_u4() {
940		let result: u16 = Divisible::<U4>::broadcast(U4::new(0xA));
941		assert_eq!(result, 0xAAAA);
942	}
943
944	#[test]
945	fn test_broadcast_u8_u2() {
946		let result: u8 = Divisible::<U2>::broadcast(U2::new(0b11));
947		assert_eq!(result, 0xFF);
948		let result: u8 = Divisible::<U2>::broadcast(U2::new(0b01));
949		assert_eq!(result, 0x55);
950	}
951
952	#[test]
953	fn test_broadcast_u8_u1() {
954		let result: u8 = Divisible::<U1>::broadcast(U1::new(0));
955		assert_eq!(result, 0x00);
956		let result: u8 = Divisible::<U1>::broadcast(U1::new(1));
957		assert_eq!(result, 0xFF);
958	}
959
960	#[test]
961	fn test_broadcast_smallu2_from_smallu1() {
962		let result: SmallU<2> = Divisible::<SmallU<1>>::broadcast(SmallU::<1>::new(0));
963		assert_eq!(result.val(), 0b00);
964		let result: SmallU<2> = Divisible::<SmallU<1>>::broadcast(SmallU::<1>::new(1));
965		assert_eq!(result.val(), 0b11);
966	}
967
968	#[test]
969	fn test_broadcast_smallu4_from_smallu1() {
970		let result: SmallU<4> = Divisible::<SmallU<1>>::broadcast(SmallU::<1>::new(0));
971		assert_eq!(result.val(), 0b0000);
972		let result: SmallU<4> = Divisible::<SmallU<1>>::broadcast(SmallU::<1>::new(1));
973		assert_eq!(result.val(), 0b1111);
974	}
975
976	#[test]
977	fn test_broadcast_smallu4_from_smallu2() {
978		let result: SmallU<4> = Divisible::<SmallU<2>>::broadcast(SmallU::<2>::new(0b10));
979		assert_eq!(result.val(), 0b1010);
980	}
981
982	#[test]
983	fn test_broadcast_reflexive() {
984		let result: u64 = Divisible::<u64>::broadcast(0x123456789ABCDEF0);
985		assert_eq!(result, 0x123456789ABCDEF0);
986	}
987
988	#[test]
989	fn test_from_iter_full() {
990		let result: u32 = Divisible::<u8>::from_iter([0x01, 0x02, 0x03, 0x04].into_iter());
991		assert_eq!(result, 0x04030201);
992	}
993
994	#[test]
995	fn test_from_iter_partial() {
996		// Only 2 elements, remaining should be 0
997		let result: u32 = Divisible::<u8>::from_iter([0xAB, 0xCD].into_iter());
998		assert_eq!(result, 0x0000CDAB);
999	}
1000
1001	#[test]
1002	fn test_from_iter_empty() {
1003		let result: u32 = Divisible::<u8>::from_iter(std::iter::empty());
1004		assert_eq!(result, 0);
1005	}
1006
1007	#[test]
1008	fn test_from_iter_excess() {
1009		// More than N elements, only first 4 should be consumed
1010		let result: u32 =
1011			Divisible::<u8>::from_iter([0x01, 0x02, 0x03, 0x04, 0x05, 0x06].into_iter());
1012		assert_eq!(result, 0x04030201);
1013	}
1014
1015	#[test]
1016	fn test_from_iter_u64_u16() {
1017		let result: u64 = Divisible::<u16>::from_iter([0x1234, 0x5678, 0x9ABC].into_iter());
1018		// Only 3 elements provided, 4th should be 0
1019		assert_eq!(result, 0x0000_9ABC_5678_1234);
1020	}
1021
1022	#[test]
1023	fn test_from_iter_smallu() {
1024		let result: u8 = Divisible::<U4>::from_iter([U4::new(0xA), U4::new(0xB)].into_iter());
1025		assert_eq!(result, 0xBA);
1026	}
1027}