binius_field/underlier/
underlier_with_bit_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8	U1, U2, U4,
9	underlier_type::{NumCast, UnderlierType},
10};
11use crate::Divisible;
12
13/// Underlier type that supports bit arithmetic.
14pub trait UnderlierWithBitOps:
15	UnderlierType
16	+ BitAnd<Self, Output = Self>
17	+ BitAndAssign<Self>
18	+ BitOr<Self, Output = Self>
19	+ BitOrAssign<Self>
20	+ BitXor<Self, Output = Self>
21	+ BitXorAssign<Self>
22	+ Shr<usize, Output = Self>
23	+ Shl<usize, Output = Self>
24	+ Not<Output = Self>
25{
26	const ZERO: Self;
27	const ONE: Self;
28	const ONES: Self;
29
30	/// Fill value with the given bit
31	/// `val` must be 0 or 1.
32	fn fill_with_bit(val: u8) -> Self;
33
34	#[inline]
35	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
36	where
37		T: UnderlierType,
38		Self: From<T>,
39	{
40		// This implementation is optimal for the case when `Self` us u8..u128.
41		// For SIMD types/arrays specialization would be more performant.
42		let mut result = Self::default();
43		let width = checked_int_div(Self::BITS, T::BITS);
44		for i in 0..width {
45			result |= Self::from(f(i)) << (i * T::BITS);
46		}
47
48		result
49	}
50
51	/// Broadcast subvalue to fill `Self`.
52	/// `Self::BITS/T::BITS` is supposed to be a power of 2.
53	#[inline]
54	fn broadcast_subvalue<T>(value: T) -> Self
55	where
56		T: UnderlierType,
57		Self: From<T>,
58	{
59		// This implementation is optimal for the case when `Self` us u8..u128.
60		// For SIMD types/arrays specialization would be more performant.
61		let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
62		let mut result = Self::from(value);
63		for i in 0..height {
64			result |= result << ((1 << i) * T::BITS);
65		}
66
67		result
68	}
69
70	/// Gets the subvalue from the given position.
71	/// Function panics in case when index is out of range.
72	///
73	/// # Safety
74	/// `i` must be less than `Self::BITS/T::BITS`.
75	#[inline]
76	unsafe fn get_subvalue<T>(&self, i: usize) -> T
77	where
78		T: UnderlierType,
79		Self: Divisible<T>,
80	{
81		debug_assert!(
82			i < checked_int_div(Self::BITS, T::BITS),
83			"i: {} Self::BITS: {}, T::BITS: {}",
84			i,
85			Self::BITS,
86			T::BITS
87		);
88		Divisible::<T>::get(*self, i)
89	}
90
91	/// Sets the subvalue in the given position.
92	/// Function panics in case when index is out of range.
93	///
94	/// # Safety
95	/// `i` must be less than `Self::BITS/T::BITS`.
96	#[inline]
97	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
98	where
99		T: UnderlierWithBitOps,
100		Self: Divisible<T>,
101	{
102		debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
103		*self = (*self).set(i, val);
104	}
105
106	/// Spread takes a block of sub_elements of `T` type within the current value and
107	/// repeats them to the full underlier width.
108	///
109	/// # Safety
110	/// `log_block_len + T::LOG_BITS` must be less than or equal to `Self::LOG_BITS`.
111	/// `block_idx` must be less than `1 << (Self::LOG_BITS - log_block_len)`.
112	#[inline]
113	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
114	where
115		T: UnderlierWithBitOps + NumCast<Self>,
116		Self: Divisible<T> + From<T>,
117	{
118		unsafe { spread_fallback(self, log_block_len, block_idx) }
119	}
120}
121
122/// Fallback implementation of `spread` method.
123///
124/// # Safety
125/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
126/// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`.
127pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
128where
129	U: UnderlierWithBitOps + From<T> + Divisible<T>,
130	T: UnderlierWithBitOps + NumCast<U>,
131{
132	debug_assert!(
133		log_block_len + T::LOG_BITS <= U::LOG_BITS,
134		"log_block_len: {}, U::BITS: {}, T::BITS: {}",
135		log_block_len,
136		U::BITS,
137		T::BITS
138	);
139	debug_assert!(
140		block_idx < 1 << (U::LOG_BITS - log_block_len),
141		"block_idx: {}, U::BITS: {}, log_block_len: {}",
142		block_idx,
143		U::BITS,
144		log_block_len
145	);
146
147	let mut result = U::ZERO;
148	let block_offset = block_idx << log_block_len;
149	let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
150	for i in 0..1 << log_block_len {
151		unsafe {
152			result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
153		}
154	}
155
156	for i in 0..log_repeat {
157		result |= result << (1 << (T::LOG_BITS + i));
158	}
159
160	result
161}
162
163#[cfg(test)]
164#[allow(unused)]
165pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
166	if bits_count == T::BITS {
167		!T::ZERO
168	} else {
169		let mut result = T::ONE;
170		for height in 0..checked_log_2(bits_count) {
171			result |= result << (1 << height)
172		}
173
174		result
175	}
176}
177
178/// Value that can be spread to a single u8
179pub(crate) trait SpreadToByte {
180	fn spread_to_byte(self) -> u8;
181}
182
183impl SpreadToByte for U1 {
184	#[inline(always)]
185	fn spread_to_byte(self) -> u8 {
186		u8::fill_with_bit(self.val())
187	}
188}
189
190impl SpreadToByte for U2 {
191	#[inline(always)]
192	fn spread_to_byte(self) -> u8 {
193		let mut result = self.val();
194		result |= result << 2;
195		result |= result << 4;
196
197		result
198	}
199}
200
201impl SpreadToByte for U4 {
202	#[inline(always)]
203	fn spread_to_byte(self) -> u8 {
204		let mut result = self.val();
205		result |= result << 4;
206
207		result
208	}
209}
210
211/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
212///
213/// # Safety
214/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
215#[allow(unused)]
216#[inline(always)]
217pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
218	value: U,
219	block_idx: usize,
220) -> [T; BLOCK_LEN]
221where
222	U: UnderlierWithBitOps + From<T> + Divisible<T>,
223	T: UnderlierType + NumCast<U>,
224{
225	std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
226}
227
228/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
229///
230/// # Safety
231/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
232#[allow(unused)]
233#[inline(always)]
234pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
235	value: U,
236	block_idx: usize,
237) -> [u8; BLOCK_LEN]
238where
239	U: UnderlierWithBitOps + From<T> + Divisible<T>,
240	T: UnderlierType + SpreadToByte + NumCast<U>,
241{
242	unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
243		.map(SpreadToByte::spread_to_byte)
244}
245
246#[cfg(test)]
247mod tests {
248	use proptest::{arbitrary::any, bits, proptest};
249
250	use super::{
251		super::small_uint::{U1, U2, U4},
252		*,
253	};
254
255	#[test]
256	fn test_from_fn() {
257		assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
258		assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
259		assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
260
261		assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
262		assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
263		assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
264		assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
265		assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
266
267		assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
268		assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
269		assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
270		assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
271		assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
272
273		assert_eq!(u32::from_fn(|_| 0u8), 0);
274		assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
275		assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
276		assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
277	}
278
279	#[test]
280	fn test_broadcast_subvalue() {
281		assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
282		assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
283
284		assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
285		assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
286		assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
287		assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
288
289		assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
290		assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
291		assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
292		assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
293
294		assert_eq!(u32::broadcast_subvalue(0u8), 0);
295		assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
296		assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
297	}
298
299	#[test]
300	fn test_get_subvalue() {
301		let value = 0xab12cd34u32;
302
303		unsafe {
304			assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
305			assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
306			assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
307			assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
308
309			assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
310			assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
311			assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
312			assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
313
314			assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
315			assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
316			assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
317			assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
318
319			assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
320			assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
321			assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
322			assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
323		}
324	}
325
326	proptest! {
327		#[test]
328		fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
329			unsafe {
330				init_val.set_subvalue(i, U1::new(val));
331				assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
332			}
333		}
334
335		#[test]
336		fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
337			unsafe {
338				init_val.set_subvalue(i, U2::new(val));
339				assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
340			}
341		}
342
343		#[test]
344		fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
345			unsafe {
346				init_val.set_subvalue(i, U4::new(val));
347				assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
348			}
349		}
350
351		#[test]
352		fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
353			unsafe {
354				init_val.set_subvalue(i, val);
355				assert_eq!(init_val.get_subvalue::<u8>(i), val);
356			}
357		}
358	}
359}