binius_field/underlier/
underlier_with_bit_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8	U1, U2, U4,
9	underlier_type::{NumCast, UnderlierType},
10};
11
12/// Underlier type that supports bit arithmetic.
13pub trait UnderlierWithBitOps:
14	UnderlierType
15	+ BitAnd<Self, Output = Self>
16	+ BitAndAssign<Self>
17	+ BitOr<Self, Output = Self>
18	+ BitOrAssign<Self>
19	+ BitXor<Self, Output = Self>
20	+ BitXorAssign<Self>
21	+ Shr<usize, Output = Self>
22	+ Shl<usize, Output = Self>
23	+ Not<Output = Self>
24{
25	const ZERO: Self;
26	const ONE: Self;
27	const ONES: Self;
28
29	/// Fill value with the given bit
30	/// `val` must be 0 or 1.
31	fn fill_with_bit(val: u8) -> Self;
32
33	#[inline]
34	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
35	where
36		T: UnderlierType,
37		Self: From<T>,
38	{
39		// This implementation is optimal for the case when `Self` us u8..u128.
40		// For SIMD types/arrays specialization would be more performant.
41		let mut result = Self::default();
42		let width = checked_int_div(Self::BITS, T::BITS);
43		for i in 0..width {
44			result |= Self::from(f(i)) << (i * T::BITS);
45		}
46
47		result
48	}
49
50	/// Broadcast subvalue to fill `Self`.
51	/// `Self::BITS/T::BITS` is supposed to be a power of 2.
52	#[inline]
53	fn broadcast_subvalue<T>(value: T) -> Self
54	where
55		T: UnderlierType,
56		Self: From<T>,
57	{
58		// This implementation is optimal for the case when `Self` us u8..u128.
59		// For SIMD types/arrays specialization would be more performant.
60		let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
61		let mut result = Self::from(value);
62		for i in 0..height {
63			result |= result << ((1 << i) * T::BITS);
64		}
65
66		result
67	}
68
69	/// Gets the subvalue from the given position.
70	/// Function panics in case when index is out of range.
71	///
72	/// # Safety
73	/// `i` must be less than `Self::BITS/T::BITS`.
74	#[inline]
75	unsafe fn get_subvalue<T>(&self, i: usize) -> T
76	where
77		T: UnderlierType + NumCast<Self>,
78	{
79		debug_assert!(
80			i < checked_int_div(Self::BITS, T::BITS),
81			"i: {} Self::BITS: {}, T::BITS: {}",
82			i,
83			Self::BITS,
84			T::BITS
85		);
86		T::num_cast_from(*self >> (i * T::BITS))
87	}
88
89	/// Sets the subvalue in the given position.
90	/// Function panics in case when index is out of range.
91	///
92	/// # Safety
93	/// `i` must be less than `Self::BITS/T::BITS`.
94	#[inline]
95	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
96	where
97		T: UnderlierWithBitOps,
98		Self: From<T>,
99	{
100		debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
101		let mask = Self::from(single_element_mask::<T>());
102
103		*self &= !(mask << (i * T::BITS));
104		*self |= Self::from(val) << (i * T::BITS);
105	}
106
107	/// Spread takes a block of sub_elements of `T` type within the current value and
108	/// repeats them to the full underlier width.
109	///
110	/// # Safety
111	/// `log_block_len + T::LOG_BITS` must be less than or equal to `Self::LOG_BITS`.
112	/// `block_idx` must be less than `1 << (Self::LOG_BITS - log_block_len)`.
113	#[inline]
114	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
115	where
116		T: UnderlierWithBitOps + NumCast<Self>,
117		Self: From<T>,
118	{
119		unsafe { spread_fallback(self, log_block_len, block_idx) }
120	}
121
122	/// Left shift within 128-bit lanes.
123	/// This can be more efficient than the full `Shl` implementation.
124	fn shl_128b_lanes(self, shift: usize) -> Self;
125
126	/// Right shift within 128-bit lanes.
127	/// This can be more efficient than the full `Shr` implementation.
128	fn shr_128b_lanes(self, shift: usize) -> Self;
129
130	/// Unpacks `1 << log_block_len`-bit values from low parts of `self` and `other` within 128-bit
131	/// lanes.
132	///
133	/// Example:
134	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
135	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
136	///    log_block_len: 1
137	///
138	///    result: [a_0, a_0, b_0, b_1, a_2, a_3, b_2, b_3]
139	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
140		unpack_lo_128b_fallback(self, other, log_block_len)
141	}
142
143	/// Unpacks `1 << log_block_len`-bit values from high parts of `self` and `other` within 128-bit
144	/// lanes.
145	///
146	/// Example:
147	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
148	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
149	///    log_block_len: 1
150	///
151	///    result: [a_4, a_5, b_4, b_5, a_6, a_7, b_6, b_7]
152	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
153		unpack_hi_128b_fallback(self, other, log_block_len)
154	}
155}
156
157/// Returns a bit mask for a single `T` element inside underlier type.
158/// This function is completely optimized out by the compiler in release version
159/// because all the values are known at compile time.
160fn single_element_mask<T>() -> T
161where
162	T: UnderlierWithBitOps,
163{
164	single_element_mask_bits(T::BITS)
165}
166
167/// A helper function to apply unpack_lo/hi_128b_lanes for two values in an array
168#[allow(dead_code)]
169#[inline(always)]
170pub(crate) fn pair_unpack_lo_hi_128b_lanes<U: UnderlierWithBitOps>(
171	values: &mut impl AsMut<[U]>,
172	i: usize,
173	j: usize,
174	log_block_len: usize,
175) {
176	let values = values.as_mut();
177
178	(values[i], values[j]) = (
179		values[i].unpack_lo_128b_lanes(values[j], log_block_len),
180		values[i].unpack_hi_128b_lanes(values[j], log_block_len),
181	);
182}
183
184/// Fallback implementation of `spread` method.
185///
186/// # Safety
187/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
188/// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`.
189pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
190where
191	U: UnderlierWithBitOps + From<T>,
192	T: UnderlierWithBitOps + NumCast<U>,
193{
194	debug_assert!(
195		log_block_len + T::LOG_BITS <= U::LOG_BITS,
196		"log_block_len: {}, U::BITS: {}, T::BITS: {}",
197		log_block_len,
198		U::BITS,
199		T::BITS
200	);
201	debug_assert!(
202		block_idx < 1 << (U::LOG_BITS - log_block_len),
203		"block_idx: {}, U::BITS: {}, log_block_len: {}",
204		block_idx,
205		U::BITS,
206		log_block_len
207	);
208
209	let mut result = U::ZERO;
210	let block_offset = block_idx << log_block_len;
211	let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
212	for i in 0..1 << log_block_len {
213		unsafe {
214			result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
215		}
216	}
217
218	for i in 0..log_repeat {
219		result |= result << (1 << (T::LOG_BITS + i));
220	}
221
222	result
223}
224
225#[inline(always)]
226fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
227	let mut mask = single_element_mask_bits(1 << log_block_len);
228	for i in 1..T::BITS / 128 {
229		mask |= mask << (i * 128);
230	}
231
232	mask
233}
234
235pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
236	lhs: T,
237	rhs: T,
238	log_block_len: usize,
239) -> T {
240	assert!(log_block_len <= 6);
241
242	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
243
244	let mut result = T::ZERO;
245	for i in 0..1 << (6 - log_block_len) {
246		result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
247			.shl_128b_lanes(i << (log_block_len + 1));
248		result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
249			.shl_128b_lanes((2 * i + 1) << log_block_len);
250	}
251
252	result
253}
254
255pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
256	lhs: T,
257	rhs: T,
258	log_block_len: usize,
259) -> T {
260	assert!(log_block_len <= 6);
261
262	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
263	let mut result = T::ZERO;
264	for i in 0..1 << (6 - log_block_len) {
265		result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
266			.shl_128b_lanes(i << (log_block_len + 1));
267		result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
268			.shl_128b_lanes((2 * i + 1) << log_block_len);
269	}
270
271	result
272}
273
274pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
275	if bits_count == T::BITS {
276		!T::ZERO
277	} else {
278		let mut result = T::ONE;
279		for height in 0..checked_log_2(bits_count) {
280			result |= result << (1 << height)
281		}
282
283		result
284	}
285}
286
287/// Value that can be spread to a single u8
288pub(crate) trait SpreadToByte {
289	fn spread_to_byte(self) -> u8;
290}
291
292impl SpreadToByte for U1 {
293	#[inline(always)]
294	fn spread_to_byte(self) -> u8 {
295		u8::fill_with_bit(self.val())
296	}
297}
298
299impl SpreadToByte for U2 {
300	#[inline(always)]
301	fn spread_to_byte(self) -> u8 {
302		let mut result = self.val();
303		result |= result << 2;
304		result |= result << 4;
305
306		result
307	}
308}
309
310impl SpreadToByte for U4 {
311	#[inline(always)]
312	fn spread_to_byte(self) -> u8 {
313		let mut result = self.val();
314		result |= result << 4;
315
316		result
317	}
318}
319
320/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
321///
322/// # Safety
323/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
324#[allow(unused)]
325#[inline(always)]
326pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
327	value: U,
328	block_idx: usize,
329) -> [T; BLOCK_LEN]
330where
331	U: UnderlierWithBitOps + From<T>,
332	T: UnderlierType + NumCast<U>,
333{
334	std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
335}
336
337/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
338///
339/// # Safety
340/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
341#[allow(unused)]
342#[inline(always)]
343pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
344	value: U,
345	block_idx: usize,
346) -> [u8; BLOCK_LEN]
347where
348	U: UnderlierWithBitOps + From<T>,
349	T: UnderlierType + SpreadToByte + NumCast<U>,
350{
351	unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
352		.map(SpreadToByte::spread_to_byte)
353}
354
355#[cfg(test)]
356mod tests {
357	use proptest::{arbitrary::any, bits, proptest};
358
359	use super::{
360		super::small_uint::{U1, U2, U4},
361		*,
362	};
363
364	#[test]
365	fn test_from_fn() {
366		assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
367		assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
368		assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
369
370		assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
371		assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
372		assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
373		assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
374		assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
375
376		assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
377		assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
378		assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
379		assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
380		assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
381
382		assert_eq!(u32::from_fn(|_| 0u8), 0);
383		assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
384		assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
385		assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
386	}
387
388	#[test]
389	fn test_broadcast_subvalue() {
390		assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
391		assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
392
393		assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
394		assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
395		assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
396		assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
397
398		assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
399		assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
400		assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
401		assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
402
403		assert_eq!(u32::broadcast_subvalue(0u8), 0);
404		assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
405		assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
406	}
407
408	#[test]
409	fn test_get_subvalue() {
410		let value = 0xab12cd34u32;
411
412		unsafe {
413			assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
414			assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
415			assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
416			assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
417
418			assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
419			assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
420			assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
421			assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
422
423			assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
424			assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
425			assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
426			assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
427
428			assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
429			assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
430			assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
431			assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
432		}
433	}
434
435	proptest! {
436		#[test]
437		fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
438			unsafe {
439				init_val.set_subvalue(i, U1::new(val));
440				assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
441			}
442		}
443
444		#[test]
445		fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
446			unsafe {
447				init_val.set_subvalue(i, U2::new(val));
448				assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
449			}
450		}
451
452		#[test]
453		fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
454			unsafe {
455				init_val.set_subvalue(i, U4::new(val));
456				assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
457			}
458		}
459
460		#[test]
461		fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
462			unsafe {
463				init_val.set_subvalue(i, val);
464				assert_eq!(init_val.get_subvalue::<u8>(i), val);
465			}
466		}
467	}
468}