binius_field/underlier/
underlier_with_bit_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8	U1, U2, U4,
9	underlier_type::{NumCast, UnderlierType},
10};
11use crate::tower_levels::TowerLevel;
12
13/// Underlier type that supports bit arithmetic.
14pub trait UnderlierWithBitOps:
15	UnderlierType
16	+ BitAnd<Self, Output = Self>
17	+ BitAndAssign<Self>
18	+ BitOr<Self, Output = Self>
19	+ BitOrAssign<Self>
20	+ BitXor<Self, Output = Self>
21	+ BitXorAssign<Self>
22	+ Shr<usize, Output = Self>
23	+ Shl<usize, Output = Self>
24	+ Not<Output = Self>
25{
26	const ZERO: Self;
27	const ONE: Self;
28	const ONES: Self;
29
30	/// Fill value with the given bit
31	/// `val` must be 0 or 1.
32	fn fill_with_bit(val: u8) -> Self;
33
34	#[inline]
35	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
36	where
37		T: UnderlierType,
38		Self: From<T>,
39	{
40		// This implementation is optimal for the case when `Self` us u8..u128.
41		// For SIMD types/arrays specialization would be more performant.
42		let mut result = Self::default();
43		let width = checked_int_div(Self::BITS, T::BITS);
44		for i in 0..width {
45			result |= Self::from(f(i)) << (i * T::BITS);
46		}
47
48		result
49	}
50
51	/// Broadcast subvalue to fill `Self`.
52	/// `Self::BITS/T::BITS` is supposed to be a power of 2.
53	#[inline]
54	fn broadcast_subvalue<T>(value: T) -> Self
55	where
56		T: UnderlierType,
57		Self: From<T>,
58	{
59		// This implementation is optimal for the case when `Self` us u8..u128.
60		// For SIMD types/arrays specialization would be more performant.
61		let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
62		let mut result = Self::from(value);
63		for i in 0..height {
64			result |= result << ((1 << i) * T::BITS);
65		}
66
67		result
68	}
69
70	/// Gets the subvalue from the given position.
71	/// Function panics in case when index is out of range.
72	///
73	/// # Safety
74	/// `i` must be less than `Self::BITS/T::BITS`.
75	#[inline]
76	unsafe fn get_subvalue<T>(&self, i: usize) -> T
77	where
78		T: UnderlierType + NumCast<Self>,
79	{
80		debug_assert!(
81			i < checked_int_div(Self::BITS, T::BITS),
82			"i: {} Self::BITS: {}, T::BITS: {}",
83			i,
84			Self::BITS,
85			T::BITS
86		);
87		T::num_cast_from(*self >> (i * T::BITS))
88	}
89
90	/// Sets the subvalue in the given position.
91	/// Function panics in case when index is out of range.
92	///
93	/// # Safety
94	/// `i` must be less than `Self::BITS/T::BITS`.
95	#[inline]
96	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
97	where
98		T: UnderlierWithBitOps,
99		Self: From<T>,
100	{
101		debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
102		let mask = Self::from(single_element_mask::<T>());
103
104		*self &= !(mask << (i * T::BITS));
105		*self |= Self::from(val) << (i * T::BITS);
106	}
107
108	/// Spread takes a block of sub_elements of `T` type within the current value and
109	/// repeats them to the full underlier width.
110	///
111	/// # Safety
112	/// `log_block_len + T::LOG_BITS` must be less than or equal to `Self::LOG_BITS`.
113	/// `block_idx` must be less than `1 << (Self::LOG_BITS - log_block_len)`.
114	#[inline]
115	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
116	where
117		T: UnderlierWithBitOps + NumCast<Self>,
118		Self: From<T>,
119	{
120		unsafe { spread_fallback(self, log_block_len, block_idx) }
121	}
122
123	/// Left shift within 128-bit lanes.
124	/// This can be more efficient than the full `Shl` implementation.
125	fn shl_128b_lanes(self, shift: usize) -> Self;
126
127	/// Right shift within 128-bit lanes.
128	/// This can be more efficient than the full `Shr` implementation.
129	fn shr_128b_lanes(self, shift: usize) -> Self;
130
131	/// Unpacks `1 << log_block_len`-bit values from low parts of `self` and `other` within 128-bit
132	/// lanes.
133	///
134	/// Example:
135	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
136	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
137	///    log_block_len: 1
138	///
139	///    result: [a_0, a_0, b_0, b_1, a_2, a_3, b_2, b_3]
140	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
141		unpack_lo_128b_fallback(self, other, log_block_len)
142	}
143
144	/// Unpacks `1 << log_block_len`-bit values from high parts of `self` and `other` within 128-bit
145	/// lanes.
146	///
147	/// Example:
148	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
149	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
150	///    log_block_len: 1
151	///
152	///    result: [a_4, a_5, b_4, b_5, a_6, a_7, b_6, b_7]
153	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
154		unpack_hi_128b_fallback(self, other, log_block_len)
155	}
156
157	/// Transpose bytes from byte-sliced representation to a packed "normal one".
158	///
159	/// For example for tower level 1, having the following bytes:
160	///     [a0, b0, c1, d1]
161	///     [a1, b1, c2, d2]
162	///
163	/// The result will be:
164	///     [a0, a1, b0, b1]
165	///     [c1, c2, d1, d2]
166	fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
167	where
168		u8: NumCast<Self>,
169		Self: From<u8>,
170	{
171		assert!(TL::LOG_WIDTH <= 4);
172
173		let result = TL::from_fn(|row| {
174			Self::from_fn(|col| {
175				let index = row * (Self::BITS / 8) + col;
176
177				// Safety: `index` is always less than `N * byte_count`.
178				unsafe { values[index % TL::WIDTH].get_subvalue::<u8>(index / TL::WIDTH) }
179			})
180		});
181
182		*values = result;
183	}
184
185	/// Transpose bytes from `ordinal` packed representation to a byte-sliced one.
186	///
187	/// For example for tower level 1, having the following bytes:
188	///    [a0, a1, b0, b1]
189	///    [c0, c1, d0, d1]
190	///
191	/// The result will be:
192	///   [a0, b0, c0, d0]
193	///   [a1, b1, c1, d1]
194	fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
195	where
196		u8: NumCast<Self>,
197		Self: From<u8>,
198	{
199		assert!(TL::LOG_WIDTH <= 4);
200
201		let bytes = Self::BITS / 8;
202		let result = TL::from_fn(|row| {
203			Self::from_fn(|col| {
204				let index = row + col * TL::WIDTH;
205
206				// Safety: `index` is always less than `N * byte_count`.
207				unsafe { values[index / bytes].get_subvalue::<u8>(index % bytes) }
208			})
209		});
210
211		*values = result;
212	}
213}
214
215/// Returns a bit mask for a single `T` element inside underlier type.
216/// This function is completely optimized out by the compiler in release version
217/// because all the values are known at compile time.
218fn single_element_mask<T>() -> T
219where
220	T: UnderlierWithBitOps,
221{
222	single_element_mask_bits(T::BITS)
223}
224
225/// A helper function to apply unpack_lo/hi_128b_lanes for two values in an array
226#[allow(dead_code)]
227#[inline(always)]
228pub(crate) fn pair_unpack_lo_hi_128b_lanes<U: UnderlierWithBitOps>(
229	values: &mut impl AsMut<[U]>,
230	i: usize,
231	j: usize,
232	log_block_len: usize,
233) {
234	let values = values.as_mut();
235
236	(values[i], values[j]) = (
237		values[i].unpack_lo_128b_lanes(values[j], log_block_len),
238		values[i].unpack_hi_128b_lanes(values[j], log_block_len),
239	);
240}
241
242/// A helper function used as a building block for efficient SIMD types transposition
243/// implementation. This function actually may reorder the elements.
244#[allow(dead_code)]
245#[inline(always)]
246pub(crate) fn transpose_128b_blocks_low_to_high<U: UnderlierWithBitOps, TL: TowerLevel>(
247	values: &mut TL::Data<U>,
248	log_block_len: usize,
249) {
250	assert!(TL::WIDTH <= 16);
251
252	if TL::WIDTH == 1 {
253		return;
254	}
255
256	let (left, right) = TL::split_mut(values);
257	transpose_128b_blocks_low_to_high::<_, TL::Base>(left, log_block_len);
258	transpose_128b_blocks_low_to_high::<_, TL::Base>(right, log_block_len);
259
260	let log_block_len = log_block_len + TL::LOG_WIDTH + 2;
261	for i in 0..TL::WIDTH / 2 {
262		pair_unpack_lo_hi_128b_lanes(values, i, i + TL::WIDTH / 2, log_block_len);
263	}
264}
265
266/// Transposition implementation for 128-bit SIMD types.
267/// This implementations is used for NEON and SSE2.
268#[allow(dead_code)]
269#[inline(always)]
270pub(crate) fn transpose_128b_values<U: UnderlierWithBitOps, TL: TowerLevel>(
271	values: &mut TL::Data<U>,
272	log_block_len: usize,
273) {
274	assert!(U::BITS == 128);
275
276	transpose_128b_blocks_low_to_high::<U, TL>(values, log_block_len);
277
278	// Elements are transposed, but we need to reorder them
279	match TL::LOG_WIDTH {
280		0 | 1 => {}
281		2 => {
282			values.as_mut().swap(1, 2);
283		}
284		3 => {
285			values.as_mut().swap(1, 4);
286			values.as_mut().swap(3, 6);
287		}
288		4 => {
289			values.as_mut().swap(1, 8);
290			values.as_mut().swap(2, 4);
291			values.as_mut().swap(3, 12);
292			values.as_mut().swap(5, 10);
293			values.as_mut().swap(7, 14);
294			values.as_mut().swap(11, 13);
295		}
296		_ => panic!("unsupported tower level"),
297	}
298}
299
300/// Fallback implementation of `spread` method.
301///
302/// # Safety
303/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
304/// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`.
305pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
306where
307	U: UnderlierWithBitOps + From<T>,
308	T: UnderlierWithBitOps + NumCast<U>,
309{
310	debug_assert!(
311		log_block_len + T::LOG_BITS <= U::LOG_BITS,
312		"log_block_len: {}, U::BITS: {}, T::BITS: {}",
313		log_block_len,
314		U::BITS,
315		T::BITS
316	);
317	debug_assert!(
318		block_idx < 1 << (U::LOG_BITS - log_block_len),
319		"block_idx: {}, U::BITS: {}, log_block_len: {}",
320		block_idx,
321		U::BITS,
322		log_block_len
323	);
324
325	let mut result = U::ZERO;
326	let block_offset = block_idx << log_block_len;
327	let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
328	for i in 0..1 << log_block_len {
329		unsafe {
330			result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
331		}
332	}
333
334	for i in 0..log_repeat {
335		result |= result << (1 << (T::LOG_BITS + i));
336	}
337
338	result
339}
340
341#[inline(always)]
342fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
343	let mut mask = single_element_mask_bits(1 << log_block_len);
344	for i in 1..T::BITS / 128 {
345		mask |= mask << (i * 128);
346	}
347
348	mask
349}
350
351pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
352	lhs: T,
353	rhs: T,
354	log_block_len: usize,
355) -> T {
356	assert!(log_block_len <= 6);
357
358	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
359
360	let mut result = T::ZERO;
361	for i in 0..1 << (6 - log_block_len) {
362		result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
363			.shl_128b_lanes(i << (log_block_len + 1));
364		result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
365			.shl_128b_lanes((2 * i + 1) << log_block_len);
366	}
367
368	result
369}
370
371pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
372	lhs: T,
373	rhs: T,
374	log_block_len: usize,
375) -> T {
376	assert!(log_block_len <= 6);
377
378	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
379	let mut result = T::ZERO;
380	for i in 0..1 << (6 - log_block_len) {
381		result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
382			.shl_128b_lanes(i << (log_block_len + 1));
383		result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
384			.shl_128b_lanes((2 * i + 1) << log_block_len);
385	}
386
387	result
388}
389
390pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
391	if bits_count == T::BITS {
392		!T::ZERO
393	} else {
394		let mut result = T::ONE;
395		for height in 0..checked_log_2(bits_count) {
396			result |= result << (1 << height)
397		}
398
399		result
400	}
401}
402
403/// Value that can be spread to a single u8
404pub(crate) trait SpreadToByte {
405	fn spread_to_byte(self) -> u8;
406}
407
408impl SpreadToByte for U1 {
409	#[inline(always)]
410	fn spread_to_byte(self) -> u8 {
411		u8::fill_with_bit(self.val())
412	}
413}
414
415impl SpreadToByte for U2 {
416	#[inline(always)]
417	fn spread_to_byte(self) -> u8 {
418		let mut result = self.val();
419		result |= result << 2;
420		result |= result << 4;
421
422		result
423	}
424}
425
426impl SpreadToByte for U4 {
427	#[inline(always)]
428	fn spread_to_byte(self) -> u8 {
429		let mut result = self.val();
430		result |= result << 4;
431
432		result
433	}
434}
435
436/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
437///
438/// # Safety
439/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
440#[allow(unused)]
441#[inline(always)]
442pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
443	value: U,
444	block_idx: usize,
445) -> [T; BLOCK_LEN]
446where
447	U: UnderlierWithBitOps + From<T>,
448	T: UnderlierType + NumCast<U>,
449{
450	std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
451}
452
453/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
454///
455/// # Safety
456/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
457#[allow(unused)]
458#[inline(always)]
459pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
460	value: U,
461	block_idx: usize,
462) -> [u8; BLOCK_LEN]
463where
464	U: UnderlierWithBitOps + From<T>,
465	T: UnderlierType + SpreadToByte + NumCast<U>,
466{
467	unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
468		.map(SpreadToByte::spread_to_byte)
469}
470
471#[cfg(test)]
472mod tests {
473	use proptest::{arbitrary::any, bits, proptest};
474
475	use super::{
476		super::small_uint::{U1, U2, U4},
477		*,
478	};
479	use crate::tower_levels::{TowerLevel1, TowerLevel2};
480
481	#[test]
482	fn test_from_fn() {
483		assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
484		assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
485		assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
486
487		assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
488		assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
489		assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
490		assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
491		assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
492
493		assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
494		assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
495		assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
496		assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
497		assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
498
499		assert_eq!(u32::from_fn(|_| 0u8), 0);
500		assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
501		assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
502		assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
503	}
504
505	#[test]
506	fn test_broadcast_subvalue() {
507		assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
508		assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
509
510		assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
511		assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
512		assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
513		assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
514
515		assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
516		assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
517		assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
518		assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
519
520		assert_eq!(u32::broadcast_subvalue(0u8), 0);
521		assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
522		assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
523	}
524
525	#[test]
526	fn test_get_subvalue() {
527		let value = 0xab12cd34u32;
528
529		unsafe {
530			assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
531			assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
532			assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
533			assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
534
535			assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
536			assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
537			assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
538			assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
539
540			assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
541			assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
542			assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
543			assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
544
545			assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
546			assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
547			assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
548			assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
549		}
550	}
551
552	proptest! {
553		#[test]
554		fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
555			unsafe {
556				init_val.set_subvalue(i, U1::new(val));
557				assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
558			}
559		}
560
561		#[test]
562		fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
563			unsafe {
564				init_val.set_subvalue(i, U2::new(val));
565				assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
566			}
567		}
568
569		#[test]
570		fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
571			unsafe {
572				init_val.set_subvalue(i, U4::new(val));
573				assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
574			}
575		}
576
577		#[test]
578		fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
579			unsafe {
580				init_val.set_subvalue(i, val);
581				assert_eq!(init_val.get_subvalue::<u8>(i), val);
582			}
583		}
584	}
585
586	#[test]
587	fn test_transpose_from_byte_sliced() {
588		let mut value = [0x01234567u32];
589		u32::transpose_bytes_from_byte_sliced::<TowerLevel1>(&mut value);
590		assert_eq!(value, [0x01234567u32]);
591
592		let mut value = [0x67452301u32, 0xefcdab89u32];
593		u32::transpose_bytes_from_byte_sliced::<TowerLevel2>(&mut value);
594		assert_eq!(value, [0xab238901u32, 0xef67cd45u32]);
595	}
596
597	#[test]
598	fn test_transpose_to_byte_sliced() {
599		let mut value = [0x01234567u32];
600		u32::transpose_bytes_to_byte_sliced::<TowerLevel1>(&mut value);
601		assert_eq!(value, [0x01234567u32]);
602
603		let mut value = [0x67452301u32, 0xefcdab89u32];
604		u32::transpose_bytes_to_byte_sliced::<TowerLevel2>(&mut value);
605		assert_eq!(value, [0xcd894501u32, 0xefab6723u32]);
606	}
607}