binius_field/underlier/
underlier_with_bit_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::{checked_int_div, checked_log_2};
6
7use super::{
8	underlier_type::{NumCast, UnderlierType},
9	U1, U2, U4,
10};
11use crate::tower_levels::TowerLevel;
12
13/// Underlier type that supports bit arithmetic.
14pub trait UnderlierWithBitOps:
15	UnderlierType
16	+ BitAnd<Self, Output = Self>
17	+ BitAndAssign<Self>
18	+ BitOr<Self, Output = Self>
19	+ BitOrAssign<Self>
20	+ BitXor<Self, Output = Self>
21	+ BitXorAssign<Self>
22	+ Shr<usize, Output = Self>
23	+ Shl<usize, Output = Self>
24	+ Not<Output = Self>
25{
26	const ZERO: Self;
27	const ONE: Self;
28	const ONES: Self;
29
30	/// Fill value with the given bit
31	/// `val` must be 0 or 1.
32	fn fill_with_bit(val: u8) -> Self;
33
34	#[inline]
35	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
36	where
37		T: UnderlierType,
38		Self: From<T>,
39	{
40		// This implementation is optimal for the case when `Self` us u8..u128.
41		// For SIMD types/arrays specialization would be more performant.
42		let mut result = Self::default();
43		let width = checked_int_div(Self::BITS, T::BITS);
44		for i in 0..width {
45			result |= Self::from(f(i)) << (i * T::BITS);
46		}
47
48		result
49	}
50
51	/// Broadcast subvalue to fill `Self`.
52	/// `Self::BITS/T::BITS` is supposed to be a power of 2.
53	#[inline]
54	fn broadcast_subvalue<T>(value: T) -> Self
55	where
56		T: UnderlierType,
57		Self: From<T>,
58	{
59		// This implementation is optimal for the case when `Self` us u8..u128.
60		// For SIMD types/arrays specialization would be more performant.
61		let height = checked_log_2(checked_int_div(Self::BITS, T::BITS));
62		let mut result = Self::from(value);
63		for i in 0..height {
64			result |= result << ((1 << i) * T::BITS);
65		}
66
67		result
68	}
69
70	/// Gets the subvalue from the given position.
71	/// Function panics in case when index is out of range.
72	///
73	/// # Safety
74	/// `i` must be less than `Self::BITS/T::BITS`.
75	#[inline]
76	unsafe fn get_subvalue<T>(&self, i: usize) -> T
77	where
78		T: UnderlierType + NumCast<Self>,
79	{
80		debug_assert!(
81			i < checked_int_div(Self::BITS, T::BITS),
82			"i: {} Self::BITS: {}, T::BITS: {}",
83			i,
84			Self::BITS,
85			T::BITS
86		);
87		T::num_cast_from(*self >> (i * T::BITS))
88	}
89
90	/// Sets the subvalue in the given position.
91	/// Function panics in case when index is out of range.
92	///
93	/// # Safety
94	/// `i` must be less than `Self::BITS/T::BITS`.
95	#[inline]
96	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
97	where
98		T: UnderlierWithBitOps,
99		Self: From<T>,
100	{
101		debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
102		let mask = Self::from(single_element_mask::<T>());
103
104		*self &= !(mask << (i * T::BITS));
105		*self |= Self::from(val) << (i * T::BITS);
106	}
107
108	/// Spread takes a block of sub_elements of `T` type within the current value and
109	/// repeats them to the full underlier width.
110	///
111	/// # Safety
112	/// `log_block_len + T::LOG_BITS` must be less than or equal to `Self::LOG_BITS`.
113	/// `block_idx` must be less than `1 << (Self::LOG_BITS - log_block_len)`.
114	#[inline]
115	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
116	where
117		T: UnderlierWithBitOps + NumCast<Self>,
118		Self: From<T>,
119	{
120		spread_fallback(self, log_block_len, block_idx)
121	}
122
123	/// Left shift within 128-bit lanes.
124	/// This can be more efficient than the full `Shl` implementation.
125	fn shl_128b_lanes(self, shift: usize) -> Self;
126
127	/// Right shift within 128-bit lanes.
128	/// This can be more efficient than the full `Shr` implementation.
129	fn shr_128b_lanes(self, shift: usize) -> Self;
130
131	/// Unpacks `1 << log_block_len`-bit values from low parts of `self` and `other` within 128-bit lanes.
132	///
133	/// Example:
134	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
135	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
136	///    log_block_len: 1
137	///
138	///    result: [a_0, a_0, b_0, b_1, a_2, a_3, b_2, b_3]
139	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
140		unpack_lo_128b_fallback(self, other, log_block_len)
141	}
142
143	/// Unpacks `1 << log_block_len`-bit values from high parts of `self` and `other` within 128-bit lanes.
144	///
145	/// Example:
146	///    self:  [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
147	///    other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7]
148	///    log_block_len: 1
149	///
150	///    result: [a_4, a_5, b_4, b_5, a_6, a_7, b_6, b_7]
151	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
152		unpack_hi_128b_fallback(self, other, log_block_len)
153	}
154
155	/// Transpose bytes from byte-sliced representation to a packed "normal one".
156	///
157	/// For example for tower level 1, having the following bytes:
158	///     [a0, b0, c1, d1]
159	///     [a1, b1, c2, d2]
160	///
161	/// The result will be:
162	///     [a0, a1, b0, b1]
163	///     [c1, c2, d1, d2]
164	fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
165	where
166		u8: NumCast<Self>,
167		Self: From<u8>,
168	{
169		assert!(TL::LOG_WIDTH <= 4);
170
171		let result = TL::from_fn(|row| {
172			Self::from_fn(|col| {
173				let index = row * (Self::BITS / 8) + col;
174
175				// Safety: `index` is always less than `N * byte_count`.
176				unsafe { values[index % TL::WIDTH].get_subvalue::<u8>(index / TL::WIDTH) }
177			})
178		});
179
180		*values = result;
181	}
182
183	/// Transpose bytes from `ordinal` packed representation to a byte-sliced one.
184	///
185	/// For example for tower level 1, having the following bytes:
186	///    [a0, a1, b0, b1]
187	///    [c0, c1, d0, d1]
188	///
189	/// The result will be:
190	///   [a0, b0, c0, d0]
191	///   [a1, b1, c1, d1]
192	fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
193	where
194		u8: NumCast<Self>,
195		Self: From<u8>,
196	{
197		assert!(TL::LOG_WIDTH <= 4);
198
199		let bytes = Self::BITS / 8;
200		let result = TL::from_fn(|row| {
201			Self::from_fn(|col| {
202				let index = row + col * TL::WIDTH;
203
204				// Safety: `index` is always less than `N * byte_count`.
205				unsafe { values[index / bytes].get_subvalue::<u8>(index % bytes) }
206			})
207		});
208
209		*values = result;
210	}
211}
212
213/// Returns a bit mask for a single `T` element inside underlier type.
214/// This function is completely optimized out by the compiler in release version
215/// because all the values are known at compile time.
216fn single_element_mask<T>() -> T
217where
218	T: UnderlierWithBitOps,
219{
220	single_element_mask_bits(T::BITS)
221}
222
223/// A helper function to apply unpack_lo/hi_128b_lanes for two values in an array
224#[allow(dead_code)]
225#[inline(always)]
226pub(crate) fn pair_unpack_lo_hi_128b_lanes<U: UnderlierWithBitOps>(
227	values: &mut impl AsMut<[U]>,
228	i: usize,
229	j: usize,
230	log_block_len: usize,
231) {
232	let values = values.as_mut();
233
234	(values[i], values[j]) = (
235		values[i].unpack_lo_128b_lanes(values[j], log_block_len),
236		values[i].unpack_hi_128b_lanes(values[j], log_block_len),
237	);
238}
239
240/// A helper function used as a building block for efficient SIMD types transposition implementation.
241/// This function actually may reorder the elements.
242#[allow(dead_code)]
243#[inline(always)]
244pub(crate) fn transpose_128b_blocks_low_to_high<U: UnderlierWithBitOps, TL: TowerLevel>(
245	values: &mut TL::Data<U>,
246	log_block_len: usize,
247) {
248	assert!(TL::WIDTH <= 16);
249
250	if TL::WIDTH == 1 {
251		return;
252	}
253
254	let (left, right) = TL::split_mut(values);
255	transpose_128b_blocks_low_to_high::<_, TL::Base>(left, log_block_len);
256	transpose_128b_blocks_low_to_high::<_, TL::Base>(right, log_block_len);
257
258	let log_block_len = log_block_len + TL::LOG_WIDTH + 2;
259	for i in 0..TL::WIDTH / 2 {
260		pair_unpack_lo_hi_128b_lanes(values, i, i + TL::WIDTH / 2, log_block_len);
261	}
262}
263
264/// Transposition implementation for 128-bit SIMD types.
265/// This implementations is used for NEON and SSE2.
266#[allow(dead_code)]
267#[inline(always)]
268pub(crate) fn transpose_128b_values<U: UnderlierWithBitOps, TL: TowerLevel>(
269	values: &mut TL::Data<U>,
270	log_block_len: usize,
271) {
272	assert!(U::BITS == 128);
273
274	transpose_128b_blocks_low_to_high::<U, TL>(values, log_block_len);
275
276	// Elements are transposed, but we need to reorder them
277	match TL::LOG_WIDTH {
278		0 | 1 => {}
279		2 => {
280			values.as_mut().swap(1, 2);
281		}
282		3 => {
283			values.as_mut().swap(1, 4);
284			values.as_mut().swap(3, 6);
285		}
286		4 => {
287			values.as_mut().swap(1, 8);
288			values.as_mut().swap(2, 4);
289			values.as_mut().swap(3, 12);
290			values.as_mut().swap(5, 10);
291			values.as_mut().swap(7, 14);
292			values.as_mut().swap(11, 13);
293		}
294		_ => panic!("unsupported tower level"),
295	}
296}
297
298/// Fallback implementation of `spread` method.
299///
300/// # Safety
301/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
302/// `block_idx` must be less than `1 << (U::LOG_BITS - log_block_len)`.
303pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
304where
305	U: UnderlierWithBitOps + From<T>,
306	T: UnderlierWithBitOps + NumCast<U>,
307{
308	debug_assert!(
309		log_block_len + T::LOG_BITS <= U::LOG_BITS,
310		"log_block_len: {}, U::BITS: {}, T::BITS: {}",
311		log_block_len,
312		U::BITS,
313		T::BITS
314	);
315	debug_assert!(
316		block_idx < 1 << (U::LOG_BITS - log_block_len),
317		"block_idx: {}, U::BITS: {}, log_block_len: {}",
318		block_idx,
319		U::BITS,
320		log_block_len
321	);
322
323	let mut result = U::ZERO;
324	let block_offset = block_idx << log_block_len;
325	let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
326	for i in 0..1 << log_block_len {
327		unsafe {
328			result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
329		}
330	}
331
332	for i in 0..log_repeat {
333		result |= result << (1 << (T::LOG_BITS + i));
334	}
335
336	result
337}
338
339#[inline(always)]
340fn single_element_mask_bits_128b_lanes<T: UnderlierWithBitOps>(log_block_len: usize) -> T {
341	let mut mask = single_element_mask_bits(1 << log_block_len);
342	for i in 1..T::BITS / 128 {
343		mask |= mask << (i * 128);
344	}
345
346	mask
347}
348
349pub(crate) fn unpack_lo_128b_fallback<T: UnderlierWithBitOps>(
350	lhs: T,
351	rhs: T,
352	log_block_len: usize,
353) -> T {
354	assert!(log_block_len <= 6);
355
356	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
357
358	let mut result = T::ZERO;
359	for i in 0..1 << (6 - log_block_len) {
360		result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask)
361			.shl_128b_lanes(i << (log_block_len + 1));
362		result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask)
363			.shl_128b_lanes((2 * i + 1) << log_block_len);
364	}
365
366	result
367}
368
369pub(crate) fn unpack_hi_128b_fallback<T: UnderlierWithBitOps>(
370	lhs: T,
371	rhs: T,
372	log_block_len: usize,
373) -> T {
374	assert!(log_block_len <= 6);
375
376	let mask = single_element_mask_bits_128b_lanes::<T>(log_block_len);
377	let mut result = T::ZERO;
378	for i in 0..1 << (6 - log_block_len) {
379		result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
380			.shl_128b_lanes(i << (log_block_len + 1));
381		result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask)
382			.shl_128b_lanes((2 * i + 1) << log_block_len);
383	}
384
385	result
386}
387
388pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
389	if bits_count == T::BITS {
390		!T::ZERO
391	} else {
392		let mut result = T::ONE;
393		for height in 0..checked_log_2(bits_count) {
394			result |= result << (1 << height)
395		}
396
397		result
398	}
399}
400
401/// Value that can be spread to a single u8
402pub(crate) trait SpreadToByte {
403	fn spread_to_byte(self) -> u8;
404}
405
406impl SpreadToByte for U1 {
407	#[inline(always)]
408	fn spread_to_byte(self) -> u8 {
409		u8::fill_with_bit(self.val())
410	}
411}
412
413impl SpreadToByte for U2 {
414	#[inline(always)]
415	fn spread_to_byte(self) -> u8 {
416		let mut result = self.val();
417		result |= result << 2;
418		result |= result << 4;
419
420		result
421	}
422}
423
424impl SpreadToByte for U4 {
425	#[inline(always)]
426	fn spread_to_byte(self) -> u8 {
427		let mut result = self.val();
428		result |= result << 4;
429
430		result
431	}
432}
433
434/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
435///
436/// # Safety
437/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
438#[allow(unused)]
439#[inline(always)]
440pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
441	value: U,
442	block_idx: usize,
443) -> [T; BLOCK_LEN]
444where
445	U: UnderlierWithBitOps + From<T>,
446	T: UnderlierType + NumCast<U>,
447{
448	std::array::from_fn(|i| value.get_subvalue::<T>(block_idx * BLOCK_LEN + i))
449}
450
451/// A helper functions for implementing `UnderlierWithBitOps::spread_unchecked` for SIMD types.
452///
453/// # Safety
454/// `log_block_len + T::LOG_BITS` must be less than or equal to `U::LOG_BITS`.
455#[allow(unused)]
456#[inline(always)]
457pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
458	value: U,
459	block_idx: usize,
460) -> [u8; BLOCK_LEN]
461where
462	U: UnderlierWithBitOps + From<T>,
463	T: UnderlierType + SpreadToByte + NumCast<U>,
464{
465	get_block_values::<U, T, BLOCK_LEN>(value, block_idx).map(SpreadToByte::spread_to_byte)
466}
467
468#[cfg(test)]
469mod tests {
470	use proptest::{arbitrary::any, bits, proptest};
471
472	use super::{
473		super::small_uint::{U1, U2, U4},
474		*,
475	};
476	use crate::tower_levels::{TowerLevel1, TowerLevel2};
477
478	#[test]
479	fn test_from_fn() {
480		assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
481		assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
482		assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
483
484		assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
485		assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
486		assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
487		assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
488		assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
489
490		assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
491		assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
492		assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
493		assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
494		assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
495
496		assert_eq!(u32::from_fn(|_| 0u8), 0);
497		assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
498		assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
499		assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
500	}
501
502	#[test]
503	fn test_broadcast_subvalue() {
504		assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
505		assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
506
507		assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
508		assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
509		assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
510		assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
511
512		assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
513		assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
514		assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
515		assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
516
517		assert_eq!(u32::broadcast_subvalue(0u8), 0);
518		assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
519		assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
520	}
521
522	#[test]
523	fn test_get_subvalue() {
524		let value = 0xab12cd34u32;
525
526		unsafe {
527			assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
528			assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
529			assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
530			assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
531
532			assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
533			assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
534			assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
535			assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
536
537			assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
538			assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
539			assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
540			assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
541
542			assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
543			assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
544			assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
545			assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
546		}
547	}
548
549	proptest! {
550		#[test]
551		fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
552			unsafe {
553				init_val.set_subvalue(i, U1::new(val));
554				assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
555			}
556		}
557
558		#[test]
559		fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
560			unsafe {
561				init_val.set_subvalue(i, U2::new(val));
562				assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
563			}
564		}
565
566		#[test]
567		fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
568			unsafe {
569				init_val.set_subvalue(i, U4::new(val));
570				assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
571			}
572		}
573
574		#[test]
575		fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
576			unsafe {
577				init_val.set_subvalue(i, val);
578				assert_eq!(init_val.get_subvalue::<u8>(i), val);
579			}
580		}
581	}
582
583	#[test]
584	fn test_transpose_from_byte_sliced() {
585		let mut value = [0x01234567u32];
586		u32::transpose_bytes_from_byte_sliced::<TowerLevel1>(&mut value);
587		assert_eq!(value, [0x01234567u32]);
588
589		let mut value = [0x67452301u32, 0xefcdab89u32];
590		u32::transpose_bytes_from_byte_sliced::<TowerLevel2>(&mut value);
591		assert_eq!(value, [0xab238901u32, 0xef67cd45u32]);
592	}
593
594	#[test]
595	fn test_transpose_to_byte_sliced() {
596		let mut value = [0x01234567u32];
597		u32::transpose_bytes_to_byte_sliced::<TowerLevel1>(&mut value);
598		assert_eq!(value, [0x01234567u32]);
599
600		let mut value = [0x67452301u32, 0xefcdab89u32];
601		u32::transpose_bytes_to_byte_sliced::<TowerLevel2>(&mut value);
602		assert_eq!(value, [0xcd894501u32, 0xefab6723u32]);
603	}
604}