1use std::{cmp, marker::PhantomData};
4
5use binius_field::{BinaryField, PackedField, TowerField};
6use binius_math::BinarySubspace;
7
8use super::{additive_ntt::AdditiveNTT, error::Error, twiddle::TwiddleAccess};
9use crate::twiddle::{expand_subspace_evals, OnTheFlyTwiddleAccess, PrecomputedTwiddleAccess};
10
11#[derive(Debug)]
13pub struct SingleThreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F>> {
14 pub(super) s_evals: Vec<TA>,
16 _marker: PhantomData<F>,
17}
18
19impl<F: BinaryField> SingleThreadedNTT<F> {
20 pub fn new(log_domain_size: usize) -> Result<Self, Error> {
23 let subspace = BinarySubspace::with_dim(log_domain_size)?;
24 let twiddle_access = OnTheFlyTwiddleAccess::generate(&subspace)?;
25 Ok(Self::with_twiddle_access(twiddle_access))
26 }
27
28 pub fn with_domain_field<FDomain>(log_domain_size: usize) -> Result<Self, Error>
31 where
32 FDomain: BinaryField,
33 F: From<FDomain>,
34 {
35 let subspace = BinarySubspace::<FDomain>::with_dim(log_domain_size)?.isomorphic();
36 let twiddle_access = OnTheFlyTwiddleAccess::generate(&subspace)?;
37 Ok(Self::with_twiddle_access(twiddle_access))
38 }
39
40 pub fn precompute_twiddles(&self) -> SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>> {
41 SingleThreadedNTT::with_twiddle_access(expand_subspace_evals(&self.s_evals))
42 }
43}
44
45impl<F: TowerField> SingleThreadedNTT<F> {
46 pub fn with_canonical_field(log_domain_size: usize) -> Result<Self, Error> {
48 Self::with_domain_field::<F::Canonical>(log_domain_size)
49 }
50}
51
52impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
53 const fn with_twiddle_access(twiddle_access: Vec<TA>) -> Self {
54 Self {
55 s_evals: twiddle_access,
56 _marker: PhantomData,
57 }
58 }
59}
60
61impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
62 pub fn twiddles(&self) -> &[TA] {
63 &self.s_evals
64 }
65}
66
67impl<F, TA> AdditiveNTT<F> for SingleThreadedNTT<F, TA>
68where
69 F: BinaryField,
70 TA: TwiddleAccess<F>,
71{
72 fn log_domain_size(&self) -> usize {
73 self.s_evals.len()
74 }
75
76 fn subspace(&self, i: usize) -> BinarySubspace<F> {
77 let (subspace, shift) = self.s_evals[i].affine_subspace();
78 debug_assert_eq!(shift, F::ZERO, "s_evals subspaces must be linear by construction");
79 subspace
80 }
81
82 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
83 self.s_evals[i].get(j)
84 }
85
86 fn forward_transform<P: PackedField<Scalar = F>>(
87 &self,
88 data: &mut [P],
89 coset: u32,
90 log_batch_size: usize,
91 log_n: usize,
92 ) -> Result<(), Error> {
93 forward_transform(self.log_domain_size(), &self.s_evals, data, coset, log_batch_size, log_n)
94 }
95
96 fn inverse_transform<P: PackedField<Scalar = F>>(
97 &self,
98 data: &mut [P],
99 coset: u32,
100 log_batch_size: usize,
101 log_n: usize,
102 ) -> Result<(), Error> {
103 inverse_transform(self.log_domain_size(), &self.s_evals, data, coset, log_batch_size, log_n)
104 }
105}
106
107pub fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
108 log_domain_size: usize,
109 s_evals: &[impl TwiddleAccess<F>],
110 data: &mut [P],
111 coset: u32,
112 log_batch_size: usize,
113 log_n: usize,
114) -> Result<(), Error> {
115 match data.len() {
116 0 => return Ok(()),
117 1 => {
118 return match P::LOG_WIDTH {
119 0 => Ok(()),
120 _ => {
121 let mut buffer = [data[0], P::zero()];
128
129 forward_transform(
130 log_domain_size,
131 s_evals,
132 &mut buffer,
133 coset,
134 log_batch_size,
135 log_n,
136 )?;
137
138 data[0] = buffer[0];
139
140 Ok(())
141 }
142 };
143 }
144 _ => {}
145 };
146
147 let log_b = log_batch_size;
148
149 let log_w = P::LOG_WIDTH;
150
151 check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_batch_size, log_n)?;
152
153 let cutoff = log_w.saturating_sub(log_b);
156
157 for i in (cutoff..log_n).rev() {
158 let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
159
160 for j in 0..1 << (log_n - 1 - i) {
161 let twiddle = P::broadcast(coset_twiddle.get(j));
162 for k in 0..1 << (i + log_b - log_w) {
163 let idx0 = j << (i + log_b - log_w + 1) | k;
164 let idx1 = idx0 | 1 << (i + log_b - log_w);
165 data[idx0] += data[idx1] * twiddle;
166 data[idx1] += data[idx0];
167 }
168 }
169 }
170
171 for i in (0..cmp::min(cutoff, log_n)).rev() {
172 let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
173
174 let log_block_len = i + log_b;
180 let block_twiddle = calculate_twiddle::<P>(
181 &s_evals[i].coset(log_domain_size - 1 - cutoff, 0),
182 log_block_len,
183 );
184
185 for j in 0..data.len() / 2 {
186 let twiddle = P::broadcast(coset_twiddle.get(j << (cutoff - i))) + block_twiddle;
187 let (mut u, mut v) = data[j << 1].interleave(data[j << 1 | 1], log_block_len);
188 u += v * twiddle;
189 v += u;
190 (data[j << 1], data[j << 1 | 1]) = u.interleave(v, log_block_len);
191 }
192 }
193
194 Ok(())
195}
196
197pub fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
198 log_domain_size: usize,
199 s_evals: &[impl TwiddleAccess<F>],
200 data: &mut [P],
201 coset: u32,
202 log_batch_size: usize,
203 log_n: usize,
204) -> Result<(), Error> {
205 match data.len() {
206 0 => return Ok(()),
207 1 => {
208 return match P::LOG_WIDTH {
209 0 => Ok(()),
210 _ => {
211 let mut buffer = [data[0], P::zero()];
218
219 inverse_transform(
220 log_domain_size,
221 s_evals,
222 &mut buffer,
223 coset,
224 log_batch_size,
225 log_n,
226 )?;
227
228 data[0] = buffer[0];
229 Ok(())
230 }
231 };
232 }
233 _ => {}
234 };
235
236 let log_w = P::LOG_WIDTH;
237
238 let log_b = log_batch_size;
239
240 check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_batch_size, log_n)?;
241
242 let cutoff = log_w.saturating_sub(log_b);
245
246 #[allow(clippy::needless_range_loop)]
247 for i in 0..cmp::min(cutoff, log_n) {
248 let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
249
250 let log_block_len = i + log_b;
256 let block_twiddle = calculate_twiddle::<P>(
257 &s_evals[i].coset(log_domain_size - 1 - cutoff, 0),
258 log_block_len,
259 );
260
261 for j in 0..data.len() / 2 {
262 let twiddle = P::broadcast(coset_twiddle.get(j << (cutoff - i))) + block_twiddle;
263 let (mut u, mut v) = data[j << 1].interleave(data[j << 1 | 1], log_block_len);
264 v += u;
265 u += v * twiddle;
266 (data[j << 1], data[j << 1 | 1]) = u.interleave(v, log_block_len);
267 }
268 }
269
270 #[allow(clippy::needless_range_loop)]
271 for i in cutoff..log_n {
272 let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
273
274 for j in 0..1 << (log_n - 1 - i) {
275 let twiddle = P::broadcast(coset_twiddle.get(j));
276 for k in 0..1 << (i + log_b - log_w) {
277 let idx0 = j << (i + log_b - log_w + 1) | k;
278 let idx1 = idx0 | 1 << (i + log_b - log_w);
279 data[idx1] += data[idx0];
280 data[idx0] += data[idx1] * twiddle;
281 }
282 }
283 }
284
285 Ok(())
286}
287
288pub fn check_batch_transform_inputs_and_params<PB: PackedField>(
289 log_domain_size: usize,
290 data: &[PB],
291 coset: u32,
292 log_batch_size: usize,
293 log_n: usize,
294) -> Result<(), Error> {
295 if !data.len().is_power_of_two() {
296 return Err(Error::PowerOfTwoLengthRequired);
297 }
298 if !PB::WIDTH.is_power_of_two() {
299 return Err(Error::PackingWidthMustDivideDimension);
300 }
301
302 let full_sized_n = (data.len() * PB::WIDTH) >> log_batch_size;
303
304 if (1 << log_n != full_sized_n && data.len() > 2) || (1 << log_n > full_sized_n) {
306 return Err(Error::BatchTooLarge);
307 }
308
309 let coset_bits = 32 - coset.leading_zeros() as usize;
310
311 let log_required_domain_size =
317 (log_n + coset_bits).max((PB::LOG_WIDTH + 1).saturating_sub(log_batch_size));
318 if log_required_domain_size > log_domain_size {
319 return Err(Error::DomainTooSmall {
320 log_required_domain_size,
321 });
322 }
323
324 Ok(())
325}
326
327#[inline]
328fn calculate_twiddle<P>(s_evals: &impl TwiddleAccess<P::Scalar>, log_block_len: usize) -> P
329where
330 P: PackedField<Scalar: BinaryField>,
331{
332 let log_blocks_count = P::LOG_WIDTH - log_block_len - 1;
333
334 let mut twiddle = P::default();
335 for k in 0..1 << log_blocks_count {
336 let (subblock_twiddle_0, subblock_twiddle_1) = s_evals.get_pair(log_blocks_count, k);
337 let idx0 = k << (log_block_len + 1);
338 let idx1 = idx0 | 1 << log_block_len;
339
340 for l in 0..1 << log_block_len {
341 twiddle.set(idx0 | l, subblock_twiddle_0);
342 twiddle.set(idx1 | l, subblock_twiddle_1);
343 }
344 }
345 twiddle
346}
347
348#[cfg(test)]
349mod tests {
350 use assert_matches::assert_matches;
351 use binius_field::{
352 BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedFieldIndexable,
353 };
354 use binius_math::Error as MathError;
355 use rand::{rngs::StdRng, SeedableRng};
356
357 use super::*;
358
359 #[test]
360 fn test_additive_ntt_fails_with_field_too_small() {
361 assert_matches!(
362 SingleThreadedNTT::<BinaryField8b>::new(10),
363 Err(Error::MathError(MathError::DomainSizeTooLarge))
364 );
365 }
366
367 #[test]
368 fn test_subspace_size_agrees_with_domain_size() {
369 let ntt = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
370 assert_eq!(ntt.subspace(0).dim(), 10);
371 assert_eq!(ntt.subspace(9).dim(), 1);
372 }
373
374 #[test]
375 fn one_packed_field_forward() {
376 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
377 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
378
379 let mut packed_copy = packed;
380
381 let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
382
383 let _ = s.forward_transform(&mut packed, 3, 0, 3);
384
385 let _ = s.forward_transform(unpacked, 3, 0, 3);
386
387 for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
388 assert_eq!(packed[0].get(i), *unpacked_item);
389 }
390 }
391
392 #[test]
393 fn one_packed_field_inverse() {
394 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
395 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
396
397 let mut packed_copy = packed;
398
399 let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
400
401 let _ = s.inverse_transform(&mut packed, 3, 0, 3);
402
403 let _ = s.inverse_transform(unpacked, 3, 0, 3);
404
405 for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
406 assert_eq!(packed[0].get(i), *unpacked_item);
407 }
408 }
409
410 #[test]
411 fn smaller_embedded_batch_forward() {
412 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
413 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
414
415 let mut packed_copy = packed;
416
417 let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
418
419 let _ = forward_transform(s.log_domain_size(), &s.s_evals, &mut packed, 3, 0, 2);
420
421 let _ = s.forward_transform(unpacked, 3, 0, 2);
422
423 for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
424 assert_eq!(packed[0].get(i), *unpacked_item);
425 }
426 }
427
428 #[test]
429 fn smaller_embedded_batch_inverse() {
430 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
431 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
432
433 let mut packed_copy = packed;
434
435 let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
436
437 let _ = inverse_transform(s.log_domain_size(), &s.s_evals, &mut packed, 3, 0, 2);
438
439 let _ = s.inverse_transform(unpacked, 3, 0, 2);
440
441 for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
442 assert_eq!(packed[0].get(i), *unpacked_item);
443 }
444 }
445
446 }