binius_ntt/
odd_interpolate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::BinaryField;
4use binius_math::Matrix;
5use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
6
7use crate::{
8	additive_ntt::{AdditiveNTT, NTTShape},
9	error::Error,
10	twiddle::TwiddleAccess,
11};
12
13pub struct OddInterpolate<F: BinaryField> {
14	vandermonde_inverse: Matrix<F>,
15	ell: usize,
16}
17
18impl<F: BinaryField> OddInterpolate<F> {
19	/// Create a new odd interpolator into novel polynomial basis for domains of size $d \times 2^{\ell}$.
20	/// Takes a reference to NTT twiddle factors to seed the "Vandermonde" matrix and compute its inverse.
21	/// Time complexity is $\mathcal{O}(d^3).$
22	pub fn new<TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Self, Error>
23	where
24		TA: TwiddleAccess<F>,
25	{
26		// TODO: This constructor should accept an `impl AdditiveNTT` instead of an
27		// `impl TwiddleAccess`. It can use `AdditiveNTT::get_subspace_eval` instead of the twiddle
28		// accessors directly. `AdditiveNTT` is a more public interface.
29		let vandermonde = novel_vandermonde(d, ell, twiddle_access)?;
30
31		let mut vandermonde_inverse = Matrix::zeros(d, d);
32		vandermonde.inverse_into(&mut vandermonde_inverse)?;
33
34		Ok(Self {
35			vandermonde_inverse,
36			ell,
37		})
38	}
39
40	/// Let $L/\mathbb F_2$ be a binary field, and fix an $\mathbb F_2$-basis $1=:\beta_0,\ldots, \beta_{r-1}$ as usual.
41	/// Let $d\geq 1$ be an odd integer and let $\ell\geq 0$ be an integer. Let
42	/// $[a_0,\ldots, a_{d\times 2^{\ell} - 1}]$ be a list of elements of $L$. There is a unique univariate polynomial
43	/// $P(X)\in L\[X\]$ of degree less than $d\times 2^{\ell}$ such that the *evaluations* of $P$ on the "first" $d\times 2^{\ell}$
44	/// elements of $L$ (in little-Endian binary counting order with respect to the basis $\beta_0,\ldots, \beta_{r}$)
45	/// are precisely $a_0,\ldots, a_{d\times 2^{\ell} - 1}$.
46	///
47	/// We efficiently compute the coefficients of $P(X)$ with respect to the Novel Polynomial Basis (itself taken
48	/// with respect to the given ordered list $\beta_0,\ldots, \beta_{r-1}$).
49	///
50	/// Time complexity is $\mathcal{O}(d^2\times 2^{\ell} + \ell 2^{\ell})$, thus this routine is intended to be used
51	/// for small values of $d$.
52	pub fn inverse_transform<NTT>(&self, ntt: &NTT, data: &mut [F]) -> Result<(), Error>
53	where
54		// REVIEW: generalize this to any P: PackedField<Scalar=F>
55		NTT: AdditiveNTT<F>,
56	{
57		let d = self.vandermonde_inverse.m();
58		let ell = self.ell;
59
60		if data.len() != d << ell {
61			bail!(Error::OddInterpolateIncorrectLength {
62				expected_len: d << ell
63			});
64		}
65
66		let log_required_domain_size = log2_ceil_usize(d) + ell;
67		if ntt.log_domain_size() < log_required_domain_size {
68			bail!(Error::DomainTooSmall {
69				log_required_domain_size
70			});
71		}
72
73		let shape = NTTShape {
74			log_y: ell,
75			..Default::default()
76		};
77		for (i, chunk) in data.chunks_exact_mut(1 << ell).enumerate() {
78			ntt.inverse_transform(chunk, shape, i as u32, 0)?;
79		}
80
81		// Given M and a vector v, do the "strided product" M v. In more detail: we assume matrix is $d\times d$,
82		// and vector is $d\times 2^{\ell}$. For each $i$ in $0,\ldots, 2^{\ell-1}$, let $v_i$ be the subvect
83		// given by those entries whose index is congruent to $i$ mod $2^{\ell}$. Then this computes $M v_i$,
84		// and finally "interleaves" the result (which means that we treat $M v_i = w_i$ for each $i$ and then conjure
85		// up the associated vector $w$.)
86		let mut bases = vec![F::ZERO; d];
87		let mut novel = vec![F::ZERO; d];
88		// TODO: use `Matrix::mul_into`, implement when data is a slice of type `P: PackedField<Scalar=F>`.
89		for stride in 0..1 << ell {
90			(0..d).for_each(|i| bases[i] = data[i << ell | stride]);
91			self.vandermonde_inverse.mul_vec_into(&bases, &mut novel);
92			(0..d).for_each(|i| data[i << ell | stride] = novel[i]);
93		}
94
95		Ok(())
96	}
97}
98
99/// Compute the Vandermonde matrix: $X^{(\ell)}_i(w^{\ell}_j)$, where $w^{\ell}_j$ is the $j^{\text{th}}$ element of the field
100/// with respect to the $\beta^{(\ell)}_i$ in little Endian order. The matrix has dimensions $d\times d$.
101/// The key trick is that $\widehat{W}^{(\ell)}_i(\beta^{\ell}_j) = $\widehat{W}_{i+\ell}(\beta_{j+\ell})$.
102fn novel_vandermonde<F, TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Matrix<F>, Error>
103where
104	F: BinaryField,
105	TA: TwiddleAccess<F>,
106{
107	if d == 0 {
108		return Ok(Matrix::zeros(0, 0));
109	}
110
111	let log_d = log2_ceil_usize(d);
112
113	// This will contain the evaluations of $X^{(\ell)}_{j}(w^{(\ell)}_i)$. As usual, indexing goes from 0..d-1.
114	let mut x_ell = Matrix::zeros(d, d);
115
116	// $X_0$ is the function "1".
117	(0..d).for_each(|j| x_ell[(j, 0)] = F::ONE);
118
119	let log_required_domain_size = log_d + ell;
120	if twiddle_access.len() < log_required_domain_size {
121		bail!(Error::DomainTooSmall {
122			log_required_domain_size
123		});
124	}
125
126	for (j, twiddle_access_j_plus_ell) in twiddle_access[ell..ell + log_d].iter().enumerate() {
127		assert!(twiddle_access_j_plus_ell.log_n() >= log_d - 1 - j);
128
129		for i in 0..d {
130			x_ell[(i, 1 << j)] = twiddle_access_j_plus_ell.get(i >> (j + 1))
131				+ if (i >> j) & 1 == 1 { F::ONE } else { F::ZERO };
132		}
133
134		// Note that the jth column of x_ell is the ordered list of values $X_j(w_i)$ for i = 0, ..., d-1.
135		for k in 1..(1 << j).min(d - (1 << j)) {
136			for t in 0..d {
137				x_ell[(t, k + (1 << j))] = x_ell[(t, k)] * x_ell[(t, 1 << j)];
138			}
139		}
140	}
141
142	Ok(x_ell)
143}
144
145#[cfg(test)]
146mod tests {
147	use std::iter::repeat_with;
148
149	use binius_field::{BinaryField32b, Field};
150	use rand::{rngs::StdRng, SeedableRng};
151
152	use super::*;
153	use crate::single_threaded::SingleThreadedNTT;
154
155	#[test]
156	fn test_interpolate_odd() {
157		type F = BinaryField32b;
158		let max_ell = 8;
159		let max_d = 10;
160
161		let mut rng = StdRng::seed_from_u64(0);
162		let ntt = SingleThreadedNTT::<F>::new(max_ell + log2_ceil_usize(max_d)).unwrap();
163
164		for ell in 0..max_ell {
165			for d in 0..max_d {
166				let expected_novel = repeat_with(|| F::random(&mut rng))
167					.take(d << ell)
168					.collect::<Vec<_>>();
169
170				let mut ntt_evals = expected_novel.clone();
171				// zero-pad to the next power of two to apply the forward transform.
172				let next_log_n = log2_ceil_usize(expected_novel.len());
173				ntt_evals.resize(1 << next_log_n, F::ZERO);
174				// apply forward transform and then run our odd interpolation routine.
175				let shape = NTTShape {
176					log_y: next_log_n,
177					..Default::default()
178				};
179				ntt.forward_transform(&mut ntt_evals, shape, 0, 0).unwrap();
180
181				let odd_interpolate = OddInterpolate::new(d, ell, &ntt.s_evals).unwrap();
182				odd_interpolate
183					.inverse_transform(&ntt, &mut ntt_evals[..d << ell])
184					.unwrap();
185
186				assert_eq!(expected_novel, &ntt_evals[..d << ell]);
187			}
188		}
189	}
190}