binius_math/batch_invert.rs
1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField128bGhash as Ghash, Field};
4
5/// Efficiently inverts multiple field elements simultaneously using Montgomery's batch inversion
6/// trick.
7///
8/// This function implements a binary tree approach to batch field inversion, which is significantly
9/// more efficient than inverting each element individually when dealing with many elements.
10///
11/// # Algorithm Overview
12/// 1. **Product Tree Construction**: Build a binary tree bottom-up where leaves are input elements
13/// and each internal node stores the product of its children using linear addressing
14/// 2. **Root Inversion**: Perform single expensive field inversion on the root (product of all
15/// elements)
16/// 3. **Inverse Propagation**: Propagate the root inverse down through tree levels to compute
17/// individual element inverses
18///
19/// # Performance Benefits
20/// - **Single inversion**: Only one expensive field inversion instead of N inversions
21/// - **Linear addressing**: Simple addition-based indexing for better cache locality
22/// - **Zero handling**: Graceful handling of zero elements without division-by-zero errors
23///
24/// # Parameters
25/// - `elements`: Array of field elements to invert (modified in-place)
26/// - `scratchpad`: Working memory buffer, must be at least `2*N-1` elements
27///
28/// # Requirements
29/// - `N` must be a power of 2 and ≥ 2
30/// - `scratchpad.len() >= 2*N-1` for intermediate computations
31#[inline]
32pub fn batch_invert<const N: usize>(elements: &mut [Ghash], scratchpad: &mut [Ghash]) {
33 assert!(N.is_power_of_two() && N >= 2, "N must be a power of 2 and >= 2");
34 assert_eq!(elements.len(), N);
35 assert!(scratchpad.len() >= 2 * N - 1, "scratchpad too small");
36
37 let zero = Ghash::ZERO;
38 let one = Ghash::ONE;
39 let levels = N.ilog2() as usize;
40
41 // Phase 1: Setup - Copy input elements, replacing zeros with ones
42 // This prevents division-by-zero while preserving zero semantics in final output
43 for i in 0..N {
44 scratchpad[i] = if elements[i] == zero {
45 one // Temporary replacement - restored to zero in final phase
46 } else {
47 elements[i]
48 };
49 }
50
51 // Phase 2: Build product tree bottom-up using linear addressing
52 // Each level combines pairs from the previous level into products
53 let mut dest_offset = N; // Current write position in scratchpad
54
55 // Build intermediate tree levels (N/2, N/4, N/8, ... down to 2 elements)
56 for level in 1..levels {
57 let level_size = N >> level; // Number of products at this level
58 let src_offset = dest_offset - (level_size * 2); // Read from previous level
59
60 // Combine adjacent pairs: scratchpad[2*i] * scratchpad[2*i+1]
61 for i in 0..level_size {
62 scratchpad[dest_offset + i] =
63 scratchpad[src_offset + 2 * i] * scratchpad[src_offset + 2 * i + 1];
64 }
65 dest_offset += level_size; // Move to next level's storage
66 }
67
68 // Final level: multiply the last two products to get root
69 let src_offset = dest_offset - 2;
70 scratchpad[dest_offset] = scratchpad[src_offset] * scratchpad[src_offset + 1];
71
72 // Phase 3: Invert root product (Montgomery's key insight: single inversion)
73 scratchpad[dest_offset] = scratchpad[dest_offset]
74 .invert()
75 .expect("factors are non-zero, so product is non-zero");
76
77 // Phase 4: Propagate inverse down tree levels (reverse order)
78 // Each level computes inverses from the level above
79 for level in 1..levels {
80 let level_size = 1 << level; // Size doubles each level going down
81 let src_offset = dest_offset; // Read from current position
82 dest_offset -= level_size; // Move down to previous level
83
84 // For each pair, compute inverses using: inv(a*b) * b = inv(a), inv(a*b) * a = inv(b)
85 for i in 0..level_size >> 1 {
86 let left_product = scratchpad[dest_offset + 2 * i]; // Original product a
87 scratchpad[dest_offset + 2 * i] =
88 scratchpad[dest_offset + 2 * i + 1] * scratchpad[src_offset + i]; // inv(a) = b * inv(a*b)
89 scratchpad[dest_offset + 2 * i + 1] = left_product * scratchpad[src_offset + i]; // inv(b) = a * inv(a*b)
90 }
91 }
92
93 // Phase 5: Extract final inverses and restore zero semantics
94 // The last layer of products could be done in the loop immediately above,
95 // but for speed we avoid an extra copy by merging it copying from the
96 // scratchpad.
97 for i in 0..N / 2 {
98 let j = 2 * i;
99 // Restore original zeros (marked in Phase 1)
100 elements[j] = if elements[j] == zero {
101 zero
102 } else {
103 scratchpad[j + 1] * scratchpad[dest_offset + i]
104 };
105 elements[j + 1] = if elements[j + 1] == zero {
106 zero
107 } else {
108 scratchpad[j] * scratchpad[dest_offset + i]
109 };
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use binius_field::{Field, Random, arithmetic_traits::InvertOrZero};
116 use rand::{Rng, SeedableRng, rngs::StdRng};
117
118 use super::*;
119
120 fn test_batch_invert_for_size<const N: usize>(rng: &mut StdRng) {
121 let mut state = [Ghash::ZERO; N];
122 for i in 0..N {
123 state[i] = if rng.random::<bool>() {
124 Ghash::ZERO
125 } else {
126 Ghash::random(&mut *rng)
127 };
128 }
129
130 let expected: [Ghash; N] = state.map(|x| x.invert_or_zero());
131
132 let mut scratchpad = vec![Ghash::ZERO; 2 * N - 1];
133 batch_invert::<N>(&mut state, &mut scratchpad);
134
135 assert_eq!(state, expected);
136 }
137
138 #[test]
139 fn test_batch_invert() {
140 let mut rng = StdRng::seed_from_u64(0);
141
142 for _ in 0..4 {
143 test_batch_invert_for_size::<2>(&mut rng);
144 test_batch_invert_for_size::<4>(&mut rng);
145 test_batch_invert_for_size::<8>(&mut rng);
146 test_batch_invert_for_size::<16>(&mut rng);
147 test_batch_invert_for_size::<32>(&mut rng);
148 test_batch_invert_for_size::<64>(&mut rng);
149 test_batch_invert_for_size::<128>(&mut rng);
150 test_batch_invert_for_size::<256>(&mut rng);
151 }
152 }
153}