binius_frontend/compiler/hints/
mod_inverse.rs1use binius_core::Word;
5
6use super::Hint;
7use crate::util::num_biguint_from_u64_limbs;
8
9pub struct ModInverseHint;
11
12impl ModInverseHint {
13 pub fn new() -> Self {
14 Self
15 }
16}
17
18impl Default for ModInverseHint {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Hint for ModInverseHint {
25 const NAME: &'static str = "binius.mod_inverse";
26
27 fn shape(&self, dimensions: &[usize]) -> (usize, usize) {
28 let [base_limbs, mod_limbs] = dimensions else {
29 panic!("ModInverse requires 2 dimensions");
30 };
31 (*base_limbs + *mod_limbs, 2 * *mod_limbs)
32 }
33
34 fn execute(&self, dimensions: &[usize], inputs: &[Word], outputs: &mut [Word]) {
35 let [n_base, n_mod] = dimensions else {
36 panic!("ModInverse requires 2 dimensions");
37 };
38
39 let base_limbs = &inputs[0..*n_base];
40 let mod_limbs = &inputs[*n_base..];
41
42 let base = num_biguint_from_u64_limbs(base_limbs.iter().map(|w| w.as_u64()));
43 let modulus = num_biguint_from_u64_limbs(mod_limbs.iter().map(|w| w.as_u64()));
44
45 let zero = num_bigint::BigUint::ZERO;
46 let (quotient, inverse) = if let Some(inverse) = base.modinv(&modulus) {
47 let quotient = (base * &inverse - num_bigint::BigUint::from(1usize)) / &modulus;
48 (quotient, inverse)
49 } else {
50 (zero.clone(), zero)
51 };
52
53 assert_eq!(outputs.len(), 2 * *n_mod);
54 let (quotient_words, inverse_words) = outputs.split_at_mut(*n_mod);
55
56 for (i, limb) in quotient.iter_u64_digits().enumerate() {
58 quotient_words[i] = Word::from_u64(limb);
59 }
60
61 for i in quotient.iter_u64_digits().len()..*n_mod {
63 quotient_words[i] = Word::ZERO;
64 }
65
66 for (i, limb) in inverse.iter_u64_digits().enumerate() {
68 inverse_words[i] = Word::from_u64(limb);
69 }
70 for i in inverse.iter_u64_digits().len()..*n_mod {
72 inverse_words[i] = Word::ZERO;
73 }
74 }
75}