binius_math/
bit_reverse.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{PackedField, square_transpose};
4use binius_utils::{checked_arithmetics::log2_strict_usize, rayon::prelude::*};
5use bytemuck::zeroed_vec;
6
7use crate::field_buffer::FieldSliceMut;
8
9/// Reverses the low `bits` bits of an unsigned integer.
10///
11/// # Arguments
12///
13/// * `x` - The value whose bits to reverse
14/// * `bits` - The number of low-order bits to reverse
15///
16/// # Returns
17///
18/// The value with its low `bits` bits reversed
19pub fn reverse_bits(x: usize, bits: u32) -> usize {
20	x.reverse_bits().unbounded_shr(usize::BITS - bits)
21}
22
23/// Applies a bit-reversal permutation to packed field elements in a buffer using parallelization.
24///
25/// This function permutes the field elements such that element at index `i` is moved to
26/// index `reverse_bits(i, log_len)`. The permutation is performed in-place and correctly
27/// handles packed field representations.
28///
29/// # Arguments
30///
31/// * `buffer` - Mutable slice of packed field elements to permute
32pub fn bit_reverse_packed<P: PackedField>(mut buffer: FieldSliceMut<P>) {
33	// The algorithm has two parallelized phases:
34	// 1. Process P::WIDTH x P::WIDTH submatrices in parallel
35	// 2. Apply bit-reversal to independent chunks in parallel
36
37	let log_len = buffer.log_len();
38	if log_len < 2 * P::LOG_WIDTH {
39		return bit_reverse_packed_naive(buffer);
40	}
41
42	let bits = (log_len - P::LOG_WIDTH) as u32;
43	let data = buffer.as_mut();
44
45	// Phase 1: Process submatrices in parallel
46	// Each iteration accesses disjoint memory locations, so parallelization is safe
47	let data_ptr = data.as_mut_ptr() as usize;
48	(0..1 << (log_len - 2 * P::LOG_WIDTH))
49		.into_par_iter()
50		.for_each_init(
51			|| zeroed_vec::<P>(P::WIDTH),
52			|tmp, i| {
53				// SAFETY: Different values of i access non-overlapping submatrices.
54				// The indexing pattern reverse_bits(j, bits) | i ensures that:
55				// - reverse_bits(j, bits) places j in the high bits
56				// - | i places i in the low bits
57				// Therefore, different i values access completely disjoint index sets.
58				unsafe {
59					let data = data_ptr as *mut P;
60					for j in 0..P::WIDTH {
61						tmp[j] = *data.add(reverse_bits(j, bits) | i);
62					}
63				}
64				square_transpose(P::LOG_WIDTH, tmp).expect("pre-conditions satisfied");
65				unsafe {
66					let data = data_ptr as *mut P;
67					for j in 0..P::WIDTH {
68						*data.add(reverse_bits(j, bits) | i) = tmp[j];
69					}
70				}
71			},
72		);
73
74	// Phase 2: Apply bit_reverse_indices to chunks in parallel
75	// Chunks are non-overlapping, so this is safe
76	data.par_chunks_mut(1 << (log_len - 2 * P::LOG_WIDTH))
77		.for_each(|chunk| {
78			bit_reverse_indices(chunk);
79		});
80}
81
82/// Applies a bit-reversal permutation to packed field elements using a simple algorithm.
83///
84/// This is a straightforward reference implementation that directly swaps field elements
85/// according to the bit-reversal permutation. It serves as a baseline for correctness
86/// testing of optimized implementations.
87///
88/// # Arguments
89///
90/// * `buffer` - Mutable slice of packed field elements to permute
91fn bit_reverse_packed_naive<P: PackedField>(mut buffer: FieldSliceMut<P>) {
92	let bits = buffer.log_len() as u32;
93	for i in 0..buffer.len() {
94		let i_rev = reverse_bits(i, bits);
95		if i < i_rev {
96			let tmp = buffer.get(i);
97			buffer.set(i, buffer.get(i_rev));
98			buffer.set(i_rev, tmp);
99		}
100	}
101}
102
103/// Applies a bit-reversal permutation to elements in a slice using parallel iteration.
104///
105/// This function permutes the elements such that element at index `i` is moved to
106/// index `reverse_bits(i, log2(length))`. The permutation is performed in-place
107/// by swapping elements in parallel.
108///
109/// # Arguments
110///
111/// * `buffer` - Mutable slice of elements to permute
112///
113/// # Panics
114///
115/// Panics if the buffer length is not a power of two.
116pub fn bit_reverse_indices<T>(buffer: &mut [T]) {
117	let bits = log2_strict_usize(buffer.len()) as u32;
118
119	// We need to use UnsafeCell-like semantics here to get proper Sync behavior.
120	// Creating a raw pointer from the slice inside the closure avoids Sync issues.
121	let buffer_ptr = buffer.as_mut_ptr() as usize;
122
123	(0..buffer.len()).into_par_iter().for_each(|i| {
124		let i_rev = reverse_bits(i, bits);
125		if i < i_rev {
126			// SAFETY: The i < i_rev condition guarantees that:
127			// 1. Each (i, i_rev) pair is processed by exactly one thread (the one with i < i_rev)
128			// 2. Since bit-reversal is bijective, no two threads access the same pair
129			// 3. Therefore, ptr.add(i) and ptr.add(i_rev) point to disjoint memory locations
130			// 4. No data races can occur
131			// 5. buffer_ptr is valid for the lifetime of this closure
132			unsafe {
133				let ptr = buffer_ptr as *mut T;
134				let ptr_i = ptr.add(i);
135				let ptr_i_rev = ptr.add(i_rev);
136				std::ptr::swap_nonoverlapping(ptr_i, ptr_i_rev, 1);
137			}
138		}
139	});
140}
141
142#[cfg(test)]
143mod tests {
144	use rand::{SeedableRng, rngs::StdRng};
145
146	use super::*;
147	use crate::test_utils::{Packed128b, random_field_buffer};
148
149	// For Packed128b (PackedBinaryGhash4x128b), LOG_WIDTH = 2, so 2 * LOG_WIDTH = 4
150	// Test three cases around the threshold where bit_reverse_packed switches between
151	// naive and optimized implementations
152	#[rstest::rstest]
153	#[case::below_threshold(3)] // log_d < 2 * P::LOG_WIDTH
154	#[case::at_threshold(4)] // log_d == 2 * P::LOG_WIDTH
155	#[case::above_threshold(8)] // log_d > 2 * P::LOG_WIDTH
156	fn test_bit_reverse_packed_equivalence(#[case] log_d: usize) {
157		let mut rng = StdRng::seed_from_u64(0);
158
159		let data_orig = random_field_buffer::<Packed128b>(&mut rng, log_d);
160
161		let mut data_optimized = data_orig.clone();
162		let mut data_naive = data_orig.clone();
163
164		bit_reverse_packed(data_optimized.to_mut());
165		bit_reverse_packed_naive(data_naive.to_mut());
166
167		assert_eq!(data_optimized, data_naive, "Mismatch at log_d={}", log_d);
168	}
169}