binius_math/bit_reverse.rs
1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{PackedField, square_transpose};
4use binius_utils::{checked_arithmetics::log2_strict_usize, rayon::prelude::*};
5use bytemuck::zeroed_vec;
6
7use crate::field_buffer::FieldSliceMut;
8
9/// Reverses the low `bits` bits of an unsigned integer.
10///
11/// # Arguments
12///
13/// * `x` - The value whose bits to reverse
14/// * `bits` - The number of low-order bits to reverse
15///
16/// # Returns
17///
18/// The value with its low `bits` bits reversed
19pub fn reverse_bits(x: usize, bits: u32) -> usize {
20 x.reverse_bits().unbounded_shr(usize::BITS - bits)
21}
22
23/// Applies a bit-reversal permutation to packed field elements in a buffer using parallelization.
24///
25/// This function permutes the field elements such that element at index `i` is moved to
26/// index `reverse_bits(i, log_len)`. The permutation is performed in-place and correctly
27/// handles packed field representations.
28///
29/// # Arguments
30///
31/// * `buffer` - Mutable slice of packed field elements to permute
32pub fn bit_reverse_packed<P: PackedField>(mut buffer: FieldSliceMut<P>) {
33 // The algorithm has two parallelized phases:
34 // 1. Process P::WIDTH x P::WIDTH submatrices in parallel
35 // 2. Apply bit-reversal to independent chunks in parallel
36
37 let log_len = buffer.log_len();
38 if log_len < 2 * P::LOG_WIDTH {
39 return bit_reverse_packed_naive(buffer);
40 }
41
42 let bits = (log_len - P::LOG_WIDTH) as u32;
43 let data = buffer.as_mut();
44
45 // Phase 1: Process submatrices in parallel
46 // Each iteration accesses disjoint memory locations, so parallelization is safe
47 let data_ptr = data.as_mut_ptr() as usize;
48 (0..1 << (log_len - 2 * P::LOG_WIDTH))
49 .into_par_iter()
50 .for_each_init(
51 || zeroed_vec::<P>(P::WIDTH),
52 |tmp, i| {
53 // SAFETY: Different values of i access non-overlapping submatrices.
54 // The indexing pattern reverse_bits(j, bits) | i ensures that:
55 // - reverse_bits(j, bits) places j in the high bits
56 // - | i places i in the low bits
57 // Therefore, different i values access completely disjoint index sets.
58 unsafe {
59 let data = data_ptr as *mut P;
60 for j in 0..P::WIDTH {
61 tmp[j] = *data.add(reverse_bits(j, bits) | i);
62 }
63 }
64 square_transpose(P::LOG_WIDTH, tmp).expect("pre-conditions satisfied");
65 unsafe {
66 let data = data_ptr as *mut P;
67 for j in 0..P::WIDTH {
68 *data.add(reverse_bits(j, bits) | i) = tmp[j];
69 }
70 }
71 },
72 );
73
74 // Phase 2: Apply bit_reverse_indices to chunks in parallel
75 // Chunks are non-overlapping, so this is safe
76 data.par_chunks_mut(1 << (log_len - 2 * P::LOG_WIDTH))
77 .for_each(|chunk| {
78 bit_reverse_indices(chunk);
79 });
80}
81
82/// Applies a bit-reversal permutation to packed field elements using a simple algorithm.
83///
84/// This is a straightforward reference implementation that directly swaps field elements
85/// according to the bit-reversal permutation. It serves as a baseline for correctness
86/// testing of optimized implementations.
87///
88/// # Arguments
89///
90/// * `buffer` - Mutable slice of packed field elements to permute
91fn bit_reverse_packed_naive<P: PackedField>(mut buffer: FieldSliceMut<P>) {
92 let bits = buffer.log_len() as u32;
93 for i in 0..buffer.len() {
94 let i_rev = reverse_bits(i, bits);
95 if i < i_rev {
96 let tmp = buffer.get(i);
97 buffer.set(i, buffer.get(i_rev));
98 buffer.set(i_rev, tmp);
99 }
100 }
101}
102
103/// Applies a bit-reversal permutation to elements in a slice using parallel iteration.
104///
105/// This function permutes the elements such that element at index `i` is moved to
106/// index `reverse_bits(i, log2(length))`. The permutation is performed in-place
107/// by swapping elements in parallel.
108///
109/// # Arguments
110///
111/// * `buffer` - Mutable slice of elements to permute
112///
113/// # Panics
114///
115/// Panics if the buffer length is not a power of two.
116pub fn bit_reverse_indices<T>(buffer: &mut [T]) {
117 let bits = log2_strict_usize(buffer.len()) as u32;
118
119 // We need to use UnsafeCell-like semantics here to get proper Sync behavior.
120 // Creating a raw pointer from the slice inside the closure avoids Sync issues.
121 let buffer_ptr = buffer.as_mut_ptr() as usize;
122
123 (0..buffer.len()).into_par_iter().for_each(|i| {
124 let i_rev = reverse_bits(i, bits);
125 if i < i_rev {
126 // SAFETY: The i < i_rev condition guarantees that:
127 // 1. Each (i, i_rev) pair is processed by exactly one thread (the one with i < i_rev)
128 // 2. Since bit-reversal is bijective, no two threads access the same pair
129 // 3. Therefore, ptr.add(i) and ptr.add(i_rev) point to disjoint memory locations
130 // 4. No data races can occur
131 // 5. buffer_ptr is valid for the lifetime of this closure
132 unsafe {
133 let ptr = buffer_ptr as *mut T;
134 let ptr_i = ptr.add(i);
135 let ptr_i_rev = ptr.add(i_rev);
136 std::ptr::swap_nonoverlapping(ptr_i, ptr_i_rev, 1);
137 }
138 }
139 });
140}
141
142#[cfg(test)]
143mod tests {
144 use rand::{SeedableRng, rngs::StdRng};
145
146 use super::*;
147 use crate::test_utils::{Packed128b, random_field_buffer};
148
149 // For Packed128b (PackedBinaryGhash4x128b), LOG_WIDTH = 2, so 2 * LOG_WIDTH = 4
150 // Test three cases around the threshold where bit_reverse_packed switches between
151 // naive and optimized implementations
152 #[rstest::rstest]
153 #[case::below_threshold(3)] // log_d < 2 * P::LOG_WIDTH
154 #[case::at_threshold(4)] // log_d == 2 * P::LOG_WIDTH
155 #[case::above_threshold(8)] // log_d > 2 * P::LOG_WIDTH
156 fn test_bit_reverse_packed_equivalence(#[case] log_d: usize) {
157 let mut rng = StdRng::seed_from_u64(0);
158
159 let data_orig = random_field_buffer::<Packed128b>(&mut rng, log_d);
160
161 let mut data_optimized = data_orig.clone();
162 let mut data_naive = data_orig.clone();
163
164 bit_reverse_packed(data_optimized.to_mut());
165 bit_reverse_packed_naive(data_naive.to_mut());
166
167 assert_eq!(data_optimized, data_naive, "Mismatch at log_d={}", log_d);
168 }
169}