binius_ntt/odd_interpolate.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
// Copyright 2024 Irreducible Inc.
use crate::{additive_ntt::AdditiveNTT, error::Error, twiddle::TwiddleAccess};
use binius_field::BinaryField;
use binius_math::Matrix;
use binius_utils::bail;
use p3_util::log2_ceil_usize;
pub struct OddInterpolate<F: BinaryField> {
vandermonde_inverse: Matrix<F>,
ell: usize,
}
impl<F: BinaryField> OddInterpolate<F> {
/// Create a new odd interpolator into novel polynomial basis for domains of size $d \times 2^{\ell}$.
/// Takes a reference to NTT twiddle factors to seed the "Vandermonde" matrix and compute its inverse.
/// Time complexity is $\mathcal{O}(d^3).$
pub fn new<TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Self, Error>
where
TA: TwiddleAccess<F>,
{
let vandermonde = novel_vandermonde(d, ell, twiddle_access)?;
let mut vandermonde_inverse = Matrix::zeros(d, d);
vandermonde.inverse_into(&mut vandermonde_inverse)?;
Ok(Self {
vandermonde_inverse,
ell,
})
}
/// Let $L/\mathbb F_2$ be a binary field, and fix an $\mathbb F_2$-basis $1=:\beta_0,\ldots, \beta_{r-1}$ as usual.
/// Let $d\geq 1$ be an odd integer and let $\ell\geq 0$ be an integer. Let
/// $[a_0,\ldots, a_{d\times 2^{\ell} - 1}]$ be a list of elements of $L$. There is a unique univariate polynomial
/// $P(X)\in L\[X\]$ of degree less than $d\times 2^{\ell}$ such that the *evaluations* of $P$ on the "first" $d\times 2^{\ell}$
/// elements of $L$ (in little-Endian binary counting order with respect to the basis $\beta_0,\ldots, \beta_{r}$)
/// are precisely $a_0,\ldots, a_{d\times 2^{\ell} - 1}$.
///
/// We efficiently compute the coefficients of $P(X)$ with respect to the Novel Polynomial Basis (itself taken
/// with respect to the given ordered list $\beta_0,\ldots, \beta_{r-1}$).
///
/// Time complexity is $\mathcal{O}(d^2\times 2^{\ell} + \ell 2^{\ell})$, thus this routine is intended to be used
/// for small values of $d$.
pub fn inverse_transform<NTT>(&self, ntt: &NTT, data: &mut [F]) -> Result<(), Error>
where
// REVIEW: generalize this to any P: PackedField<Scalar=F>
NTT: AdditiveNTT<F>,
{
let d = self.vandermonde_inverse.m();
let ell = self.ell;
if data.len() != d << ell {
bail!(Error::OddInterpolateIncorrectLength {
expected_len: d << ell
});
}
let log_required_domain_size = log2_ceil_usize(d) + ell;
if ntt.log_domain_size() < log_required_domain_size {
bail!(Error::DomainTooSmall {
log_required_domain_size
});
}
for (i, chunk) in data.chunks_exact_mut(1 << ell).enumerate() {
ntt.inverse_transform(chunk, i as u32, 0)?;
}
// Given M and a vector v, do the "strided product" M v. In more detail: we assume matrix is $d\times d$,
// and vector is $d\times 2^{\ell}$. For each $i$ in $0,\ldots, 2^{\ell-1}$, let $v_i$ be the subvect
// given by those entries whose index is congruent to $i$ mod $2^{\ell}$. Then this computes $M v_i$,
// and finally "interleaves" the result (which means that we treat $M v_i = w_i$ for each $i$ and then conjure
// up the associated vector $w$.)
let mut bases = vec![F::ZERO; d];
let mut novel = vec![F::ZERO; d];
// TODO: use `Matrix::mul_into`, implement when data is a slice of type `P: PackedField<Scalar=F>`.
for stride in 0..1 << ell {
(0..d).for_each(|i| bases[i] = data[i << ell | stride]);
self.vandermonde_inverse.mul_vec_into(&bases, &mut novel);
(0..d).for_each(|i| data[i << ell | stride] = novel[i]);
}
Ok(())
}
}
/// Compute the Vandermonde matrix: $X^{(\ell)}_i(w^{\ell}_j)$, where $w^{\ell}_j$ is the $j^{\text{th}}$ element of the field
/// with respect to the $\beta^{(\ell)}_i$ in little Endian order. The matrix has dimensions $d\times d$.
/// The key trick is that $\widehat{W}^{(\ell)}_i(\beta^{\ell}_j) = $\widehat{W}_{i+\ell}(\beta_{j+\ell})$.
fn novel_vandermonde<F, TA>(d: usize, ell: usize, twiddle_access: &[TA]) -> Result<Matrix<F>, Error>
where
F: BinaryField,
TA: TwiddleAccess<F>,
{
if d == 0 {
return Ok(Matrix::zeros(0, 0));
}
let log_d = log2_ceil_usize(d);
// This will contain the evaluations of $X^{(\ell)}_{j}(w^{(\ell)}_i)$. As usual, indexing goes from 0..d-1.
let mut x_ell = Matrix::zeros(d, d);
// $X_0$ is the function "1".
(0..d).for_each(|j| x_ell[(j, 0)] = F::ONE);
let log_required_domain_size = log_d + ell;
if twiddle_access.len() < log_required_domain_size {
bail!(Error::DomainTooSmall {
log_required_domain_size
});
}
for (j, twiddle_access_j_plus_ell) in twiddle_access[ell..ell + log_d].iter().enumerate() {
assert!(twiddle_access_j_plus_ell.log_n() >= log_d - 1 - j);
for i in 0..d {
x_ell[(i, 1 << j)] = twiddle_access_j_plus_ell.get(i >> (j + 1))
+ if (i >> j) & 1 == 1 { F::ONE } else { F::ZERO };
}
// Note that the jth column of x_ell is the ordered list of values $X_j(w_i)$ for i = 0, ..., d-1.
for k in 1..(1 << j).min(d - (1 << j)) {
for t in 0..d {
x_ell[(t, k + (1 << j))] = x_ell[(t, k)] * x_ell[(t, 1 << j)];
}
}
}
Ok(x_ell)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::single_threaded::SingleThreadedNTT;
use binius_field::{BinaryField32b, Field};
use rand::{rngs::StdRng, SeedableRng};
use std::iter::repeat_with;
#[test]
fn test_interpolate_odd() {
type F = BinaryField32b;
let max_ell = 8;
let max_d = 10;
let mut rng = StdRng::seed_from_u64(0);
let ntt = SingleThreadedNTT::<F>::new(max_ell + log2_ceil_usize(max_d)).unwrap();
for ell in 0..max_ell {
for d in 0..max_d {
let expected_novel = repeat_with(|| F::random(&mut rng))
.take(d << ell)
.collect::<Vec<_>>();
let mut ntt_evals = expected_novel.clone();
// zero-pad to the next power of two to apply the forward transform.
let next_power_of_two = 1 << log2_ceil_usize(expected_novel.len());
ntt_evals.resize(next_power_of_two, F::ZERO);
// apply forward transform and then run our odd interpolation routine.
ntt.forward_transform(&mut ntt_evals, 0, 0).unwrap();
let odd_interpolate = OddInterpolate::new(d, ell, &ntt.s_evals).unwrap();
odd_interpolate
.inverse_transform(&ntt, &mut ntt_evals[..d << ell])
.unwrap();
assert_eq!(expected_novel, &ntt_evals[..d << ell]);
}
}
}
}