binius_ntt/
odd_interpolate.rs1use binius_field::BinaryField;
4use binius_math::Matrix;
5use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
6
7use crate::{additive_ntt::AdditiveNTT, error::Error, twiddle::TwiddleAccess};
8
9pub struct OddInterpolate<F: BinaryField> {
10 vandermonde_inverse: Matrix<F>,
11 ell: usize,
12}
13
14impl<F: BinaryField> OddInterpolate<F> {
15 pub fn new<TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Self, Error>
19 where
20 TA: TwiddleAccess<F>,
21 {
22 let vandermonde = novel_vandermonde(d, ell, twiddle_access)?;
26
27 let mut vandermonde_inverse = Matrix::zeros(d, d);
28 vandermonde.inverse_into(&mut vandermonde_inverse)?;
29
30 Ok(Self {
31 vandermonde_inverse,
32 ell,
33 })
34 }
35
36 pub fn inverse_transform<NTT>(&self, ntt: &NTT, data: &mut [F]) -> Result<(), Error>
49 where
50 NTT: AdditiveNTT<F>,
52 {
53 let d = self.vandermonde_inverse.m();
54 let ell = self.ell;
55
56 if data.len() != d << ell {
57 bail!(Error::OddInterpolateIncorrectLength {
58 expected_len: d << ell
59 });
60 }
61
62 let log_required_domain_size = log2_ceil_usize(d) + ell;
63 if ntt.log_domain_size() < log_required_domain_size {
64 bail!(Error::DomainTooSmall {
65 log_required_domain_size
66 });
67 }
68
69 for (i, chunk) in data.chunks_exact_mut(1 << ell).enumerate() {
70 ntt.inverse_transform(chunk, i as u32, 0, ell)?;
71 }
72
73 let mut bases = vec![F::ZERO; d];
79 let mut novel = vec![F::ZERO; d];
80 for stride in 0..1 << ell {
82 (0..d).for_each(|i| bases[i] = data[i << ell | stride]);
83 self.vandermonde_inverse.mul_vec_into(&bases, &mut novel);
84 (0..d).for_each(|i| data[i << ell | stride] = novel[i]);
85 }
86
87 Ok(())
88 }
89}
90
91fn novel_vandermonde<F, TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Matrix<F>, Error>
95where
96 F: BinaryField,
97 TA: TwiddleAccess<F>,
98{
99 if d == 0 {
100 return Ok(Matrix::zeros(0, 0));
101 }
102
103 let log_d = log2_ceil_usize(d);
104
105 let mut x_ell = Matrix::zeros(d, d);
107
108 (0..d).for_each(|j| x_ell[(j, 0)] = F::ONE);
110
111 let log_required_domain_size = log_d + ell;
112 if twiddle_access.len() < log_required_domain_size {
113 bail!(Error::DomainTooSmall {
114 log_required_domain_size
115 });
116 }
117
118 for (j, twiddle_access_j_plus_ell) in twiddle_access[ell..ell + log_d].iter().enumerate() {
119 assert!(twiddle_access_j_plus_ell.log_n() >= log_d - 1 - j);
120
121 for i in 0..d {
122 x_ell[(i, 1 << j)] = twiddle_access_j_plus_ell.get(i >> (j + 1))
123 + if (i >> j) & 1 == 1 { F::ONE } else { F::ZERO };
124 }
125
126 for k in 1..(1 << j).min(d - (1 << j)) {
128 for t in 0..d {
129 x_ell[(t, k + (1 << j))] = x_ell[(t, k)] * x_ell[(t, 1 << j)];
130 }
131 }
132 }
133
134 Ok(x_ell)
135}
136
137#[cfg(test)]
138mod tests {
139 use std::iter::repeat_with;
140
141 use binius_field::{BinaryField32b, Field};
142 use rand::{rngs::StdRng, SeedableRng};
143
144 use super::*;
145 use crate::single_threaded::SingleThreadedNTT;
146
147 #[test]
148 fn test_interpolate_odd() {
149 type F = BinaryField32b;
150 let max_ell = 8;
151 let max_d = 10;
152
153 let mut rng = StdRng::seed_from_u64(0);
154 let ntt = SingleThreadedNTT::<F>::new(max_ell + log2_ceil_usize(max_d)).unwrap();
155
156 for ell in 0..max_ell {
157 for d in 0..max_d {
158 let expected_novel = repeat_with(|| F::random(&mut rng))
159 .take(d << ell)
160 .collect::<Vec<_>>();
161
162 let mut ntt_evals = expected_novel.clone();
163 let next_log_n = log2_ceil_usize(expected_novel.len());
165 ntt_evals.resize(1 << next_log_n, F::ZERO);
166 ntt.forward_transform(&mut ntt_evals, 0, 0, next_log_n)
168 .unwrap();
169
170 let odd_interpolate = OddInterpolate::new(d, ell, &ntt.s_evals).unwrap();
171 odd_interpolate
172 .inverse_transform(&ntt, &mut ntt_evals[..d << ell])
173 .unwrap();
174
175 assert_eq!(expected_novel, &ntt_evals[..d << ell]);
176 }
177 }
178 }
179}