binius_ntt/
odd_interpolate.rs1use binius_field::BinaryField;
4use binius_math::Matrix;
5use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
6
7use crate::{
8 additive_ntt::{AdditiveNTT, NTTShape},
9 error::Error,
10};
11
12#[derive(Debug)]
21pub struct OddInterpolate<'a, F: BinaryField, NTT: AdditiveNTT<F>> {
22 vandermonde_inverse: Matrix<F>,
23 ell: usize,
24 coset_bits: usize,
25 ntt: &'a NTT,
26}
27
28impl<'a, F: BinaryField, NTT: AdditiveNTT<F>> OddInterpolate<'a, F, NTT> {
29 pub fn new(ntt: &'a NTT, d: usize, ell: usize, coset_bits: usize) -> Result<Self, Error> {
33 if d > (1 << coset_bits) {
34 bail!(Error::CosetIndexOutOfBounds {
35 coset: d - 1,
36 coset_bits
37 });
38 }
39
40 let log_required_domain_size = coset_bits + ell;
41 if ntt.log_domain_size() < log_required_domain_size {
42 bail!(Error::DomainTooSmall {
43 log_required_domain_size
44 });
45 }
46
47 let vandermonde = novel_vandermonde(ntt, d, coset_bits)?;
48
49 let mut vandermonde_inverse = Matrix::zeros(d, d);
50 vandermonde.inverse_into(&mut vandermonde_inverse)?;
51
52 Ok(Self {
53 ntt,
54 vandermonde_inverse,
55 ell,
56 coset_bits,
57 })
58 }
59
60 pub fn inverse_transform(&self, data: &mut [F]) -> Result<(), Error> {
74 let d = self.vandermonde_inverse.m();
76 let ell = self.ell;
77
78 if data.len() != d << ell {
79 bail!(Error::OddInterpolateIncorrectLength {
80 expected_len: d << ell
81 });
82 }
83
84 let shape = NTTShape {
85 log_y: ell,
86 ..Default::default()
87 };
88 for (i, chunk) in data.chunks_exact_mut(1 << ell).enumerate() {
89 self.ntt
90 .inverse_transform(chunk, shape, i, self.coset_bits, 0)?;
91 }
92
93 let mut bases = vec![F::ZERO; d];
100 let mut novel = vec![F::ZERO; d];
101 for stride in 0..1 << ell {
104 (0..d).for_each(|i| bases[i] = data[i << ell | stride]);
105 self.vandermonde_inverse.mul_vec_into(&bases, &mut novel);
106 (0..d).for_each(|i| data[i << ell | stride] = novel[i]);
107 }
108
109 Ok(())
110 }
111}
112
113fn novel_vandermonde<F, NTT>(ntt: &NTT, d: usize, coset_bits: usize) -> Result<Matrix<F>, Error>
118where
119 F: BinaryField,
120 NTT: AdditiveNTT<F>,
121{
122 let mut x_ell = Matrix::zeros(d, d);
125
126 (0..d).for_each(|j| x_ell[(j, 0)] = F::ONE);
128
129 if d == 0 {
130 return Ok(x_ell);
131 }
132
133 let log_d = log2_ceil_usize(d);
134 for j in 0..log_d {
135 let subspace_idx = ntt.log_domain_size() - coset_bits + j;
136 for i in 0..d {
137 x_ell[(i, 1 << j)] = ntt.get_subspace_eval(subspace_idx, i >> (j + 1))
138 + if (i >> j) & 1 == 1 { F::ONE } else { F::ZERO };
139 }
140
141 for k in 1..(1 << j).min(d - (1 << j)) {
144 for t in 0..d {
145 x_ell[(t, k + (1 << j))] = x_ell[(t, k)] * x_ell[(t, 1 << j)];
146 }
147 }
148 }
149
150 Ok(x_ell)
151}
152
153#[cfg(test)]
154mod tests {
155 use std::iter::repeat_with;
156
157 use binius_field::{BinaryField32b, Field};
158 use binius_utils::checked_arithmetics::log2_ceil_usize;
159 use rand::{SeedableRng, rngs::StdRng};
160
161 use super::*;
162 use crate::single_threaded::SingleThreadedNTT;
163
164 #[test]
165 fn test_interpolate_odd() {
166 type F = BinaryField32b;
167 let max_ell = 8;
168 let max_d = 10;
169
170 let mut rng = StdRng::seed_from_u64(0);
171 let ntt = SingleThreadedNTT::<F>::new(max_ell + log2_ceil_usize(max_d)).unwrap();
172
173 for ell in 0..max_ell {
174 for d in 0..max_d {
175 let expected_novel = repeat_with(|| F::random(&mut rng))
176 .take(d << ell)
177 .collect::<Vec<_>>();
178
179 let mut ntt_evals = expected_novel.clone();
180 let next_log_n = log2_ceil_usize(expected_novel.len());
182 ntt_evals.resize(1 << next_log_n, F::ZERO);
183 let shape = NTTShape {
185 log_y: next_log_n,
186 ..Default::default()
187 };
188 ntt.forward_transform(&mut ntt_evals, shape, 0, 0, 0)
189 .unwrap();
190
191 let coset_bits = next_log_n.saturating_sub(ell);
192 let odd_interpolate = OddInterpolate::new(&ntt, d, ell, coset_bits).unwrap();
193 odd_interpolate
194 .inverse_transform(&mut ntt_evals[..d << ell])
195 .unwrap();
196
197 assert_eq!(expected_novel, &ntt_evals[..d << ell]);
198 }
199 }
200 }
201}