binius_math/
batch_invert.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField128bGhash as Ghash, Field};
4
5/// Efficiently inverts multiple field elements simultaneously using Montgomery's batch inversion
6/// trick.
7///
8/// This function implements a binary tree approach to batch field inversion, which is significantly
9/// more efficient than inverting each element individually when dealing with many elements.
10///
11/// # Algorithm Overview
12/// 1. **Product Tree Construction**: Build a binary tree bottom-up where leaves are input elements
13///    and each internal node stores the product of its children using linear addressing
14/// 2. **Root Inversion**: Perform single expensive field inversion on the root (product of all
15///    elements)
16/// 3. **Inverse Propagation**: Propagate the root inverse down through tree levels to compute
17///    individual element inverses
18///
19/// # Performance Benefits
20/// - **Single inversion**: Only one expensive field inversion instead of N inversions
21/// - **Linear addressing**: Simple addition-based indexing for better cache locality
22/// - **Zero handling**: Graceful handling of zero elements without division-by-zero errors
23///
24/// # Parameters
25/// - `elements`: Array of field elements to invert (modified in-place)
26/// - `scratchpad`: Working memory buffer, must be at least `2*N-1` elements
27///
28/// # Requirements
29/// - `N` must be a power of 2 and ≥ 2
30/// - `scratchpad.len() >= 2*N-1` for intermediate computations
31#[inline]
32pub fn batch_invert<const N: usize>(elements: &mut [Ghash], scratchpad: &mut [Ghash]) {
33	assert!(N.is_power_of_two() && N >= 2, "N must be a power of 2 and >= 2");
34	assert_eq!(elements.len(), N);
35	assert!(scratchpad.len() >= 2 * N - 1, "scratchpad too small");
36
37	let zero = Ghash::ZERO;
38	let one = Ghash::ONE;
39	let levels = N.ilog2() as usize;
40
41	// Phase 1: Setup - Copy input elements, replacing zeros with ones
42	// This prevents division-by-zero while preserving zero semantics in final output
43	for i in 0..N {
44		scratchpad[i] = if elements[i] == zero {
45			one // Temporary replacement - restored to zero in final phase
46		} else {
47			elements[i]
48		};
49	}
50
51	// Phase 2: Build product tree bottom-up using linear addressing
52	// Each level combines pairs from the previous level into products
53	let mut dest_offset = N; // Current write position in scratchpad
54
55	// Build intermediate tree levels (N/2, N/4, N/8, ... down to 2 elements)
56	for level in 1..levels {
57		let level_size = N >> level; // Number of products at this level
58		let src_offset = dest_offset - (level_size * 2); // Read from previous level
59
60		// Combine adjacent pairs: scratchpad[2*i] * scratchpad[2*i+1]
61		for i in 0..level_size {
62			scratchpad[dest_offset + i] =
63				scratchpad[src_offset + 2 * i] * scratchpad[src_offset + 2 * i + 1];
64		}
65		dest_offset += level_size; // Move to next level's storage
66	}
67
68	// Final level: multiply the last two products to get root
69	let src_offset = dest_offset - 2;
70	scratchpad[dest_offset] = scratchpad[src_offset] * scratchpad[src_offset + 1];
71
72	// Phase 3: Invert root product (Montgomery's key insight: single inversion)
73	scratchpad[dest_offset] = scratchpad[dest_offset]
74		.invert()
75		.expect("factors are non-zero, so product is non-zero");
76
77	// Phase 4: Propagate inverse down tree levels (reverse order)
78	// Each level computes inverses from the level above
79	for level in 1..levels {
80		let level_size = 1 << level; // Size doubles each level going down
81		let src_offset = dest_offset; // Read from current position
82		dest_offset -= level_size; // Move down to previous level
83
84		// For each pair, compute inverses using: inv(a*b) * b = inv(a), inv(a*b) * a = inv(b)
85		for i in 0..level_size >> 1 {
86			let left_product = scratchpad[dest_offset + 2 * i]; // Original product a
87			scratchpad[dest_offset + 2 * i] =
88				scratchpad[dest_offset + 2 * i + 1] * scratchpad[src_offset + i]; // inv(a) = b * inv(a*b)
89			scratchpad[dest_offset + 2 * i + 1] = left_product * scratchpad[src_offset + i]; // inv(b) = a * inv(a*b)
90		}
91	}
92
93	// Phase 5: Extract final inverses and restore zero semantics
94	// The last layer of products could be done in the loop immediately above,
95	// but for speed we avoid an extra copy by merging it copying from the
96	// scratchpad.
97	for i in 0..N / 2 {
98		let j = 2 * i;
99		// Restore original zeros (marked in Phase 1)
100		elements[j] = if elements[j] == zero {
101			zero
102		} else {
103			scratchpad[j + 1] * scratchpad[dest_offset + i]
104		};
105		elements[j + 1] = if elements[j + 1] == zero {
106			zero
107		} else {
108			scratchpad[j] * scratchpad[dest_offset + i]
109		};
110	}
111}
112
113#[cfg(test)]
114mod tests {
115	use binius_field::{Field, Random, arithmetic_traits::InvertOrZero};
116	use rand::{Rng, SeedableRng, rngs::StdRng};
117
118	use super::*;
119
120	fn test_batch_invert_for_size<const N: usize>(rng: &mut StdRng) {
121		let mut state = [Ghash::ZERO; N];
122		for i in 0..N {
123			state[i] = if rng.random::<bool>() {
124				Ghash::ZERO
125			} else {
126				Ghash::random(&mut *rng)
127			};
128		}
129
130		let expected: [Ghash; N] = state.map(|x| x.invert_or_zero());
131
132		let mut scratchpad = vec![Ghash::ZERO; 2 * N - 1];
133		batch_invert::<N>(&mut state, &mut scratchpad);
134
135		assert_eq!(state, expected);
136	}
137
138	#[test]
139	fn test_batch_invert() {
140		let mut rng = StdRng::seed_from_u64(0);
141
142		for _ in 0..4 {
143			test_batch_invert_for_size::<2>(&mut rng);
144			test_batch_invert_for_size::<4>(&mut rng);
145			test_batch_invert_for_size::<8>(&mut rng);
146			test_batch_invert_for_size::<16>(&mut rng);
147			test_batch_invert_for_size::<32>(&mut rng);
148			test_batch_invert_for_size::<64>(&mut rng);
149			test_batch_invert_for_size::<128>(&mut rng);
150			test_batch_invert_for_size::<256>(&mut rng);
151		}
152	}
153}