binius_math/
batch_invert.rs

1// Copyright 2025-2026 The Binius Developers
2// Copyright 2025 Irreducible Inc.
3
4use std::iter;
5
6use binius_field::{Field, PackedField};
7
8/// Reusable batch inversion context that owns its scratch buffers.
9///
10/// This struct manages the memory needed for batch field inversions, allowing
11/// efficient reuse across multiple inversion operations of the same size.
12///
13/// # Example
14/// ```
15/// use binius_field::{BinaryField128bGhash, Field};
16/// use binius_math::batch_invert::BatchInversion;
17///
18/// let mut inverter = BatchInversion::<BinaryField128bGhash>::new(8);
19/// let mut elements = [BinaryField128bGhash::ONE; 8];
20/// inverter.invert_or_zero(&mut elements);
21/// ```
22pub struct BatchInversion<P: PackedField> {
23	n: usize,
24	scratchpad: Vec<P>,
25	is_zero: Vec<bool>,
26	/// Nested inverter for handling the base case when WIDTH > 1.
27	/// When we recurse down to a single packed element, we need to
28	/// batch-invert its WIDTH scalar elements.
29	scalar_inverter: Option<Box<BatchInversion<P::Scalar>>>,
30}
31
32impl<P: PackedField> BatchInversion<P> {
33	/// Creates a new batch inversion context for slices of size `n`.
34	///
35	/// Allocates the necessary scratch space:
36	/// - `scratchpad`: Storage for intermediate products during recursion
37	/// - `is_zero`: Tracking vector for zero elements (one per scalar)
38	/// - `scalar_inverter`: Nested inverter for base case when WIDTH > 1
39	///
40	/// # Parameters
41	/// - `n`: The number of packed elements this instance will handle
42	///
43	/// # Panics
44	/// Panics if `n` is 0.
45	pub fn new(n: usize) -> Self {
46		assert!(n > 0, "n must be greater than 0");
47
48		let scratchpad_size = min_scratchpad_size(n);
49		let scalar_inverter = if P::WIDTH > 1 {
50			Some(Box::new(BatchInversion::<P::Scalar>::new(P::WIDTH)))
51		} else {
52			None
53		};
54		Self {
55			n,
56			scratchpad: vec![P::zero(); scratchpad_size],
57			is_zero: vec![false; n * P::WIDTH],
58			scalar_inverter,
59		}
60	}
61
62	/// Inverts non-zero elements in-place.
63	///
64	/// # Parameters
65	/// - `elements`: Mutable slice to invert in-place
66	///
67	/// # Preconditions
68	/// All scalar elements must be non-zero. Behavior is undefined if any scalar is zero.
69	///
70	/// # Panics
71	/// Panics if `elements.len() != n` (the size specified at construction).
72	pub fn invert_nonzero(&mut self, elements: &mut [P]) {
73		assert_eq!(
74			elements.len(),
75			self.n,
76			"elements.len() must equal n (expected {}, got {})",
77			self.n,
78			elements.len()
79		);
80
81		self.batch_invert_nonzero(elements);
82	}
83
84	/// Inverts elements in-place, handling zeros gracefully.
85	///
86	/// Zero scalar elements remain zero after inversion, while non-zero scalars
87	/// are replaced with their multiplicative inverses.
88	///
89	/// # Parameters
90	/// - `elements`: Mutable slice to invert in-place
91	///
92	/// # Panics
93	/// Panics if `elements.len() != n` (the size specified at construction).
94	pub fn invert_or_zero(&mut self, elements: &mut [P]) {
95		assert_eq!(
96			elements.len(),
97			self.n,
98			"elements.len() must equal n (expected {}, got {})",
99			self.n,
100			elements.len()
101		);
102
103		// Mark zeros at scalar level and replace with ones
104		for (packed_idx, packed) in elements.iter_mut().enumerate() {
105			for lane in 0..P::WIDTH {
106				let scalar_idx = packed_idx * P::WIDTH + lane;
107				let scalar = packed.get(lane);
108				if scalar == P::Scalar::ZERO {
109					packed.set(lane, P::Scalar::ONE);
110					self.is_zero[scalar_idx] = true;
111				} else {
112					self.is_zero[scalar_idx] = false;
113				}
114			}
115		}
116
117		// Perform inversion on non-zero elements
118		self.invert_nonzero(elements);
119
120		// Restore zeros at scalar level
121		for (packed_idx, packed) in elements.iter_mut().enumerate() {
122			for lane in 0..P::WIDTH {
123				let scalar_idx = packed_idx * P::WIDTH + lane;
124				if self.is_zero[scalar_idx] {
125					packed.set(lane, P::Scalar::ZERO);
126				}
127			}
128		}
129	}
130}
131
132fn min_scratchpad_size(mut n: usize) -> usize {
133	assert!(n > 0);
134
135	let mut size = 0;
136	while n > 1 {
137		n = n.div_ceil(2);
138		size += n;
139	}
140	size
141}
142
143impl<P: PackedField> BatchInversion<P> {
144	fn batch_invert_nonzero(&mut self, elements: &mut [P]) {
145		batch_invert_nonzero_with_scratchpad(
146			elements,
147			&mut self.scratchpad,
148			self.scalar_inverter.as_deref_mut(),
149		);
150	}
151}
152
153fn batch_invert_nonzero_with_scratchpad<P: PackedField>(
154	elements: &mut [P],
155	scratchpad: &mut [P],
156	scalar_inverter: Option<&mut BatchInversion<P::Scalar>>,
157) {
158	debug_assert!(!elements.is_empty());
159
160	if elements.len() == 1 {
161		let packed = &mut elements[0];
162		if P::WIDTH == 1 {
163			// Direct scalar inversion
164			let scalar = packed.get(0);
165			let inv = scalar
166				.invert()
167				.expect("precondition: elements contains no zeros");
168			packed.set(0, inv);
169		} else {
170			// Unpack, batch invert scalars, repack
171			let mut scalars = packed.into_iter().collect::<Vec<_>>();
172			scalar_inverter
173				.expect("scalar_inverter must be Some when WIDTH > 1")
174				.invert_nonzero(&mut scalars);
175			*packed = P::from_scalars(scalars);
176		}
177		return;
178	}
179
180	let next_layer_len = elements.len().div_ceil(2);
181	let (next_layer, remaining) = scratchpad.split_at_mut(next_layer_len);
182	product_layer(elements, next_layer);
183	batch_invert_nonzero_with_scratchpad(next_layer, remaining, scalar_inverter);
184	unproduct_layer(next_layer, elements);
185}
186
187/// Computes element-wise products of top and bottom halves.
188///
189/// Pairs `input[i]` with `input[half + i]` for parallel efficiency.
190/// For odd-length inputs, the middle element is copied through.
191#[inline]
192fn product_layer<P: PackedField>(input: &[P], output: &mut [P]) {
193	debug_assert_eq!(output.len(), input.len().div_ceil(2));
194
195	let (lo, hi) = input.split_at(output.len());
196	let mut out_lo_iter = iter::zip(output, lo);
197
198	if hi.len() < out_lo_iter.len() {
199		let Some((out_i, lo_i)) = out_lo_iter.next_back() else {
200			unreachable!("out_lo_iter.len() must be greater than zero");
201		};
202		*out_i = *lo_i;
203	}
204	for ((out_i, &lo_i), &hi_i) in iter::zip(out_lo_iter, hi) {
205		*out_i = lo_i * hi_i;
206	}
207}
208
209/// Unwinds product_layer to recover individual inverses.
210///
211/// Given inverted products and original values, recovers:
212/// - `output[i] = input[i] * output[half + i]` (inverse of lo half element)
213/// - `output[half + i] = input[i] * output[i]` (inverse of hi half element)
214#[inline]
215fn unproduct_layer<P: PackedField>(input: &[P], output: &mut [P]) {
216	debug_assert_eq!(input.len(), output.len().div_ceil(2));
217
218	let (lo, hi) = output.split_at_mut(input.len());
219	let mut lo_in_iter = iter::zip(lo, input);
220
221	if hi.len() < lo_in_iter.len() {
222		let Some((lo_i, in_i)) = lo_in_iter.next_back() else {
223			unreachable!("out_lo_iter.len() must be greater than zero");
224		};
225		*lo_i = *in_i;
226	}
227	for ((lo_i, &in_i), hi_i) in iter::zip(lo_in_iter, hi) {
228		let lo_tmp = *lo_i;
229		let hi_tmp = *hi_i;
230		*lo_i = in_i * hi_tmp;
231		*hi_i = in_i * lo_tmp;
232	}
233}
234
235#[cfg(test)]
236mod tests {
237	use binius_field::{BinaryField128bGhash as Ghash, Random, arithmetic_traits::InvertOrZero};
238	use proptest::prelude::*;
239	use rand::{Rng, SeedableRng, rngs::StdRng, seq::IteratorRandom};
240
241	use super::*;
242
243	/// Shared helper to test batch inversion with a given inverter.
244	fn invert_with_inverter(
245		inverter: &mut BatchInversion<Ghash>,
246		n: usize,
247		n_zeros: usize,
248		rng: &mut impl Rng,
249	) {
250		assert!(n_zeros <= n, "n_zeros must be <= n");
251
252		// Sample indices for zeros without replacement
253		let zero_indices: Vec<usize> = (0..n).choose_multiple(rng, n_zeros);
254
255		// Create state vector with zeros at sampled indices
256		let mut state = Vec::with_capacity(n);
257		for i in 0..n {
258			if zero_indices.contains(&i) {
259				state.push(Ghash::ZERO);
260			} else {
261				state.push(Ghash::random(&mut *rng));
262			}
263		}
264
265		let expected: Vec<Ghash> = state
266			.iter()
267			.map(|x| InvertOrZero::invert_or_zero(*x))
268			.collect();
269
270		inverter.invert_or_zero(&mut state);
271
272		assert_eq!(state, expected);
273	}
274
275	fn test_batch_inversion_for_size(n: usize, n_zeros: usize, rng: &mut impl Rng) {
276		let mut inverter = BatchInversion::<Ghash>::new(n);
277		invert_with_inverter(&mut inverter, n, n_zeros, rng);
278	}
279
280	fn test_batch_inversion_nonzero_for_size(n: usize, rng: &mut impl Rng) {
281		let mut state = Vec::with_capacity(n);
282		for _ in 0..n {
283			state.push(Ghash::random(&mut *rng));
284		}
285
286		let expected: Vec<Ghash> = state
287			.iter()
288			.map(|x| InvertOrZero::invert_or_zero(*x))
289			.collect();
290
291		let mut inverter = BatchInversion::<Ghash>::new(n);
292		inverter.invert_nonzero(&mut state);
293
294		assert_eq!(state, expected);
295	}
296
297	proptest! {
298		#[test]
299		fn test_batch_inversion(n in 1usize..=16, n_zeros in 0usize..=16) {
300			prop_assume!(n_zeros <= n);
301			let mut rng = StdRng::seed_from_u64(0);
302			test_batch_inversion_for_size(n, n_zeros, &mut rng);
303		}
304
305		#[test]
306		fn test_batch_inversion_nonzero(n in 1usize..=16) {
307			let mut rng = StdRng::seed_from_u64(0);
308			test_batch_inversion_nonzero_for_size(n, &mut rng);
309		}
310	}
311
312	#[test]
313	fn test_batch_inversion_reuse() {
314		let mut rng = StdRng::seed_from_u64(0);
315		let mut inverter = BatchInversion::<Ghash>::new(8);
316
317		// Test reusing the same inverter multiple times
318		for n_zeros in 0..=8 {
319			invert_with_inverter(&mut inverter, 8, n_zeros, &mut rng);
320		}
321	}
322
323	/// Test batch inversion with a packed field (WIDTH > 1)
324	#[test]
325	fn test_batch_inversion_packed() {
326		use crate::test_utils::Packed128b;
327
328		let mut rng = StdRng::seed_from_u64(0);
329		const N: usize = 4;
330
331		// Create packed elements with some zeros at various positions
332		let mut state: Vec<Packed128b> = (0..N)
333			.map(|i| {
334				Packed128b::from_fn(|lane| {
335					// Put zeros at specific positions
336					if (i == 1 && lane == 0) || (i == 2 && lane == 2) {
337						Ghash::ZERO
338					} else {
339						Ghash::random(&mut rng)
340					}
341				})
342			})
343			.collect();
344
345		// Compute expected by inverting each scalar
346		let expected: Vec<Packed128b> = state
347			.iter()
348			.map(|packed| Packed128b::from_scalars(packed.iter().map(InvertOrZero::invert_or_zero)))
349			.collect();
350
351		let mut inverter = BatchInversion::<Packed128b>::new(N);
352		inverter.invert_or_zero(&mut state);
353
354		assert_eq!(state, expected);
355	}
356}