1use std::{cmp, marker::PhantomData};
4
5use binius_field::{BinaryField, PackedField, TowerField};
6use binius_math::BinarySubspace;
7
8use super::{
9 additive_ntt::{AdditiveNTT, NTTShape},
10 error::Error,
11 twiddle::TwiddleAccess,
12};
13use crate::twiddle::{expand_subspace_evals, OnTheFlyTwiddleAccess, PrecomputedTwiddleAccess};
14
15#[derive(Debug)]
17pub struct SingleThreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F>> {
18 pub(super) s_evals: Vec<TA>,
20 _marker: PhantomData<F>,
21}
22
23impl<F: BinaryField> SingleThreadedNTT<F> {
24 pub fn new(log_domain_size: usize) -> Result<Self, Error> {
27 let subspace = BinarySubspace::with_dim(log_domain_size)?;
28 Self::with_subspace(&subspace)
29 }
30
31 pub fn with_domain_field<FDomain>(log_domain_size: usize) -> Result<Self, Error>
34 where
35 FDomain: BinaryField,
36 F: From<FDomain>,
37 {
38 let subspace = BinarySubspace::<FDomain>::with_dim(log_domain_size)?.isomorphic();
39 Self::with_subspace(&subspace)
40 }
41
42 pub fn with_subspace(subspace: &BinarySubspace<F>) -> Result<Self, Error> {
43 let twiddle_access = OnTheFlyTwiddleAccess::generate(subspace)?;
44 Ok(Self::with_twiddle_access(twiddle_access))
45 }
46
47 pub fn precompute_twiddles(&self) -> SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>> {
48 SingleThreadedNTT::with_twiddle_access(expand_subspace_evals(&self.s_evals))
49 }
50}
51
52impl<F: TowerField> SingleThreadedNTT<F> {
53 pub fn with_canonical_field(log_domain_size: usize) -> Result<Self, Error> {
55 Self::with_domain_field::<F::Canonical>(log_domain_size)
56 }
57}
58
59impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
60 const fn with_twiddle_access(twiddle_access: Vec<TA>) -> Self {
61 Self {
62 s_evals: twiddle_access,
63 _marker: PhantomData,
64 }
65 }
66}
67
68impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
69 pub fn twiddles(&self) -> &[TA] {
70 &self.s_evals
71 }
72}
73
74impl<F, TA> AdditiveNTT<F> for SingleThreadedNTT<F, TA>
75where
76 F: BinaryField,
77 TA: TwiddleAccess<F>,
78{
79 fn log_domain_size(&self) -> usize {
80 self.s_evals.len()
81 }
82
83 fn subspace(&self, i: usize) -> BinarySubspace<F> {
84 let (subspace, shift) = self.s_evals[i].affine_subspace();
85 debug_assert_eq!(shift, F::ZERO, "s_evals subspaces must be linear by construction");
86 subspace
87 }
88
89 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
90 self.s_evals[i].get(j)
91 }
92
93 fn forward_transform<P: PackedField<Scalar = F>>(
94 &self,
95 data: &mut [P],
96 shape: NTTShape,
97 coset: u32,
98 skip_rounds: usize,
99 ) -> Result<(), Error> {
100 forward_transform(self.log_domain_size(), &self.s_evals, data, shape, coset, skip_rounds)
101 }
102
103 fn inverse_transform<P: PackedField<Scalar = F>>(
104 &self,
105 data: &mut [P],
106 shape: NTTShape,
107 coset: u32,
108 skip_rounds: usize,
109 ) -> Result<(), Error> {
110 inverse_transform(self.log_domain_size(), &self.s_evals, data, shape, coset, skip_rounds)
111 }
112}
113
114pub fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
115 log_domain_size: usize,
116 s_evals: &[impl TwiddleAccess<F>],
117 data: &mut [P],
118 shape: NTTShape,
119 coset: u32,
120 skip_rounds: usize,
121) -> Result<(), Error> {
122 check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
123
124 match data.len() {
125 0 => return Ok(()),
126 1 => {
127 return match P::LOG_WIDTH {
128 0 => Ok(()),
129 _ => {
130 let mut buffer = [data[0], P::zero()];
137 forward_transform(
138 log_domain_size,
139 s_evals,
140 &mut buffer,
141 shape,
142 coset,
143 skip_rounds,
144 )?;
145 data[0] = buffer[0];
146 Ok(())
147 }
148 };
149 }
150 _ => {}
151 };
152
153 let NTTShape {
154 log_x,
155 log_y,
156 log_z,
157 } = shape;
158
159 let log_w = P::LOG_WIDTH;
160
161 let cutoff = log_w.saturating_sub(log_x);
164
165 for i in (cutoff..(log_y - skip_rounds)).rev() {
167 let s_evals_i = &s_evals[i];
168 let coset_offset = (coset as usize) << (log_y - 1 - i);
169
170 for j in 0..1 << log_z {
172 for k in 0..1 << (log_y - 1 - i) {
175 let twiddle = s_evals_i.get(coset_offset | k);
176 for l in 0..1 << (i + log_x - log_w) {
177 let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
178 let idx1 = idx0 | 1 << (log_x + i - log_w);
179 data[idx0] += data[idx1] * twiddle;
180 data[idx1] += data[idx0];
181 }
182 }
183 }
184 }
185
186 for i in (0..cmp::min(cutoff, log_y - skip_rounds)).rev() {
187 let s_evals_i = &s_evals[i];
188 let coset_offset = (coset as usize) << (log_y - 1 - i);
189
190 let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
196
197 let log_block_len = i + log_x;
198 let log_packed_count = (log_y - 1).saturating_sub(cutoff);
199 for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
200 for k in 0..1 << log_packed_count {
201 let twiddle =
202 P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
203 let index = k << 1 | j << (log_packed_count + 1);
204 let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
205 u += v * twiddle;
206 v += u;
207 (data[index], data[index | 1]) = u.interleave(v, log_block_len);
208 }
209 }
210 }
211
212 Ok(())
213}
214
215pub fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
216 log_domain_size: usize,
217 s_evals: &[impl TwiddleAccess<F>],
218 data: &mut [P],
219 shape: NTTShape,
220 coset: u32,
221 skip_rounds: usize,
222) -> Result<(), Error> {
223 check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
224
225 match data.len() {
226 0 => return Ok(()),
227 1 => {
228 return match P::LOG_WIDTH {
229 0 => Ok(()),
230 _ => {
231 let mut buffer = [data[0], P::zero()];
238 inverse_transform(
239 log_domain_size,
240 s_evals,
241 &mut buffer,
242 shape,
243 coset,
244 skip_rounds,
245 )?;
246 data[0] = buffer[0];
247 Ok(())
248 }
249 };
250 }
251 _ => {}
252 };
253
254 let NTTShape {
255 log_x,
256 log_y,
257 log_z,
258 } = shape;
259
260 let log_w = P::LOG_WIDTH;
261
262 let cutoff = log_w.saturating_sub(log_x);
265
266 #[allow(clippy::needless_range_loop)]
267 for i in 0..cutoff.min(log_y - skip_rounds) {
268 let s_evals_i = &s_evals[i];
269 let coset_offset = (coset as usize) << (log_y - 1 - i);
270
271 let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
277
278 let log_block_len = i + log_x;
279 let log_packed_count = (log_y - 1).saturating_sub(cutoff);
280 for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
281 for k in 0..1 << log_packed_count {
282 let twiddle =
283 P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
284 let index = k << 1 | j << (log_packed_count + 1);
285 let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
286 v += u;
287 u += v * twiddle;
288 (data[index], data[index | 1]) = u.interleave(v, log_block_len);
289 }
290 }
291 }
292
293 #[allow(clippy::needless_range_loop)]
295 for i in cutoff..(log_y - skip_rounds) {
296 let s_evals_i = &s_evals[i];
297 let coset_offset = (coset as usize) << (log_y - 1 - i);
298
299 for j in 0..1 << log_z {
301 for k in 0..1 << (log_y - 1 - i) {
304 let twiddle = s_evals_i.get(coset_offset | k);
305 for l in 0..1 << (i + log_x - log_w) {
306 let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
307 let idx1 = idx0 | 1 << (log_x + i - log_w);
308 data[idx1] += data[idx0];
309 data[idx0] += data[idx1] * twiddle;
310 }
311 }
312 }
313 }
314
315 Ok(())
316}
317
318pub fn check_batch_transform_inputs_and_params<PB: PackedField>(
319 log_domain_size: usize,
320 data: &[PB],
321 shape: NTTShape,
322 coset: u32,
323 skip_rounds: usize,
324) -> Result<(), Error> {
325 let NTTShape {
326 log_x,
327 log_y,
328 log_z,
329 } = shape;
330
331 if !data.len().is_power_of_two() {
332 return Err(Error::PowerOfTwoLengthRequired);
333 }
334 if skip_rounds > log_y {
335 return Err(Error::SkipRoundsTooLarge);
336 }
337
338 let full_sized_y = (data.len() * PB::WIDTH) >> (log_x + log_z);
339
340 if (1 << log_y != full_sized_y && data.len() > 2) || (1 << log_y > full_sized_y) {
342 return Err(Error::BatchTooLarge);
343 }
344
345 let coset_bits = 32 - coset.leading_zeros() as usize;
346
347 let log_required_domain_size = log_y + coset_bits;
349 if log_required_domain_size > log_domain_size {
350 return Err(Error::DomainTooSmall {
351 log_required_domain_size,
352 });
353 }
354
355 Ok(())
356}
357
358#[inline]
359fn calculate_packed_additive_twiddle<P>(
360 s_evals: &impl TwiddleAccess<P::Scalar>,
361 shape: NTTShape,
362 ntt_round: usize,
363) -> P
364where
365 P: PackedField<Scalar: BinaryField>,
366{
367 let NTTShape {
368 log_x,
369 log_y,
370 log_z,
371 } = shape;
372 debug_assert!(log_y > 0);
373
374 let log_block_len = ntt_round + log_x;
375 debug_assert!(log_block_len < P::LOG_WIDTH);
376
377 let packed_log_len = (log_x + log_y + log_z).min(P::LOG_WIDTH);
378 let log_blocks_count = packed_log_len.saturating_sub(log_block_len + 1);
379
380 let packed_log_z = packed_log_len.saturating_sub(log_x + log_y);
381 let packed_log_y = packed_log_len - packed_log_z - log_x;
382
383 let twiddle_stride = P::LOG_WIDTH
384 .saturating_sub(log_x)
385 .min(log_blocks_count - packed_log_z);
386
387 let mut twiddle = P::default();
388 for i in 0..1 << (log_blocks_count - twiddle_stride) {
389 for j in 0..1 << twiddle_stride {
390 let (subblock_twiddle_0, subblock_twiddle_1) = if packed_log_y == log_y {
391 let same_twiddle = s_evals.get(j);
392 (same_twiddle, same_twiddle)
393 } else {
394 s_evals.get_pair(twiddle_stride, j)
395 };
396 let idx0 = j << (log_block_len + 1) | i << (log_block_len + twiddle_stride + 1);
397 let idx1 = idx0 | 1 << log_block_len;
398
399 for k in 0..1 << log_block_len {
400 twiddle.set(idx0 | k, subblock_twiddle_0);
401 twiddle.set(idx1 | k, subblock_twiddle_1);
402 }
403 }
404 }
405 twiddle
406}
407
408#[cfg(test)]
409mod tests {
410 use assert_matches::assert_matches;
411 use binius_field::{
412 BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedFieldIndexable,
413 };
414 use binius_math::Error as MathError;
415 use rand::{rngs::StdRng, SeedableRng};
416
417 use super::*;
418
419 #[test]
420 fn test_additive_ntt_fails_with_field_too_small() {
421 assert_matches!(
422 SingleThreadedNTT::<BinaryField8b>::new(10),
423 Err(Error::MathError(MathError::DomainSizeTooLarge))
424 );
425 }
426
427 #[test]
428 fn test_subspace_size_agrees_with_domain_size() {
429 let ntt = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
430 assert_eq!(ntt.subspace(0).dim(), 10);
431 assert_eq!(ntt.subspace(9).dim(), 1);
432 }
433
434 #[test]
435 fn one_packed_field_forward() {
436 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
437 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
438
439 let mut packed_copy = packed;
440
441 let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
442
443 let shape = NTTShape {
444 log_x: 0,
445 log_y: 3,
446 log_z: 0,
447 };
448 let _ = s.forward_transform(&mut packed, shape, 3, 0);
449 let _ = s.forward_transform(unpacked, shape, 3, 0);
450
451 for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
452 assert_eq!(packed[0].get(i), *unpacked_item);
453 }
454 }
455
456 #[test]
457 fn one_packed_field_inverse() {
458 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
459 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
460
461 let mut packed_copy = packed;
462
463 let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
464
465 let shape = NTTShape {
466 log_x: 0,
467 log_y: 3,
468 log_z: 0,
469 };
470 let _ = s.inverse_transform(&mut packed, shape, 3, 0);
471 let _ = s.inverse_transform(unpacked, shape, 3, 0);
472
473 for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
474 assert_eq!(packed[0].get(i), *unpacked_item);
475 }
476 }
477
478 #[test]
479 fn smaller_embedded_batch_forward() {
480 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
481 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
482
483 let mut packed_copy = packed;
484
485 let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
486
487 let shape = NTTShape {
488 log_x: 0,
489 log_y: 2,
490 log_z: 0,
491 };
492 let _ = forward_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 0);
493 let _ = s.forward_transform(unpacked, shape, 3, 0);
494
495 for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
496 assert_eq!(packed[0].get(i), *unpacked_item);
497 }
498 }
499
500 #[test]
501 fn smaller_embedded_batch_inverse() {
502 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
503 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
504
505 let mut packed_copy = packed;
506
507 let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
508
509 let shape = NTTShape {
510 log_x: 0,
511 log_y: 2,
512 log_z: 0,
513 };
514 let _ = inverse_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 0);
515 let _ = s.inverse_transform(unpacked, shape, 3, 0);
516
517 for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
518 assert_eq!(packed[0].get(i), *unpacked_item);
519 }
520 }
521
522 }