binius_math/
batch_invert.rs

1// Copyright 2025 The Binius Developers
2// Copyright 2025 Irreducible Inc.
3
4use std::iter;
5
6use binius_field::Field;
7
8/// Reusable batch inversion context that owns its scratch buffers.
9///
10/// This struct manages the memory needed for batch field inversions, allowing
11/// efficient reuse across multiple inversion operations of the same size.
12///
13/// # Example
14/// ```
15/// use binius_field::{BinaryField128bGhash, Field};
16/// use binius_math::batch_invert::BatchInversion;
17///
18/// let mut inverter = BatchInversion::<BinaryField128bGhash>::new(8);
19/// let mut elements = [BinaryField128bGhash::ONE; 8];
20/// inverter.invert_or_zero(&mut elements);
21/// ```
22pub struct BatchInversion<F: Field> {
23	n: usize,
24	scratchpad: Vec<F>,
25	is_zero: Vec<bool>,
26}
27
28impl<F: Field> BatchInversion<F> {
29	/// Creates a new batch inversion context for slices of size `n`.
30	///
31	/// Allocates the necessary scratch space:
32	/// - `scratchpad`: Storage for intermediate products during recursion
33	/// - `is_zero`: Tracking vector for zero elements
34	///
35	/// # Parameters
36	/// - `n`: The size of slices this instance will handle
37	///
38	/// # Panics
39	/// Panics if `n` is 0.
40	pub fn new(n: usize) -> Self {
41		assert!(n > 0, "n must be greater than 0");
42
43		let scratchpad_size = min_scratchpad_size(n);
44		Self {
45			n,
46			scratchpad: vec![F::ZERO; scratchpad_size],
47			is_zero: vec![false; n],
48		}
49	}
50
51	/// Inverts non-zero elements in-place.
52	///
53	/// # Parameters
54	/// - `elements`: Mutable slice to invert in-place
55	///
56	/// # Preconditions
57	/// All elements must be non-zero. Behavior is undefined if any element is zero.
58	///
59	/// # Panics
60	/// Panics if `elements.len() != n` (the size specified at construction).
61	pub fn invert_nonzero(&mut self, elements: &mut [F]) {
62		assert_eq!(
63			elements.len(),
64			self.n,
65			"elements.len() must equal n (expected {}, got {})",
66			self.n,
67			elements.len()
68		);
69
70		batch_invert_nonzero(elements, &mut self.scratchpad);
71	}
72
73	/// Inverts elements in-place, handling zeros gracefully.
74	///
75	/// Zero elements remain zero after inversion, while non-zero elements
76	/// are replaced with their multiplicative inverses.
77	///
78	/// # Parameters
79	/// - `elements`: Mutable slice to invert in-place
80	///
81	/// # Panics
82	/// Panics if `elements.len() != n` (the size specified at construction).
83	pub fn invert_or_zero(&mut self, elements: &mut [F]) {
84		assert_eq!(
85			elements.len(),
86			self.n,
87			"elements.len() must equal n (expected {}, got {})",
88			self.n,
89			elements.len()
90		);
91
92		// Mark zeros and replace with ones
93		for (element_i, is_zero_i) in iter::zip(&mut *elements, &mut self.is_zero) {
94			if *element_i == F::ZERO {
95				*element_i = F::ONE;
96				*is_zero_i = true;
97			} else {
98				*is_zero_i = false;
99			}
100		}
101
102		// Perform inversion on non-zero elements
103		self.invert_nonzero(elements);
104
105		// Restore zeros
106		for (element_i, is_zero_i) in iter::zip(elements, &self.is_zero) {
107			if *is_zero_i {
108				*element_i = F::ZERO;
109			}
110		}
111	}
112}
113
114fn min_scratchpad_size(mut n: usize) -> usize {
115	assert!(n > 0);
116
117	let mut size = 0;
118	while n > 1 {
119		n = n.div_ceil(2);
120		size += n;
121	}
122	size
123}
124
125fn batch_invert_nonzero<F: Field>(elements: &mut [F], scratchpad: &mut [F]) {
126	debug_assert!(!elements.is_empty());
127
128	if elements.len() == 1 {
129		let element = elements.first_mut().expect("len == 1");
130		let inv = element
131			.invert()
132			.expect("precondition: elements contains no zeros");
133		*element = inv;
134		return;
135	}
136
137	let next_layer_len = elements.len().div_ceil(2);
138	let (next_layer, remaining) = scratchpad.split_at_mut(next_layer_len);
139	product_layer(elements, next_layer);
140	batch_invert_nonzero(next_layer, remaining);
141	unproduct_layer(next_layer, elements);
142}
143
144#[inline]
145fn product_layer<F: Field>(input: &[F], output: &mut [F]) {
146	debug_assert_eq!(output.len(), input.len().div_ceil(2));
147
148	let (in_pairs, in_remaining) = input.as_chunks::<2>();
149	let (out_head, out_remaining) = output.split_at_mut(in_pairs.len());
150	for (out_i, [in_lhs, in_rhs]) in iter::zip(out_head, in_pairs) {
151		*out_i = *in_lhs * *in_rhs;
152	}
153	if !out_remaining.is_empty() {
154		out_remaining[0] = in_remaining[0];
155	}
156}
157
158#[inline]
159fn unproduct_layer<F: Field>(input: &[F], output: &mut [F]) {
160	debug_assert_eq!(input.len(), output.len().div_ceil(2));
161
162	let (out_pairs, out_remaining) = output.as_chunks_mut::<2>();
163	let (in_head, in_remaining) = input.split_at(out_pairs.len());
164	for (in_i, [out_lhs, out_rhs]) in iter::zip(in_head, out_pairs) {
165		let out_lhs_tmp = *out_lhs;
166		let out_rhs_tmp = *out_rhs;
167		*out_lhs = *in_i * out_rhs_tmp;
168		*out_rhs = *in_i * out_lhs_tmp;
169	}
170	if !out_remaining.is_empty() {
171		out_remaining[0] = in_remaining[0];
172	}
173}
174
175#[cfg(test)]
176mod tests {
177	use binius_field::{BinaryField128bGhash as Ghash, Random, arithmetic_traits::InvertOrZero};
178	use proptest::prelude::*;
179	use rand::{Rng, SeedableRng, rngs::StdRng, seq::IteratorRandom};
180
181	use super::*;
182
183	/// Shared helper to test batch inversion with a given inverter.
184	fn invert_with_inverter(
185		inverter: &mut BatchInversion<Ghash>,
186		n: usize,
187		n_zeros: usize,
188		rng: &mut impl Rng,
189	) {
190		assert!(n_zeros <= n, "n_zeros must be <= n");
191
192		// Sample indices for zeros without replacement
193		let zero_indices: Vec<usize> = (0..n).choose_multiple(rng, n_zeros);
194
195		// Create state vector with zeros at sampled indices
196		let mut state = Vec::with_capacity(n);
197		for i in 0..n {
198			if zero_indices.contains(&i) {
199				state.push(Ghash::ZERO);
200			} else {
201				state.push(Ghash::random(&mut *rng));
202			}
203		}
204
205		let expected: Vec<Ghash> = state.iter().map(|x| x.invert_or_zero()).collect();
206
207		inverter.invert_or_zero(&mut state);
208
209		assert_eq!(state, expected);
210	}
211
212	fn test_batch_inversion_for_size(n: usize, n_zeros: usize, rng: &mut impl Rng) {
213		let mut inverter = BatchInversion::<Ghash>::new(n);
214		invert_with_inverter(&mut inverter, n, n_zeros, rng);
215	}
216
217	fn test_batch_inversion_nonzero_for_size(n: usize, rng: &mut impl Rng) {
218		let mut state = Vec::with_capacity(n);
219		for _ in 0..n {
220			state.push(Ghash::random(&mut *rng));
221		}
222
223		let expected: Vec<Ghash> = state.iter().map(|x| x.invert_or_zero()).collect();
224
225		let mut inverter = BatchInversion::<Ghash>::new(n);
226		inverter.invert_nonzero(&mut state);
227
228		assert_eq!(state, expected);
229	}
230
231	proptest! {
232		#[test]
233		fn test_batch_inversion(n in 1usize..=16, n_zeros in 0usize..=16) {
234			prop_assume!(n_zeros <= n);
235			let mut rng = StdRng::seed_from_u64(0);
236			test_batch_inversion_for_size(n, n_zeros, &mut rng);
237		}
238
239		#[test]
240		fn test_batch_inversion_nonzero(n in 1usize..=16) {
241			let mut rng = StdRng::seed_from_u64(0);
242			test_batch_inversion_nonzero_for_size(n, &mut rng);
243		}
244	}
245
246	#[test]
247	fn test_batch_inversion_reuse() {
248		let mut rng = StdRng::seed_from_u64(0);
249		let mut inverter = BatchInversion::<Ghash>::new(8);
250
251		// Test reusing the same inverter multiple times
252		for n_zeros in 0..=8 {
253			invert_with_inverter(&mut inverter, 8, n_zeros, &mut rng);
254		}
255	}
256}