binius_field/underlier/
underlier_with_bit_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::checked_int_div;
6
7use super::{
8	U1, U2, U4,
9	underlier_type::{NumCast, UnderlierType},
10};
11use crate::Divisible;
12
13/// Underlier type that supports bit arithmetic.
14pub trait UnderlierWithBitOps:
15	UnderlierType
16	+ BitAnd<Self, Output = Self>
17	+ BitAndAssign<Self>
18	+ BitOr<Self, Output = Self>
19	+ BitOrAssign<Self>
20	+ BitXor<Self, Output = Self>
21	+ BitXorAssign<Self>
22	+ Shr<usize, Output = Self>
23	+ Shl<usize, Output = Self>
24	+ Not<Output = Self>
25{
26	const ZERO: Self;
27	const ONE: Self;
28	const ONES: Self;
29
30	/// Fill value with the given bit
31	/// `val` must be 0 or 1.
32	fn fill_with_bit(val: u8) -> Self;
33
34	/// Interleave with the given bit size
35	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
36
37	/// Transpose with the given bit size
38	fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) {
39		assert!(log_block_len < Self::LOG_BITS);
40
41		for log_block_len in (log_block_len..Self::LOG_BITS).rev() {
42			(self, other) = self.interleave(other, log_block_len);
43		}
44
45		(self, other)
46	}
47
48	#[inline]
49	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self
50	where
51		T: UnderlierType,
52		Self: Divisible<T>,
53	{
54		Self::from_iter((0..Self::N).map(f))
55	}
56
57	/// Broadcast subvalue to fill `Self`.
58	/// `Self::BITS/T::BITS` is supposed to be a power of 2.
59	#[inline]
60	fn broadcast_subvalue<T>(value: T) -> Self
61	where
62		T: UnderlierType,
63		Self: Divisible<T>,
64	{
65		Divisible::<T>::broadcast(value)
66	}
67
68	/// Gets the subvalue from the given position.
69	/// Function panics in case when index is out of range.
70	///
71	/// # Safety
72	/// `i` must be less than `Self::BITS/T::BITS`.
73	#[inline]
74	unsafe fn get_subvalue<T>(&self, i: usize) -> T
75	where
76		T: UnderlierType,
77		Self: Divisible<T>,
78	{
79		debug_assert!(
80			i < checked_int_div(Self::BITS, T::BITS),
81			"i: {} Self::BITS: {}, T::BITS: {}",
82			i,
83			Self::BITS,
84			T::BITS
85		);
86		Divisible::<T>::get(*self, i)
87	}
88
89	/// Sets the subvalue in the given position.
90	/// Function panics in case when index is out of range.
91	///
92	/// # Safety
93	/// `i` must be less than `Self::BITS/T::BITS`.
94	#[inline]
95	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
96	where
97		T: UnderlierWithBitOps,
98		Self: Divisible<T>,
99	{
100		debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
101		*self = (*self).set(i, val);
102	}
103
104	/// Spread takes a block of sub_elements of `T` type within the current value and
105	/// repeats them to the full underlier width.
106	///
107	/// # Safety
108	/// `log_block_len + T::LOG_BITS` must be less than or equal to `Self::LOG_BITS`.
109	/// `block_idx` must be less than `1 << (Self::LOG_BITS - log_block_len)`.
110	#[inline]
111	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
112	where
113		T: UnderlierWithBitOps,
114		Self: Divisible<T>,
115	{
116		unsafe { spread_fallback(self, log_block_len, block_idx) }
117	}
118}
119
120/// Fallback implementation of `spread` method.
121///
122/// # Safety
123/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
124/// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`.
125pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
126where
127	U: UnderlierWithBitOps + Divisible<T>,
128	T: UnderlierWithBitOps,
129{
130	debug_assert!(
131		log_block_len + T::LOG_BITS <= U::LOG_BITS,
132		"log_block_len: {}, U::BITS: {}, T::BITS: {}",
133		log_block_len,
134		U::BITS,
135		T::BITS
136	);
137	debug_assert!(
138		block_idx < 1 << (U::LOG_BITS - log_block_len),
139		"block_idx: {}, U::BITS: {}, log_block_len: {}",
140		block_idx,
141		U::BITS,
142		log_block_len
143	);
144
145	let mut result = U::ZERO;
146	let block_offset = block_idx << log_block_len;
147	let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
148	for i in 0..1 << log_block_len {
149		unsafe {
150			result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
151		}
152	}
153
154	for i in 0..log_repeat {
155		result |= result << (1 << (T::LOG_BITS + i));
156	}
157
158	result
159}
160
161#[cfg(test)]
162#[allow(unused)]
163pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
164	use binius_utils::checked_arithmetics::checked_log_2;
165
166	if bits_count == T::BITS {
167		!T::ZERO
168	} else {
169		let mut result = T::ONE;
170		for height in 0..checked_log_2(bits_count) {
171			result |= result << (1 << height)
172		}
173
174		result
175	}
176}
177
178/// Value that can be spread to a single u8
179pub(crate) trait SpreadToByte {
180	fn spread_to_byte(self) -> u8;
181}
182
183impl SpreadToByte for U1 {
184	#[inline(always)]
185	fn spread_to_byte(self) -> u8 {
186		u8::fill_with_bit(self.val())
187	}
188}
189
190impl SpreadToByte for U2 {
191	#[inline(always)]
192	fn spread_to_byte(self) -> u8 {
193		let mut result = self.val();
194		result |= result << 2;
195		result |= result << 4;
196
197		result
198	}
199}
200
201impl SpreadToByte for U4 {
202	#[inline(always)]
203	fn spread_to_byte(self) -> u8 {
204		let mut result = self.val();
205		result |= result << 4;
206
207		result
208	}
209}
210
211/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
212///
213/// # Safety
214/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
215#[allow(unused)]
216#[inline(always)]
217pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
218	value: U,
219	block_idx: usize,
220) -> [T; BLOCK_LEN]
221where
222	U: UnderlierWithBitOps + From<T> + Divisible<T>,
223	T: UnderlierType + NumCast<U>,
224{
225	std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
226}
227
228/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
229///
230/// # Safety
231/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
232#[allow(unused)]
233#[inline(always)]
234pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
235	value: U,
236	block_idx: usize,
237) -> [u8; BLOCK_LEN]
238where
239	U: UnderlierWithBitOps + From<T> + Divisible<T>,
240	T: UnderlierType + SpreadToByte + NumCast<U>,
241{
242	unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
243		.map(SpreadToByte::spread_to_byte)
244}
245
246#[cfg(test)]
247mod tests {
248	use proptest::{arbitrary::any, bits, proptest};
249
250	use super::{
251		super::small_uint::{U1, U2, U4},
252		*,
253	};
254
255	#[test]
256	fn test_from_fn() {
257		assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
258		assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
259		assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
260
261		assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
262		assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
263		assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
264		assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
265		assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
266
267		assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
268		assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
269		assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
270		assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
271		assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
272
273		assert_eq!(u32::from_fn(|_| 0u8), 0);
274		assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
275		assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
276		assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
277	}
278
279	#[test]
280	fn test_broadcast_subvalue() {
281		assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
282		assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
283
284		assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
285		assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
286		assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
287		assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
288
289		assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
290		assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
291		assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
292		assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
293
294		assert_eq!(u32::broadcast_subvalue(0u8), 0);
295		assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
296		assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
297	}
298
299	#[test]
300	fn test_get_subvalue() {
301		let value = 0xab12cd34u32;
302
303		unsafe {
304			assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
305			assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
306			assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
307			assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
308
309			assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
310			assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
311			assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
312			assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
313
314			assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
315			assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
316			assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
317			assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
318
319			assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
320			assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
321			assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
322			assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
323		}
324	}
325
326	proptest! {
327		#[test]
328		fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
329			unsafe {
330				init_val.set_subvalue(i, U1::new(val));
331				assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
332			}
333		}
334
335		#[test]
336		fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
337			unsafe {
338				init_val.set_subvalue(i, U2::new(val));
339				assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
340			}
341		}
342
343		#[test]
344		fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
345			unsafe {
346				init_val.set_subvalue(i, U4::new(val));
347				assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
348			}
349		}
350
351		#[test]
352		fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
353			unsafe {
354				init_val.set_subvalue(i, val);
355				assert_eq!(init_val.get_subvalue::<u8>(i), val);
356			}
357		}
358	}
359}