Skip to main content

binius_frontend/compiler/hints/
mod_inverse.rs

1// Copyright 2025 Irreducible Inc.
2//! Modular inverse hint implementation
3
4use binius_core::Word;
5
6use super::Hint;
7use crate::util::num_biguint_from_u64_limbs;
8
9/// ModInverse hint implementation
10pub struct ModInverseHint;
11
12impl ModInverseHint {
13	pub fn new() -> Self {
14		Self
15	}
16}
17
18impl Default for ModInverseHint {
19	fn default() -> Self {
20		Self::new()
21	}
22}
23
24impl Hint for ModInverseHint {
25	const NAME: &'static str = "binius.mod_inverse";
26
27	fn shape(&self, dimensions: &[usize]) -> (usize, usize) {
28		let [base_limbs, mod_limbs] = dimensions else {
29			panic!("ModInverse requires 2 dimensions");
30		};
31		(*base_limbs + *mod_limbs, 2 * *mod_limbs)
32	}
33
34	fn execute(&self, dimensions: &[usize], inputs: &[Word], outputs: &mut [Word]) {
35		let [n_base, n_mod] = dimensions else {
36			panic!("ModInverse requires 2 dimensions");
37		};
38
39		let base_limbs = &inputs[0..*n_base];
40		let mod_limbs = &inputs[*n_base..];
41
42		let base = num_biguint_from_u64_limbs(base_limbs.iter().map(|w| w.as_u64()));
43		let modulus = num_biguint_from_u64_limbs(mod_limbs.iter().map(|w| w.as_u64()));
44
45		let zero = num_bigint::BigUint::ZERO;
46		let (quotient, inverse) = if let Some(inverse) = base.modinv(&modulus) {
47			let quotient = (base * &inverse - num_bigint::BigUint::from(1usize)) / &modulus;
48			(quotient, inverse)
49		} else {
50			(zero.clone(), zero)
51		};
52
53		assert_eq!(outputs.len(), 2 * *n_mod);
54		let (quotient_words, inverse_words) = outputs.split_at_mut(*n_mod);
55
56		// Fill output quotient limbs
57		for (i, limb) in quotient.iter_u64_digits().enumerate() {
58			quotient_words[i] = Word::from_u64(limb);
59		}
60
61		// Zero remaining outputs if quotient has fewer limbs
62		for i in quotient.iter_u64_digits().len()..*n_mod {
63			quotient_words[i] = Word::ZERO;
64		}
65
66		// Fill output inverse limbs
67		for (i, limb) in inverse.iter_u64_digits().enumerate() {
68			inverse_words[i] = Word::from_u64(limb);
69		}
70		// Zero remaining outputs if inverse has fewer limbs
71		for i in inverse.iter_u64_digits().len()..*n_mod {
72			inverse_words[i] = Word::ZERO;
73		}
74	}
75}