binius_math/
batch_invert.rs1use std::iter;
5
6use binius_field::{Field, PackedField};
7
8pub struct BatchInversion<P: PackedField> {
23 n: usize,
24 scratchpad: Vec<P>,
25 is_zero: Vec<bool>,
26 scalar_inverter: Option<Box<BatchInversion<P::Scalar>>>,
30}
31
32impl<P: PackedField> BatchInversion<P> {
33 pub fn new(n: usize) -> Self {
46 assert!(n > 0, "n must be greater than 0");
47
48 let scratchpad_size = min_scratchpad_size(n);
49 let scalar_inverter = if P::WIDTH > 1 {
50 Some(Box::new(BatchInversion::<P::Scalar>::new(P::WIDTH)))
51 } else {
52 None
53 };
54 Self {
55 n,
56 scratchpad: vec![P::zero(); scratchpad_size],
57 is_zero: vec![false; n * P::WIDTH],
58 scalar_inverter,
59 }
60 }
61
62 pub fn invert_nonzero(&mut self, elements: &mut [P]) {
73 assert_eq!(
74 elements.len(),
75 self.n,
76 "elements.len() must equal n (expected {}, got {})",
77 self.n,
78 elements.len()
79 );
80
81 self.batch_invert_nonzero(elements);
82 }
83
84 pub fn invert_or_zero(&mut self, elements: &mut [P]) {
95 assert_eq!(
96 elements.len(),
97 self.n,
98 "elements.len() must equal n (expected {}, got {})",
99 self.n,
100 elements.len()
101 );
102
103 for (packed_idx, packed) in elements.iter_mut().enumerate() {
105 for lane in 0..P::WIDTH {
106 let scalar_idx = packed_idx * P::WIDTH + lane;
107 let scalar = packed.get(lane);
108 if scalar == P::Scalar::ZERO {
109 packed.set(lane, P::Scalar::ONE);
110 self.is_zero[scalar_idx] = true;
111 } else {
112 self.is_zero[scalar_idx] = false;
113 }
114 }
115 }
116
117 self.invert_nonzero(elements);
119
120 for (packed_idx, packed) in elements.iter_mut().enumerate() {
122 for lane in 0..P::WIDTH {
123 let scalar_idx = packed_idx * P::WIDTH + lane;
124 if self.is_zero[scalar_idx] {
125 packed.set(lane, P::Scalar::ZERO);
126 }
127 }
128 }
129 }
130}
131
132fn min_scratchpad_size(mut n: usize) -> usize {
133 assert!(n > 0);
134
135 let mut size = 0;
136 while n > 1 {
137 n = n.div_ceil(2);
138 size += n;
139 }
140 size
141}
142
143impl<P: PackedField> BatchInversion<P> {
144 fn batch_invert_nonzero(&mut self, elements: &mut [P]) {
145 batch_invert_nonzero_with_scratchpad(
146 elements,
147 &mut self.scratchpad,
148 self.scalar_inverter.as_deref_mut(),
149 );
150 }
151}
152
153fn batch_invert_nonzero_with_scratchpad<P: PackedField>(
154 elements: &mut [P],
155 scratchpad: &mut [P],
156 scalar_inverter: Option<&mut BatchInversion<P::Scalar>>,
157) {
158 debug_assert!(!elements.is_empty());
159
160 if elements.len() == 1 {
161 let packed = &mut elements[0];
162 if P::WIDTH == 1 {
163 let scalar = packed.get(0);
165 let inv = scalar
166 .invert()
167 .expect("precondition: elements contains no zeros");
168 packed.set(0, inv);
169 } else {
170 let mut scalars = packed.into_iter().collect::<Vec<_>>();
172 scalar_inverter
173 .expect("scalar_inverter must be Some when WIDTH > 1")
174 .invert_nonzero(&mut scalars);
175 *packed = P::from_scalars(scalars);
176 }
177 return;
178 }
179
180 let next_layer_len = elements.len().div_ceil(2);
181 let (next_layer, remaining) = scratchpad.split_at_mut(next_layer_len);
182 product_layer(elements, next_layer);
183 batch_invert_nonzero_with_scratchpad(next_layer, remaining, scalar_inverter);
184 unproduct_layer(next_layer, elements);
185}
186
187#[inline]
192fn product_layer<P: PackedField>(input: &[P], output: &mut [P]) {
193 debug_assert_eq!(output.len(), input.len().div_ceil(2));
194
195 let (lo, hi) = input.split_at(output.len());
196 let mut out_lo_iter = iter::zip(output, lo);
197
198 if hi.len() < out_lo_iter.len() {
199 let Some((out_i, lo_i)) = out_lo_iter.next_back() else {
200 unreachable!("out_lo_iter.len() must be greater than zero");
201 };
202 *out_i = *lo_i;
203 }
204 for ((out_i, &lo_i), &hi_i) in iter::zip(out_lo_iter, hi) {
205 *out_i = lo_i * hi_i;
206 }
207}
208
209#[inline]
215fn unproduct_layer<P: PackedField>(input: &[P], output: &mut [P]) {
216 debug_assert_eq!(input.len(), output.len().div_ceil(2));
217
218 let (lo, hi) = output.split_at_mut(input.len());
219 let mut lo_in_iter = iter::zip(lo, input);
220
221 if hi.len() < lo_in_iter.len() {
222 let Some((lo_i, in_i)) = lo_in_iter.next_back() else {
223 unreachable!("out_lo_iter.len() must be greater than zero");
224 };
225 *lo_i = *in_i;
226 }
227 for ((lo_i, &in_i), hi_i) in iter::zip(lo_in_iter, hi) {
228 let lo_tmp = *lo_i;
229 let hi_tmp = *hi_i;
230 *lo_i = in_i * hi_tmp;
231 *hi_i = in_i * lo_tmp;
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use binius_field::{BinaryField128bGhash as Ghash, Random, arithmetic_traits::InvertOrZero};
238 use proptest::prelude::*;
239 use rand::{Rng, SeedableRng, rngs::StdRng, seq::IteratorRandom};
240
241 use super::*;
242
243 fn invert_with_inverter(
245 inverter: &mut BatchInversion<Ghash>,
246 n: usize,
247 n_zeros: usize,
248 rng: &mut impl Rng,
249 ) {
250 assert!(n_zeros <= n, "n_zeros must be <= n");
251
252 let zero_indices: Vec<usize> = (0..n).choose_multiple(rng, n_zeros);
254
255 let mut state = Vec::with_capacity(n);
257 for i in 0..n {
258 if zero_indices.contains(&i) {
259 state.push(Ghash::ZERO);
260 } else {
261 state.push(Ghash::random(&mut *rng));
262 }
263 }
264
265 let expected: Vec<Ghash> = state
266 .iter()
267 .map(|x| InvertOrZero::invert_or_zero(*x))
268 .collect();
269
270 inverter.invert_or_zero(&mut state);
271
272 assert_eq!(state, expected);
273 }
274
275 fn test_batch_inversion_for_size(n: usize, n_zeros: usize, rng: &mut impl Rng) {
276 let mut inverter = BatchInversion::<Ghash>::new(n);
277 invert_with_inverter(&mut inverter, n, n_zeros, rng);
278 }
279
280 fn test_batch_inversion_nonzero_for_size(n: usize, rng: &mut impl Rng) {
281 let mut state = Vec::with_capacity(n);
282 for _ in 0..n {
283 state.push(Ghash::random(&mut *rng));
284 }
285
286 let expected: Vec<Ghash> = state
287 .iter()
288 .map(|x| InvertOrZero::invert_or_zero(*x))
289 .collect();
290
291 let mut inverter = BatchInversion::<Ghash>::new(n);
292 inverter.invert_nonzero(&mut state);
293
294 assert_eq!(state, expected);
295 }
296
297 proptest! {
298 #[test]
299 fn test_batch_inversion(n in 1usize..=16, n_zeros in 0usize..=16) {
300 prop_assume!(n_zeros <= n);
301 let mut rng = StdRng::seed_from_u64(0);
302 test_batch_inversion_for_size(n, n_zeros, &mut rng);
303 }
304
305 #[test]
306 fn test_batch_inversion_nonzero(n in 1usize..=16) {
307 let mut rng = StdRng::seed_from_u64(0);
308 test_batch_inversion_nonzero_for_size(n, &mut rng);
309 }
310 }
311
312 #[test]
313 fn test_batch_inversion_reuse() {
314 let mut rng = StdRng::seed_from_u64(0);
315 let mut inverter = BatchInversion::<Ghash>::new(8);
316
317 for n_zeros in 0..=8 {
319 invert_with_inverter(&mut inverter, 8, n_zeros, &mut rng);
320 }
321 }
322
323 #[test]
325 fn test_batch_inversion_packed() {
326 use crate::test_utils::Packed128b;
327
328 let mut rng = StdRng::seed_from_u64(0);
329 const N: usize = 4;
330
331 let mut state: Vec<Packed128b> = (0..N)
333 .map(|i| {
334 Packed128b::from_fn(|lane| {
335 if (i == 1 && lane == 0) || (i == 2 && lane == 2) {
337 Ghash::ZERO
338 } else {
339 Ghash::random(&mut rng)
340 }
341 })
342 })
343 .collect();
344
345 let expected: Vec<Packed128b> = state
347 .iter()
348 .map(|packed| Packed128b::from_scalars(packed.iter().map(InvertOrZero::invert_or_zero)))
349 .collect();
350
351 let mut inverter = BatchInversion::<Packed128b>::new(N);
352 inverter.invert_or_zero(&mut state);
353
354 assert_eq!(state, expected);
355 }
356}