binius_math/
batch_invert.rs1use std::iter;
5
6use binius_field::Field;
7
8pub struct BatchInversion<F: Field> {
23 n: usize,
24 scratchpad: Vec<F>,
25 is_zero: Vec<bool>,
26}
27
28impl<F: Field> BatchInversion<F> {
29 pub fn new(n: usize) -> Self {
41 assert!(n > 0, "n must be greater than 0");
42
43 let scratchpad_size = min_scratchpad_size(n);
44 Self {
45 n,
46 scratchpad: vec![F::ZERO; scratchpad_size],
47 is_zero: vec![false; n],
48 }
49 }
50
51 pub fn invert_nonzero(&mut self, elements: &mut [F]) {
62 assert_eq!(
63 elements.len(),
64 self.n,
65 "elements.len() must equal n (expected {}, got {})",
66 self.n,
67 elements.len()
68 );
69
70 batch_invert_nonzero(elements, &mut self.scratchpad);
71 }
72
73 pub fn invert_or_zero(&mut self, elements: &mut [F]) {
84 assert_eq!(
85 elements.len(),
86 self.n,
87 "elements.len() must equal n (expected {}, got {})",
88 self.n,
89 elements.len()
90 );
91
92 for (element_i, is_zero_i) in iter::zip(&mut *elements, &mut self.is_zero) {
94 if *element_i == F::ZERO {
95 *element_i = F::ONE;
96 *is_zero_i = true;
97 } else {
98 *is_zero_i = false;
99 }
100 }
101
102 self.invert_nonzero(elements);
104
105 for (element_i, is_zero_i) in iter::zip(elements, &self.is_zero) {
107 if *is_zero_i {
108 *element_i = F::ZERO;
109 }
110 }
111 }
112}
113
114fn min_scratchpad_size(mut n: usize) -> usize {
115 assert!(n > 0);
116
117 let mut size = 0;
118 while n > 1 {
119 n = n.div_ceil(2);
120 size += n;
121 }
122 size
123}
124
125fn batch_invert_nonzero<F: Field>(elements: &mut [F], scratchpad: &mut [F]) {
126 debug_assert!(!elements.is_empty());
127
128 if elements.len() == 1 {
129 let element = elements.first_mut().expect("len == 1");
130 let inv = element
131 .invert()
132 .expect("precondition: elements contains no zeros");
133 *element = inv;
134 return;
135 }
136
137 let next_layer_len = elements.len().div_ceil(2);
138 let (next_layer, remaining) = scratchpad.split_at_mut(next_layer_len);
139 product_layer(elements, next_layer);
140 batch_invert_nonzero(next_layer, remaining);
141 unproduct_layer(next_layer, elements);
142}
143
144#[inline]
145fn product_layer<F: Field>(input: &[F], output: &mut [F]) {
146 debug_assert_eq!(output.len(), input.len().div_ceil(2));
147
148 let (in_pairs, in_remaining) = input.as_chunks::<2>();
149 let (out_head, out_remaining) = output.split_at_mut(in_pairs.len());
150 for (out_i, [in_lhs, in_rhs]) in iter::zip(out_head, in_pairs) {
151 *out_i = *in_lhs * *in_rhs;
152 }
153 if !out_remaining.is_empty() {
154 out_remaining[0] = in_remaining[0];
155 }
156}
157
158#[inline]
159fn unproduct_layer<F: Field>(input: &[F], output: &mut [F]) {
160 debug_assert_eq!(input.len(), output.len().div_ceil(2));
161
162 let (out_pairs, out_remaining) = output.as_chunks_mut::<2>();
163 let (in_head, in_remaining) = input.split_at(out_pairs.len());
164 for (in_i, [out_lhs, out_rhs]) in iter::zip(in_head, out_pairs) {
165 let out_lhs_tmp = *out_lhs;
166 let out_rhs_tmp = *out_rhs;
167 *out_lhs = *in_i * out_rhs_tmp;
168 *out_rhs = *in_i * out_lhs_tmp;
169 }
170 if !out_remaining.is_empty() {
171 out_remaining[0] = in_remaining[0];
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use binius_field::{BinaryField128bGhash as Ghash, Random, arithmetic_traits::InvertOrZero};
178 use proptest::prelude::*;
179 use rand::{Rng, SeedableRng, rngs::StdRng, seq::IteratorRandom};
180
181 use super::*;
182
183 fn invert_with_inverter(
185 inverter: &mut BatchInversion<Ghash>,
186 n: usize,
187 n_zeros: usize,
188 rng: &mut impl Rng,
189 ) {
190 assert!(n_zeros <= n, "n_zeros must be <= n");
191
192 let zero_indices: Vec<usize> = (0..n).choose_multiple(rng, n_zeros);
194
195 let mut state = Vec::with_capacity(n);
197 for i in 0..n {
198 if zero_indices.contains(&i) {
199 state.push(Ghash::ZERO);
200 } else {
201 state.push(Ghash::random(&mut *rng));
202 }
203 }
204
205 let expected: Vec<Ghash> = state.iter().map(|x| x.invert_or_zero()).collect();
206
207 inverter.invert_or_zero(&mut state);
208
209 assert_eq!(state, expected);
210 }
211
212 fn test_batch_inversion_for_size(n: usize, n_zeros: usize, rng: &mut impl Rng) {
213 let mut inverter = BatchInversion::<Ghash>::new(n);
214 invert_with_inverter(&mut inverter, n, n_zeros, rng);
215 }
216
217 fn test_batch_inversion_nonzero_for_size(n: usize, rng: &mut impl Rng) {
218 let mut state = Vec::with_capacity(n);
219 for _ in 0..n {
220 state.push(Ghash::random(&mut *rng));
221 }
222
223 let expected: Vec<Ghash> = state.iter().map(|x| x.invert_or_zero()).collect();
224
225 let mut inverter = BatchInversion::<Ghash>::new(n);
226 inverter.invert_nonzero(&mut state);
227
228 assert_eq!(state, expected);
229 }
230
231 proptest! {
232 #[test]
233 fn test_batch_inversion(n in 1usize..=16, n_zeros in 0usize..=16) {
234 prop_assume!(n_zeros <= n);
235 let mut rng = StdRng::seed_from_u64(0);
236 test_batch_inversion_for_size(n, n_zeros, &mut rng);
237 }
238
239 #[test]
240 fn test_batch_inversion_nonzero(n in 1usize..=16) {
241 let mut rng = StdRng::seed_from_u64(0);
242 test_batch_inversion_nonzero_for_size(n, &mut rng);
243 }
244 }
245
246 #[test]
247 fn test_batch_inversion_reuse() {
248 let mut rng = StdRng::seed_from_u64(0);
249 let mut inverter = BatchInversion::<Ghash>::new(8);
250
251 for n_zeros in 0..=8 {
253 invert_with_inverter(&mut inverter, 8, n_zeros, &mut rng);
254 }
255 }
256}