1use std::{cmp, marker::PhantomData};
4
5use binius_field::{BinaryField, PackedField, TowerField};
6use binius_math::BinarySubspace;
7use binius_utils::bail;
8
9use super::{
10 additive_ntt::{AdditiveNTT, NTTShape},
11 error::Error,
12 twiddle::TwiddleAccess,
13};
14use crate::twiddle::{OnTheFlyTwiddleAccess, PrecomputedTwiddleAccess, expand_subspace_evals};
15
16#[derive(Debug)]
18pub struct SingleThreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F>> {
19 pub(super) s_evals: Vec<TA>,
21 _marker: PhantomData<F>,
22}
23
24impl<F: BinaryField> SingleThreadedNTT<F> {
25 pub fn new(log_domain_size: usize) -> Result<Self, Error> {
28 let subspace = BinarySubspace::with_dim(log_domain_size)?;
29 Self::with_subspace(&subspace)
30 }
31
32 pub fn with_domain_field<FDomain>(log_domain_size: usize) -> Result<Self, Error>
35 where
36 FDomain: BinaryField,
37 F: From<FDomain>,
38 {
39 let subspace = BinarySubspace::<FDomain>::with_dim(log_domain_size)?.isomorphic();
40 Self::with_subspace(&subspace)
41 }
42
43 pub fn with_subspace(subspace: &BinarySubspace<F>) -> Result<Self, Error> {
44 let twiddle_access = OnTheFlyTwiddleAccess::generate(subspace)?;
45 Ok(Self::with_twiddle_access(twiddle_access))
46 }
47
48 pub fn precompute_twiddles(&self) -> SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>> {
49 SingleThreadedNTT::with_twiddle_access(expand_subspace_evals(&self.s_evals))
50 }
51}
52
53impl<F: TowerField> SingleThreadedNTT<F> {
54 pub fn with_canonical_field(log_domain_size: usize) -> Result<Self, Error> {
57 Self::with_domain_field::<F::Canonical>(log_domain_size)
58 }
59}
60
61impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
62 const fn with_twiddle_access(twiddle_access: Vec<TA>) -> Self {
63 Self {
64 s_evals: twiddle_access,
65 _marker: PhantomData,
66 }
67 }
68}
69
70impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
71 pub fn twiddles(&self) -> &[TA] {
72 &self.s_evals
73 }
74}
75
76impl<F, TA> AdditiveNTT<F> for SingleThreadedNTT<F, TA>
77where
78 F: BinaryField,
79 TA: TwiddleAccess<F>,
80{
81 fn log_domain_size(&self) -> usize {
82 self.s_evals.len()
83 }
84
85 fn subspace(&self, i: usize) -> BinarySubspace<F> {
86 let (subspace, shift) = self.s_evals[i].affine_subspace();
87 debug_assert_eq!(shift, F::ZERO, "s_evals subspaces must be linear by construction");
88 subspace
89 }
90
91 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
92 self.s_evals[i].get(j)
93 }
94
95 fn forward_transform<P: PackedField<Scalar = F>>(
96 &self,
97 data: &mut [P],
98 shape: NTTShape,
99 coset: usize,
100 coset_bits: usize,
101 skip_rounds: usize,
102 ) -> Result<(), Error> {
103 forward_transform(
104 self.log_domain_size(),
105 &self.s_evals,
106 data,
107 shape,
108 coset,
109 coset_bits,
110 skip_rounds,
111 )
112 }
113
114 fn inverse_transform<P: PackedField<Scalar = F>>(
115 &self,
116 data: &mut [P],
117 shape: NTTShape,
118 coset: usize,
119 coset_bits: usize,
120 skip_rounds: usize,
121 ) -> Result<(), Error> {
122 inverse_transform(
123 self.log_domain_size(),
124 &self.s_evals,
125 data,
126 shape,
127 coset,
128 coset_bits,
129 skip_rounds,
130 )
131 }
132}
133
134pub fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
135 log_domain_size: usize,
136 s_evals: &[impl TwiddleAccess<F>],
137 data: &mut [P],
138 shape: NTTShape,
139 coset: usize,
140 coset_bits: usize,
141 skip_rounds: usize,
142) -> Result<(), Error> {
143 check_batch_transform_inputs_and_params(
144 log_domain_size,
145 data,
146 shape,
147 coset,
148 coset_bits,
149 skip_rounds,
150 )?;
151
152 match data.len() {
153 0 => return Ok(()),
154 1 => {
155 return match P::LOG_WIDTH {
156 0 => Ok(()),
157 _ => {
158 let mut buffer = [data[0], P::zero()];
165 forward_transform(
166 log_domain_size,
167 s_evals,
168 &mut buffer,
169 shape,
170 coset,
171 coset_bits,
172 skip_rounds,
173 )?;
174 data[0] = buffer[0];
175 Ok(())
176 }
177 };
178 }
179 _ => {}
180 };
181
182 let NTTShape {
183 log_x,
184 log_y,
185 log_z,
186 } = shape;
187
188 let log_w = P::LOG_WIDTH;
189
190 let cutoff = log_w.saturating_sub(log_x);
193
194 let s_evals = &s_evals[log_domain_size - (log_y + coset_bits)..];
197
198 for i in (cutoff..(log_y - skip_rounds)).rev() {
200 let s_evals_i = &s_evals[i];
201 let coset_offset = coset << (log_y - 1 - i);
202
203 for j in 0..1 << log_z {
205 for k in 0..1 << (log_y - 1 - i) {
208 let twiddle = s_evals_i.get(coset_offset | k);
209 for l in 0..1 << (i + log_x - log_w) {
210 let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
211 let idx1 = idx0 | 1 << (log_x + i - log_w);
212 data[idx0] += data[idx1] * twiddle;
213 data[idx1] += data[idx0];
214 }
215 }
216 }
217 }
218
219 for i in (0..cmp::min(cutoff, log_y - skip_rounds)).rev() {
220 let s_evals_i = &s_evals[i];
221 let coset_offset = coset << (log_y - 1 - i);
222
223 let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
229
230 let log_block_len = i + log_x;
231 let log_packed_count = (log_y - 1).saturating_sub(cutoff);
232 for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
233 for k in 0..1 << log_packed_count {
234 let twiddle =
235 P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
236 let index = k << 1 | j << (log_packed_count + 1);
237 let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
238 u += v * twiddle;
239 v += u;
240 (data[index], data[index | 1]) = u.interleave(v, log_block_len);
241 }
242 }
243 }
244
245 Ok(())
246}
247
248pub fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
249 log_domain_size: usize,
250 s_evals: &[impl TwiddleAccess<F>],
251 data: &mut [P],
252 shape: NTTShape,
253 coset: usize,
254 coset_bits: usize,
255 skip_rounds: usize,
256) -> Result<(), Error> {
257 check_batch_transform_inputs_and_params(
258 log_domain_size,
259 data,
260 shape,
261 coset,
262 coset_bits,
263 skip_rounds,
264 )?;
265
266 match data.len() {
267 0 => return Ok(()),
268 1 => {
269 return match P::LOG_WIDTH {
270 0 => Ok(()),
271 _ => {
272 let mut buffer = [data[0], P::zero()];
279 inverse_transform(
280 log_domain_size,
281 s_evals,
282 &mut buffer,
283 shape,
284 coset,
285 coset_bits,
286 skip_rounds,
287 )?;
288 data[0] = buffer[0];
289 Ok(())
290 }
291 };
292 }
293 _ => {}
294 };
295
296 let NTTShape {
297 log_x,
298 log_y,
299 log_z,
300 } = shape;
301
302 let log_w = P::LOG_WIDTH;
303
304 let cutoff = log_w.saturating_sub(log_x);
307
308 let s_evals = &s_evals[log_domain_size - (log_y + coset_bits)..];
311
312 #[allow(clippy::needless_range_loop)]
313 for i in 0..cutoff.min(log_y - skip_rounds) {
314 let s_evals_i = &s_evals[i];
315 let coset_offset = coset << (log_y - 1 - i);
316
317 let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
323
324 let log_block_len = i + log_x;
325 let log_packed_count = (log_y - 1).saturating_sub(cutoff);
326 for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
327 for k in 0..1 << log_packed_count {
328 let twiddle =
329 P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
330 let index = k << 1 | j << (log_packed_count + 1);
331 let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
332 v += u;
333 u += v * twiddle;
334 (data[index], data[index | 1]) = u.interleave(v, log_block_len);
335 }
336 }
337 }
338
339 #[allow(clippy::needless_range_loop)]
341 for i in cutoff..(log_y - skip_rounds) {
342 let s_evals_i = &s_evals[i];
343 let coset_offset = coset << (log_y - 1 - i);
344
345 for j in 0..1 << log_z {
347 for k in 0..1 << (log_y - 1 - i) {
350 let twiddle = s_evals_i.get(coset_offset | k);
351 for l in 0..1 << (i + log_x - log_w) {
352 let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
353 let idx1 = idx0 | 1 << (log_x + i - log_w);
354 data[idx1] += data[idx0];
355 data[idx0] += data[idx1] * twiddle;
356 }
357 }
358 }
359 }
360
361 Ok(())
362}
363
364pub fn check_batch_transform_inputs_and_params<PB: PackedField>(
365 log_domain_size: usize,
366 data: &[PB],
367 shape: NTTShape,
368 coset: usize,
369 coset_bits: usize,
370 skip_rounds: usize,
371) -> Result<(), Error> {
372 let NTTShape {
373 log_x,
374 log_y,
375 log_z,
376 } = shape;
377
378 if !data.len().is_power_of_two() {
379 bail!(Error::PowerOfTwoLengthRequired);
380 }
381 if skip_rounds > log_y {
382 bail!(Error::SkipRoundsTooLarge);
383 }
384
385 let full_sized_y = (data.len() * PB::WIDTH) >> (log_x + log_z);
386
387 if (1 << log_y != full_sized_y && data.len() > 2) || (1 << log_y > full_sized_y) {
390 bail!(Error::BatchTooLarge);
391 }
392
393 if coset >= (1 << coset_bits) {
394 bail!(Error::CosetIndexOutOfBounds { coset, coset_bits });
395 }
396
397 let log_required_domain_size = log_y + coset_bits;
399 if log_required_domain_size > log_domain_size {
400 bail!(Error::DomainTooSmall {
401 log_required_domain_size
402 });
403 }
404
405 Ok(())
406}
407
408#[inline]
409fn calculate_packed_additive_twiddle<P>(
410 s_evals: &impl TwiddleAccess<P::Scalar>,
411 shape: NTTShape,
412 ntt_round: usize,
413) -> P
414where
415 P: PackedField<Scalar: BinaryField>,
416{
417 let NTTShape {
418 log_x,
419 log_y,
420 log_z,
421 } = shape;
422 debug_assert!(log_y > 0);
423
424 let log_block_len = ntt_round + log_x;
425 debug_assert!(log_block_len < P::LOG_WIDTH);
426
427 let packed_log_len = (log_x + log_y + log_z).min(P::LOG_WIDTH);
428 let log_blocks_count = packed_log_len.saturating_sub(log_block_len + 1);
429
430 let packed_log_z = packed_log_len.saturating_sub(log_x + log_y);
431 let packed_log_y = packed_log_len - packed_log_z - log_x;
432
433 let twiddle_stride = P::LOG_WIDTH
434 .saturating_sub(log_x)
435 .min(log_blocks_count - packed_log_z);
436
437 let mut twiddle = P::default();
438 for i in 0..1 << (log_blocks_count - twiddle_stride) {
439 for j in 0..1 << twiddle_stride {
440 let (subblock_twiddle_0, subblock_twiddle_1) = if packed_log_y == log_y {
441 let same_twiddle = s_evals.get(j);
442 (same_twiddle, same_twiddle)
443 } else {
444 s_evals.get_pair(twiddle_stride, j)
445 };
446 let idx0 = j << (log_block_len + 1) | i << (log_block_len + twiddle_stride + 1);
447 let idx1 = idx0 | 1 << log_block_len;
448
449 for k in 0..1 << log_block_len {
450 twiddle.set(idx0 | k, subblock_twiddle_0);
451 twiddle.set(idx1 | k, subblock_twiddle_1);
452 }
453 }
454 }
455 twiddle
456}
457
458#[cfg(test)]
459mod tests {
460 use std::iter::repeat_with;
461
462 use assert_matches::assert_matches;
463 use binius_field::{
464 BinaryField8b, BinaryField16b, Field, PackedBinaryField8x16b, PackedFieldIndexable,
465 };
466 use binius_math::Error as MathError;
467 use rand::{SeedableRng, rngs::StdRng};
468
469 use super::*;
470
471 #[test]
472 fn test_additive_ntt_fails_with_field_too_small() {
473 assert_matches!(
474 SingleThreadedNTT::<BinaryField8b>::new(10),
475 Err(Error::MathError(MathError::DomainSizeTooLarge))
476 );
477 }
478
479 #[test]
480 fn test_subspace_size_agrees_with_domain_size() {
481 let ntt = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
482 assert_eq!(ntt.subspace(0).dim(), 10);
483 assert_eq!(ntt.subspace(9).dim(), 1);
484 }
485
486 #[test]
499 fn test_repetition_property() {
500 let log_len = 8;
501 let ntt = SingleThreadedNTT::<BinaryField16b>::new(log_len + 2).unwrap();
502
503 let mut rng = StdRng::seed_from_u64(0);
504 let msg = repeat_with(|| <BinaryField16b as Field>::random(&mut rng))
505 .take(1 << log_len)
506 .collect::<Vec<_>>();
507
508 let mut msg_padded = vec![BinaryField16b::ZERO; 1 << (log_len + 2)];
509 for i in 0..1 << log_len {
510 msg_padded[i << 2] = msg[i];
511 }
512
513 let mut out = msg;
514 ntt.forward_transform(
515 &mut out,
516 NTTShape {
517 log_y: log_len,
518 ..Default::default()
519 },
520 0,
521 0,
522 0,
523 )
524 .unwrap();
525 let mut out_rep = msg_padded;
526 ntt.forward_transform(
527 &mut out_rep,
528 NTTShape {
529 log_y: log_len + 2,
530 ..Default::default()
531 },
532 0,
533 0,
534 0,
535 )
536 .unwrap();
537 for i in 0..1 << (log_len + 2) {
538 assert_eq!(out_rep[i], out[i >> 2]);
539 }
540 }
541
542 #[test]
543 fn one_packed_field_forward() {
544 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
545 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
546
547 let mut packed_copy = packed;
548
549 let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
550
551 let shape = NTTShape {
552 log_x: 0,
553 log_y: 3,
554 log_z: 0,
555 };
556 let _ = s.forward_transform(&mut packed, shape, 3, 2, 0);
557 let _ = s.forward_transform(unpacked, shape, 3, 2, 0);
558
559 for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
560 assert_eq!(packed[0].get(i), *unpacked_item);
561 }
562 }
563
564 #[test]
565 fn one_packed_field_inverse() {
566 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
567 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
568
569 let mut packed_copy = packed;
570
571 let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
572
573 let shape = NTTShape {
574 log_x: 0,
575 log_y: 3,
576 log_z: 0,
577 };
578 let _ = s.inverse_transform(&mut packed, shape, 3, 2, 0);
579 let _ = s.inverse_transform(unpacked, shape, 3, 2, 0);
580
581 for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
582 assert_eq!(packed[0].get(i), *unpacked_item);
583 }
584 }
585
586 #[test]
587 fn smaller_embedded_batch_forward() {
588 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
589 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
590
591 let mut packed_copy = packed;
592
593 let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
594
595 let shape = NTTShape {
596 log_x: 0,
597 log_y: 2,
598 log_z: 0,
599 };
600 let _ = forward_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 2, 0);
601 let _ = s.forward_transform(unpacked, shape, 3, 2, 0);
602
603 for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
604 assert_eq!(packed[0].get(i), *unpacked_item);
605 }
606 }
607
608 #[test]
609 fn smaller_embedded_batch_inverse() {
610 let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
611 let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
612
613 let mut packed_copy = packed;
614
615 let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
616
617 let shape = NTTShape {
618 log_x: 0,
619 log_y: 2,
620 log_z: 0,
621 };
622 let _ = inverse_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 2, 0);
623 let _ = s.inverse_transform(unpacked, shape, 3, 2, 0);
624
625 for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
626 assert_eq!(packed[0].get(i), *unpacked_item);
627 }
628 }
629
630 }