binius_field/underlier/
underlier_with_bit_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::checked_int_div;
6
7use super::{
8	U1, U2, U4,
9	underlier_type::{NumCast, UnderlierType},
10};
11use crate::Divisible;
12
13/// Underlier type that supports bit arithmetic.
14pub trait UnderlierWithBitOps:
15	UnderlierType
16	+ BitAnd<Self, Output = Self>
17	+ BitAndAssign<Self>
18	+ BitOr<Self, Output = Self>
19	+ BitOrAssign<Self>
20	+ BitXor<Self, Output = Self>
21	+ BitXorAssign<Self>
22	+ Shr<usize, Output = Self>
23	+ Shl<usize, Output = Self>
24	+ Not<Output = Self>
25	+ Divisible<U1>
26{
27	const ZERO: Self;
28	const ONE: Self;
29	const ONES: Self;
30
31	/// Fill value with the given bit
32	/// `val` must be 0 or 1.
33	fn fill_with_bit(val: u8) -> Self {
34		Self::broadcast_subvalue(U1::new(val))
35	}
36
37	/// Interleave with the given bit size
38	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
39
40	/// Transpose with the given bit size
41	fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) {
42		assert!(log_block_len < Self::LOG_BITS);
43
44		for log_block_len in (log_block_len..Self::LOG_BITS).rev() {
45			(self, other) = self.interleave(other, log_block_len);
46		}
47
48		(self, other)
49	}
50
51	#[inline]
52	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self
53	where
54		T: UnderlierType,
55		Self: Divisible<T>,
56	{
57		Self::from_iter((0..<Self as Divisible<T>>::N).map(f))
58	}
59
60	/// Broadcast subvalue to fill `Self`.
61	/// `Self::BITS/T::BITS` is supposed to be a power of 2.
62	#[inline]
63	fn broadcast_subvalue<T>(value: T) -> Self
64	where
65		T: UnderlierType,
66		Self: Divisible<T>,
67	{
68		Divisible::<T>::broadcast(value)
69	}
70
71	/// Gets the subvalue from the given position.
72	/// Function panics in case when index is out of range.
73	///
74	/// # Safety
75	/// `i` must be less than `Self::BITS/T::BITS`.
76	#[inline]
77	unsafe fn get_subvalue<T>(&self, i: usize) -> T
78	where
79		T: UnderlierType,
80		Self: Divisible<T>,
81	{
82		debug_assert!(
83			i < checked_int_div(Self::BITS, T::BITS),
84			"i: {} Self::BITS: {}, T::BITS: {}",
85			i,
86			Self::BITS,
87			T::BITS
88		);
89		Divisible::<T>::get(*self, i)
90	}
91
92	/// Sets the subvalue in the given position.
93	/// Function panics in case when index is out of range.
94	///
95	/// # Safety
96	/// `i` must be less than `Self::BITS/T::BITS`.
97	#[inline]
98	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
99	where
100		T: UnderlierWithBitOps,
101		Self: Divisible<T>,
102	{
103		debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
104		*self = (*self).set(i, val);
105	}
106
107	/// Spread takes a block of sub_elements of `T` type within the current value and
108	/// repeats them to the full underlier width.
109	///
110	/// # Safety
111	/// `log_block_len + T::LOG_BITS` must be less than or equal to `Self::LOG_BITS`.
112	/// `block_idx` must be less than `1 << (Self::LOG_BITS - log_block_len)`.
113	#[inline]
114	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
115	where
116		T: UnderlierWithBitOps,
117		Self: Divisible<T>,
118	{
119		unsafe { spread_fallback::<Self, T>(self, log_block_len, block_idx) }
120	}
121}
122
123/// Fallback implementation of `spread` method.
124///
125/// # Safety
126/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
127/// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`.
128pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
129where
130	U: UnderlierWithBitOps + Divisible<T>,
131	T: UnderlierWithBitOps,
132{
133	debug_assert!(
134		log_block_len + T::LOG_BITS <= U::LOG_BITS,
135		"log_block_len: {}, U::BITS: {}, T::BITS: {}",
136		log_block_len,
137		U::BITS,
138		T::BITS
139	);
140	debug_assert!(
141		block_idx < 1 << (U::LOG_BITS - log_block_len),
142		"block_idx: {}, U::BITS: {}, log_block_len: {}",
143		block_idx,
144		U::BITS,
145		log_block_len
146	);
147
148	let mut result = U::ZERO;
149	let block_offset = block_idx << log_block_len;
150	let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
151	for i in 0..1 << log_block_len {
152		unsafe {
153			result.set_subvalue(i << log_repeat, value.get_subvalue::<T>(block_offset + i));
154		}
155	}
156
157	for i in 0..log_repeat {
158		result |= result << (1 << (T::LOG_BITS + i));
159	}
160
161	result
162}
163
164#[cfg(test)]
165#[allow(unused)]
166pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
167	use binius_utils::checked_arithmetics::checked_log_2;
168
169	if bits_count == T::BITS {
170		!T::ZERO
171	} else {
172		let mut result = T::ONE;
173		for height in 0..checked_log_2(bits_count) {
174			result |= result << (1 << height)
175		}
176
177		result
178	}
179}
180
181/// Value that can be spread to a single u8
182pub(crate) trait SpreadToByte {
183	fn spread_to_byte(self) -> u8;
184}
185
186impl SpreadToByte for U1 {
187	#[inline(always)]
188	fn spread_to_byte(self) -> u8 {
189		u8::fill_with_bit(self.val())
190	}
191}
192
193impl SpreadToByte for U2 {
194	#[inline(always)]
195	fn spread_to_byte(self) -> u8 {
196		let mut result = self.val();
197		result |= result << 2;
198		result |= result << 4;
199
200		result
201	}
202}
203
204impl SpreadToByte for U4 {
205	#[inline(always)]
206	fn spread_to_byte(self) -> u8 {
207		let mut result = self.val();
208		result |= result << 4;
209
210		result
211	}
212}
213
214/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
215///
216/// # Safety
217/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
218#[allow(unused)]
219#[inline(always)]
220pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
221	value: U,
222	block_idx: usize,
223) -> [T; BLOCK_LEN]
224where
225	U: UnderlierWithBitOps + From<T> + Divisible<T>,
226	T: UnderlierType + NumCast<U>,
227{
228	std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
229}
230
231/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
232///
233/// # Safety
234/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
235#[allow(unused)]
236#[inline(always)]
237pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
238	value: U,
239	block_idx: usize,
240) -> [u8; BLOCK_LEN]
241where
242	U: UnderlierWithBitOps + From<T> + Divisible<T>,
243	T: UnderlierType + SpreadToByte + NumCast<U>,
244{
245	unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
246		.map(SpreadToByte::spread_to_byte)
247}
248
249#[cfg(test)]
250mod tests {
251	use proptest::{arbitrary::any, bits, proptest};
252
253	use super::{
254		super::small_uint::{U1, U2, U4},
255		*,
256	};
257
258	#[test]
259	fn test_from_fn() {
260		assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
261		assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
262		assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
263
264		assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
265		assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
266		assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
267		assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
268		assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
269
270		assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
271		assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
272		assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
273		assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
274		assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
275
276		assert_eq!(u32::from_fn(|_| 0u8), 0);
277		assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
278		assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
279		assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
280	}
281
282	#[test]
283	fn test_broadcast_subvalue() {
284		assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
285		assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
286
287		assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
288		assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
289		assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
290		assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
291
292		assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
293		assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
294		assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
295		assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
296
297		assert_eq!(u32::broadcast_subvalue(0u8), 0);
298		assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
299		assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
300	}
301
302	#[test]
303	fn test_get_subvalue() {
304		let value = 0xab12cd34u32;
305
306		unsafe {
307			assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
308			assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
309			assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
310			assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
311
312			assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
313			assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
314			assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
315			assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
316
317			assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
318			assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
319			assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
320			assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
321
322			assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
323			assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
324			assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
325			assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
326		}
327	}
328
329	proptest! {
330		#[test]
331		fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
332			unsafe {
333				init_val.set_subvalue(i, U1::new(val));
334				assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
335			}
336		}
337
338		#[test]
339		fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
340			unsafe {
341				init_val.set_subvalue(i, U2::new(val));
342				assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
343			}
344		}
345
346		#[test]
347		fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
348			unsafe {
349				init_val.set_subvalue(i, U4::new(val));
350				assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
351			}
352		}
353
354		#[test]
355		fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
356			unsafe {
357				init_val.set_subvalue(i, val);
358				assert_eq!(init_val.get_subvalue::<u8>(i), val);
359			}
360		}
361	}
362}