binius_ntt/
odd_interpolate.rs1use binius_field::BinaryField;
4use binius_math::Matrix;
5use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
6
7use crate::{
8 additive_ntt::{AdditiveNTT, NTTShape},
9 error::Error,
10 twiddle::TwiddleAccess,
11};
12
13pub struct OddInterpolate<F: BinaryField> {
14 vandermonde_inverse: Matrix<F>,
15 ell: usize,
16}
17
18impl<F: BinaryField> OddInterpolate<F> {
19 pub fn new<TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Self, Error>
23 where
24 TA: TwiddleAccess<F>,
25 {
26 let vandermonde = novel_vandermonde(d, ell, twiddle_access)?;
30
31 let mut vandermonde_inverse = Matrix::zeros(d, d);
32 vandermonde.inverse_into(&mut vandermonde_inverse)?;
33
34 Ok(Self {
35 vandermonde_inverse,
36 ell,
37 })
38 }
39
40 pub fn inverse_transform<NTT>(&self, ntt: &NTT, data: &mut [F]) -> Result<(), Error>
53 where
54 NTT: AdditiveNTT<F>,
56 {
57 let d = self.vandermonde_inverse.m();
58 let ell = self.ell;
59
60 if data.len() != d << ell {
61 bail!(Error::OddInterpolateIncorrectLength {
62 expected_len: d << ell
63 });
64 }
65
66 let log_required_domain_size = log2_ceil_usize(d) + ell;
67 if ntt.log_domain_size() < log_required_domain_size {
68 bail!(Error::DomainTooSmall {
69 log_required_domain_size
70 });
71 }
72
73 let shape = NTTShape {
74 log_y: ell,
75 ..Default::default()
76 };
77 for (i, chunk) in data.chunks_exact_mut(1 << ell).enumerate() {
78 ntt.inverse_transform(chunk, shape, i as u32, 0)?;
79 }
80
81 let mut bases = vec![F::ZERO; d];
87 let mut novel = vec![F::ZERO; d];
88 for stride in 0..1 << ell {
90 (0..d).for_each(|i| bases[i] = data[i << ell | stride]);
91 self.vandermonde_inverse.mul_vec_into(&bases, &mut novel);
92 (0..d).for_each(|i| data[i << ell | stride] = novel[i]);
93 }
94
95 Ok(())
96 }
97}
98
99fn novel_vandermonde<F, TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Matrix<F>, Error>
103where
104 F: BinaryField,
105 TA: TwiddleAccess<F>,
106{
107 if d == 0 {
108 return Ok(Matrix::zeros(0, 0));
109 }
110
111 let log_d = log2_ceil_usize(d);
112
113 let mut x_ell = Matrix::zeros(d, d);
115
116 (0..d).for_each(|j| x_ell[(j, 0)] = F::ONE);
118
119 let log_required_domain_size = log_d + ell;
120 if twiddle_access.len() < log_required_domain_size {
121 bail!(Error::DomainTooSmall {
122 log_required_domain_size
123 });
124 }
125
126 for (j, twiddle_access_j_plus_ell) in twiddle_access[ell..ell + log_d].iter().enumerate() {
127 assert!(twiddle_access_j_plus_ell.log_n() >= log_d - 1 - j);
128
129 for i in 0..d {
130 x_ell[(i, 1 << j)] = twiddle_access_j_plus_ell.get(i >> (j + 1))
131 + if (i >> j) & 1 == 1 { F::ONE } else { F::ZERO };
132 }
133
134 for k in 1..(1 << j).min(d - (1 << j)) {
136 for t in 0..d {
137 x_ell[(t, k + (1 << j))] = x_ell[(t, k)] * x_ell[(t, 1 << j)];
138 }
139 }
140 }
141
142 Ok(x_ell)
143}
144
145#[cfg(test)]
146mod tests {
147 use std::iter::repeat_with;
148
149 use binius_field::{BinaryField32b, Field};
150 use rand::{rngs::StdRng, SeedableRng};
151
152 use super::*;
153 use crate::single_threaded::SingleThreadedNTT;
154
155 #[test]
156 fn test_interpolate_odd() {
157 type F = BinaryField32b;
158 let max_ell = 8;
159 let max_d = 10;
160
161 let mut rng = StdRng::seed_from_u64(0);
162 let ntt = SingleThreadedNTT::<F>::new(max_ell + log2_ceil_usize(max_d)).unwrap();
163
164 for ell in 0..max_ell {
165 for d in 0..max_d {
166 let expected_novel = repeat_with(|| F::random(&mut rng))
167 .take(d << ell)
168 .collect::<Vec<_>>();
169
170 let mut ntt_evals = expected_novel.clone();
171 let next_log_n = log2_ceil_usize(expected_novel.len());
173 ntt_evals.resize(1 << next_log_n, F::ZERO);
174 let shape = NTTShape {
176 log_y: next_log_n,
177 ..Default::default()
178 };
179 ntt.forward_transform(&mut ntt_evals, shape, 0, 0).unwrap();
180
181 let odd_interpolate = OddInterpolate::new(d, ell, &ntt.s_evals).unwrap();
182 odd_interpolate
183 .inverse_transform(&ntt, &mut ntt_evals[..d << ell])
184 .unwrap();
185
186 assert_eq!(expected_novel, &ntt_evals[..d << ell]);
187 }
188 }
189 }
190}