binius_field/underlier/
underlier_with_bit_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8	underlier_type::{NumCast, UnderlierType},
9	U1, U2, U4,
10};
11
12/// Underlier type that supports bit arithmetic.
13pub trait UnderlierWithBitOps:
14	UnderlierType
15	+ BitAnd<Self, Output = Self>
16	+ BitAndAssign<Self>
17	+ BitOr<Self, Output = Self>
18	+ BitOrAssign<Self>
19	+ BitXor<Self, Output = Self>
20	+ BitXorAssign<Self>
21	+ Shr<usize, Output = Self>
22	+ Shl<usize, Output = Self>
23	+ Not<Output = Self>
24{
25	const ZERO: Self;
26	const ONE: Self;
27	const ONES: Self;
28
29	/// Fill value with the given bit
30	/// `val` must be 0 or 1.
31	fn fill_with_bit(val: u8) -> Self;
32
33	#[inline]
34	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
35	where
36		T: UnderlierType,
37		Self: From<T>,
38	{
39		// This implementation is optimal for the case when `Self` us u8..u128.
40		// For SIMD types/arrays specialization would be more performant.
41		let mut result = Self::default();
42		let width = checked_int_div(Self::BITS, T::BITS);
43		for i in 0..width {
44			result |= Self::from(f(i)) << (i * T::BITS);
45		}
46
47		result
48	}
49
50	/// Broadcast subvalue to fill `Self`.
51	/// `Self::BITS/T::BITS` is supposed to be a power of 2.
52	#[inline]
53	fn broadcast_subvalue<T>(value: T) -> Self
54	where
55		T: UnderlierType,
56		Self: From<T>,
57	{
58		// This implementation is optimal for the case when `Self` us u8..u128.
59		// For SIMD types/arrays specialization would be more performant.
60		let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
61		let mut result = Self::from(value);
62		for i in 0..height {
63			result |= result << ((1 << i) * T::BITS);
64		}
65
66		result
67	}
68
69	/// Gets the subvalue from the given position.
70	/// Function panics in case when index is out of range.
71	///
72	/// # Safety
73	/// `i` must be less than `Self::BITS/T::BITS`.
74	#[inline]
75	unsafe fn get_subvalue<T>(&self, i: usize) -> T
76	where
77		T: UnderlierType + NumCast<Self>,
78	{
79		debug_assert!(
80			i < checked_int_div(Self::BITS, T::BITS),
81			"i: {} Self::BITS: {}, T::BITS: {}",
82			i,
83			Self::BITS,
84			T::BITS
85		);
86		T::num_cast_from(*self >> (i * T::BITS))
87	}
88
89	/// Sets the subvalue in the given position.
90	/// Function panics in case when index is out of range.
91	///
92	/// # Safety
93	/// `i` must be less than `Self::BITS/T::BITS`.
94	#[inline]
95	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
96	where
97		T: UnderlierWithBitOps,
98		Self: From<T>,
99	{
100		debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
101		let mask = Self::from(single_element_mask::<T>());
102
103		*self &= !(mask << (i * T::BITS));
104		*self |= Self::from(val) << (i * T::BITS);
105	}
106
107	/// Spread takes a block of sub_elements of `T` type within the current value and
108	/// repeats them to the full underlier width.
109	///
110	/// # Safety
111	/// `log_block_len + T::LOG_BITS` must be less than or equal to `Self::LOG_BITS`.
112	/// `block_idx` must be less than `1 << (Self::LOG_BITS - log_block_len)`.
113	#[inline]
114	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
115	where
116		T: UnderlierWithBitOps + NumCast<Self>,
117		Self: From<T>,
118	{
119		spread_fallback(self, log_block_len, block_idx)
120	}
121
122	/// Left shift within 128-bit lanes.
123	/// This can be more efficient than the full `Shl` implementation.
124	fn shl_128b_lanes(self, shift: usize) -> Self;
125
126	/// Right shift within 128-bit lanes.
127	/// This can be more efficient than the full `Shr` implementation.
128	fn shr_128b_lanes(self, shift: usize) -> Self;
129
130	/// Unpacks `1 << log_block_len`-bit values from low parts of `self` and `other` within 128-bit lanes.
131	///
132	/// Example:
133	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
134	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
135	///    log_block_len: 1
136	///
137	///    result: [a_0, a_0, b_0, b_1, a_2, a_3, b_2, b_3]
138	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
139		unpack_lo_128b_fallback(self, other, log_block_len)
140	}
141
142	/// Unpacks `1 << log_block_len`-bit values from high parts of `self` and `other` within 128-bit lanes.
143	///
144	/// Example:
145	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
146	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
147	///    log_block_len: 1
148	///
149	///    result: [a_4, a_5, b_4, b_5, a_6, a_7, b_6, b_7]
150	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
151		unpack_hi_128b_fallback(self, other, log_block_len)
152	}
153}
154
155/// Returns a bit mask for a single `T` element inside underlier type.
156/// This function is completely optimized out by the compiler in release version
157/// because all the values are known at compile time.
158fn single_element_mask<T>() -> T
159where
160	T: UnderlierWithBitOps,
161{
162	single_element_mask_bits(T::BITS)
163}
164
165/// Fallback implementation of `spread` method.
166///
167/// # Safety
168/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
169/// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`.
170pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
171where
172	U: UnderlierWithBitOps + From<T>,
173	T: UnderlierWithBitOps + NumCast<U>,
174{
175	debug_assert!(
176		log_block_len + T::LOG_BITS <= U::LOG_BITS,
177		"log_block_len: {}, U::BITS: {}, T::BITS: {}",
178		log_block_len,
179		U::BITS,
180		T::BITS
181	);
182	debug_assert!(
183		block_idx < 1 << (U::LOG_BITS - log_block_len),
184		"block_idx: {}, U::BITS: {}, log_block_len: {}",
185		block_idx,
186		U::BITS,
187		log_block_len
188	);
189
190	let mut result = U::ZERO;
191	let block_offset = block_idx << log_block_len;
192	let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
193	for i in 0..1 << log_block_len {
194		unsafe {
195			result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
196		}
197	}
198
199	for i in 0..log_repeat {
200		result |= result << (1 << (T::LOG_BITS + i));
201	}
202
203	result
204}
205
206#[inline(always)]
207fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
208	let mut mask = single_element_mask_bits(1 << log_block_len);
209	for i in 1..T::BITS / 128 {
210		mask |= mask << (i * 128);
211	}
212
213	mask
214}
215
216pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
217	lhs: T,
218	rhs: T,
219	log_block_len: usize,
220) -> T {
221	assert!(log_block_len <= 6);
222
223	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
224
225	let mut result = T::ZERO;
226	for i in 0..1 << (6 - log_block_len) {
227		result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
228			.shl_128b_lanes(i << (log_block_len + 1));
229		result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
230			.shl_128b_lanes((2 * i + 1) << log_block_len);
231	}
232
233	result
234}
235
236pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
237	lhs: T,
238	rhs: T,
239	log_block_len: usize,
240) -> T {
241	assert!(log_block_len <= 6);
242
243	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
244	let mut result = T::ZERO;
245	for i in 0..1 << (6 - log_block_len) {
246		result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
247			.shl_128b_lanes(i << (log_block_len + 1));
248		result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
249			.shl_128b_lanes((2 * i + 1) << log_block_len);
250	}
251
252	result
253}
254
255pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
256	if bits_count == T::BITS {
257		!T::ZERO
258	} else {
259		let mut result = T::ONE;
260		for height in 0..checked_log_2(bits_count) {
261			result |= result << (1 << height)
262		}
263
264		result
265	}
266}
267
268/// Value that can be spread to a single u8
269pub(crate) trait SpreadToByte {
270	fn spread_to_byte(self) -> u8;
271}
272
273impl SpreadToByte for U1 {
274	#[inline(always)]
275	fn spread_to_byte(self) -> u8 {
276		u8::fill_with_bit(self.val())
277	}
278}
279
280impl SpreadToByte for U2 {
281	#[inline(always)]
282	fn spread_to_byte(self) -> u8 {
283		let mut result = self.val();
284		result |= result << 2;
285		result |= result << 4;
286
287		result
288	}
289}
290
291impl SpreadToByte for U4 {
292	#[inline(always)]
293	fn spread_to_byte(self) -> u8 {
294		let mut result = self.val();
295		result |= result << 4;
296
297		result
298	}
299}
300
301/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
302///
303/// # Safety
304/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
305#[allow(unused)]
306#[inline(always)]
307pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
308	value: U,
309	block_idx: usize,
310) -> [T; BLOCK_LEN]
311where
312	U: UnderlierWithBitOps + From<T>,
313	T: UnderlierType + NumCast<U>,
314{
315	std::array::from_fn(|i| value.get_subvalue::<T>(block_idx * BLOCK_LEN + i))
316}
317
318/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
319///
320/// # Safety
321/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
322#[allow(unused)]
323#[inline(always)]
324pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
325	value: U,
326	block_idx: usize,
327) -> [u8; BLOCK_LEN]
328where
329	U: UnderlierWithBitOps + From<T>,
330	T: UnderlierType + SpreadToByte + NumCast<U>,
331{
332	get_block_values::<U, T, BLOCK_LEN>(value, block_idx).map(SpreadToByte::spread_to_byte)
333}
334
335#[cfg(test)]
336mod tests {
337	use proptest::{arbitrary::any, bits, proptest};
338
339	use super::{
340		super::small_uint::{U1, U2, U4},
341		*,
342	};
343
344	#[test]
345	fn test_from_fn() {
346		assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
347		assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
348		assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
349
350		assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
351		assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
352		assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
353		assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
354		assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
355
356		assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
357		assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
358		assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
359		assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
360		assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
361
362		assert_eq!(u32::from_fn(|_| 0u8), 0);
363		assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
364		assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
365		assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
366	}
367
368	#[test]
369	fn test_broadcast_subvalue() {
370		assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
371		assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
372
373		assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
374		assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
375		assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
376		assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
377
378		assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
379		assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
380		assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
381		assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
382
383		assert_eq!(u32::broadcast_subvalue(0u8), 0);
384		assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
385		assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
386	}
387
388	#[test]
389	fn test_get_subvalue() {
390		let value = 0xab12cd34u32;
391
392		unsafe {
393			assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
394			assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
395			assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
396			assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
397
398			assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
399			assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
400			assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
401			assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
402
403			assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
404			assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
405			assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
406			assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
407
408			assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
409			assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
410			assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
411			assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
412		}
413	}
414
415	proptest! {
416		#[test]
417		fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
418			unsafe {
419				init_val.set_subvalue(i, U1::new(val));
420				assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
421			}
422		}
423
424		#[test]
425		fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
426			unsafe {
427				init_val.set_subvalue(i, U2::new(val));
428				assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
429			}
430		}
431
432		#[test]
433		fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
434			unsafe {
435				init_val.set_subvalue(i, U4::new(val));
436				assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
437			}
438		}
439
440		#[test]
441		fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
442			unsafe {
443				init_val.set_subvalue(i, val);
444				assert_eq!(init_val.get_subvalue::<u8>(i), val);
445			}
446		}
447	}
448}