binius_math/ntt/
reference.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::PackedField;
4
5use crate::{
6	field_buffer::FieldSliceMut,
7	ntt::{AdditiveNTT, DomainContext},
8};
9
10/// Reference implementation of [`AdditiveNTT`].
11///
12/// This is slow. Do not use in production.
13pub struct NeighborsLastReference<DC> {
14	pub domain_context: DC,
15}
16
17impl<DC: DomainContext> AdditiveNTT for NeighborsLastReference<DC> {
18	type Field = DC::Field;
19
20	fn forward_transform<P: PackedField<Scalar = Self::Field>>(
21		&self,
22		mut data: FieldSliceMut<P>,
23		skip_early: usize,
24		skip_late: usize,
25	) {
26		let log_d = data.log_len();
27		input_check(&self.domain_context, log_d, skip_early, skip_late);
28
29		for layer in skip_early..(log_d - skip_late) {
30			let num_blocks = 1 << layer;
31			let block_size_half = 1 << (log_d - layer - 1);
32			for block in 0..num_blocks {
33				let twiddle = self.domain_context.twiddle(layer, block);
34				let block_start = block << (log_d - layer);
35				for idx0 in block_start..(block_start + block_size_half) {
36					let idx1 = block_size_half | idx0;
37					// perform butterfly
38					let mut u = data.get_checked(idx0).unwrap();
39					let mut v = data.get_checked(idx1).unwrap();
40					u += v * twiddle;
41					v += u;
42					data.set_checked(idx0, u).unwrap();
43					data.set_checked(idx1, v).unwrap();
44				}
45			}
46		}
47	}
48
49	fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
50		&self,
51		mut data: FieldSliceMut<P>,
52		skip_early: usize,
53		skip_late: usize,
54	) {
55		let log_d = data.log_len();
56		input_check(&self.domain_context, log_d, skip_early, skip_late);
57
58		for layer in (skip_early..(log_d - skip_late)).rev() {
59			let num_blocks = 1 << layer;
60			let block_size_half = 1 << (log_d - layer - 1);
61			for block in 0..num_blocks {
62				let twiddle = self.domain_context.twiddle(layer, block);
63				let block_start = block << (log_d - layer);
64				for idx0 in block_start..(block_start + block_size_half) {
65					let idx1 = block_size_half | idx0;
66					// perform butterfly
67					let mut u = data.get_checked(idx0).unwrap();
68					let mut v = data.get_checked(idx1).unwrap();
69					v += u;
70					u += v * twiddle;
71					data.set_checked(idx0, u).unwrap();
72					data.set_checked(idx1, v).unwrap();
73				}
74			}
75		}
76	}
77
78	fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
79		&self.domain_context
80	}
81}
82
83/// Checks for the preconditions of the [`AdditiveNTT`] transforms.
84///
85/// ## Preconditions
86///
87/// - `skip_early + skip_late <= log_d`
88/// - `log_d - skip_late <= domain_context.log_domain_size()`
89pub fn input_check(
90	domain_context: &impl DomainContext,
91	log_d: usize,
92	skip_early: usize,
93	skip_late: usize,
94) {
95	// we can't "double-skip" layers
96	assert!(skip_early + skip_late <= log_d);
97
98	// we need enough twiddles in `domain_context`
99	assert!(log_d - skip_late <= domain_context.log_domain_size());
100}