1use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_utils::rayon::get_log_max_threads;
6
7use super::{
8 additive_ntt::{AdditiveNTT, NTTShape},
9 error::Error,
10 multithreaded::MultithreadedNTT,
11 single_threaded::SingleThreadedNTT,
12 twiddle::PrecomputedTwiddleAccess,
13};
14
15#[derive(Default, Debug, Clone, Copy)]
17pub enum ThreadingSettings {
18 #[default]
20 SingleThreaded,
21 MultithreadedDefault,
23 ExplicitThreadsCount { log_threads: usize },
25}
26
27impl ThreadingSettings {
28 pub fn log_threads_count(&self) -> usize {
30 match self {
31 Self::SingleThreaded => 0,
32 Self::MultithreadedDefault => get_log_max_threads(),
33 Self::ExplicitThreadsCount { log_threads } => *log_threads,
34 }
35 }
36
37 pub const fn is_multithreaded(&self) -> bool {
39 match self {
40 Self::SingleThreaded => false,
41 Self::MultithreadedDefault => true,
42 Self::ExplicitThreadsCount { log_threads } => *log_threads > 0,
43 }
44 }
45}
46
47#[derive(Default)]
48pub struct NTTOptions {
49 pub precompute_twiddles: bool,
50 pub thread_settings: ThreadingSettings,
51}
52
53#[derive(Debug)]
56pub enum DynamicDispatchNTT<F: BinaryField> {
57 SingleThreaded(SingleThreadedNTT<F>),
58 SingleThreadedPrecompute(SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>>),
59 MultiThreaded(MultithreadedNTT<F>),
60 MultiThreadedPrecompute(MultithreadedNTT<F, PrecomputedTwiddleAccess<F>>),
61}
62
63impl<F: BinaryField> DynamicDispatchNTT<F> {
64 pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result<Self, Error> {
66 let log_threads = options.thread_settings.log_threads_count();
67 let result = match (options.precompute_twiddles, log_threads) {
68 (false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?),
69 (true, 0) => Self::SingleThreadedPrecompute(
70 SingleThreadedNTT::new(log_domain_size)?.precompute_twiddles(),
71 ),
72 (false, _) => Self::MultiThreaded(
73 SingleThreadedNTT::new(log_domain_size)?
74 .multithreaded_with_max_threads(log_threads),
75 ),
76 (true, _) => Self::MultiThreadedPrecompute(
77 SingleThreadedNTT::new(log_domain_size)?
78 .precompute_twiddles()
79 .multithreaded_with_max_threads(log_threads),
80 ),
81 };
82
83 Ok(result)
84 }
85}
86
87impl<F: BinaryField> AdditiveNTT<F> for DynamicDispatchNTT<F> {
88 fn log_domain_size(&self) -> usize {
89 match self {
90 Self::SingleThreaded(ntt) => ntt.log_domain_size(),
91 Self::SingleThreadedPrecompute(ntt) => ntt.log_domain_size(),
92 Self::MultiThreaded(ntt) => ntt.log_domain_size(),
93 Self::MultiThreadedPrecompute(ntt) => ntt.log_domain_size(),
94 }
95 }
96
97 fn subspace(&self, i: usize) -> BinarySubspace<F> {
98 match self {
99 Self::SingleThreaded(ntt) => ntt.subspace(i),
100 Self::SingleThreadedPrecompute(ntt) => ntt.subspace(i),
101 Self::MultiThreaded(ntt) => ntt.subspace(i),
102 Self::MultiThreadedPrecompute(ntt) => ntt.subspace(i),
103 }
104 }
105
106 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
107 match self {
108 Self::SingleThreaded(ntt) => ntt.get_subspace_eval(i, j),
109 Self::SingleThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
110 Self::MultiThreaded(ntt) => ntt.get_subspace_eval(i, j),
111 Self::MultiThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
112 }
113 }
114
115 fn forward_transform<P: PackedField<Scalar = F>>(
116 &self,
117 data: &mut [P],
118 shape: NTTShape,
119 coset: usize,
120 coset_bits: usize,
121 skip_rounds: usize,
122 ) -> Result<(), Error> {
123 match self {
124 Self::SingleThreaded(ntt) => {
125 ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
126 }
127 Self::SingleThreadedPrecompute(ntt) => {
128 ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
129 }
130 Self::MultiThreaded(ntt) => {
131 ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
132 }
133 Self::MultiThreadedPrecompute(ntt) => {
134 ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
135 }
136 }
137 }
138
139 fn inverse_transform<P: PackedField<Scalar = F>>(
140 &self,
141 data: &mut [P],
142 shape: NTTShape,
143 coset: usize,
144 coset_bits: usize,
145 skip_rounds: usize,
146 ) -> Result<(), Error> {
147 match self {
148 Self::SingleThreaded(ntt) => {
149 ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
150 }
151 Self::SingleThreadedPrecompute(ntt) => {
152 ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
153 }
154 Self::MultiThreaded(ntt) => {
155 ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
156 }
157 Self::MultiThreadedPrecompute(ntt) => {
158 ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use binius_field::BinaryField8b;
167
168 use super::*;
169
170 #[test]
171 fn test_creation() {
172 fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT<BinaryField8b> {
173 DynamicDispatchNTT::<BinaryField8b>::new(6, options).unwrap()
174 }
175
176 let ntt = make_ntt(&NTTOptions {
177 precompute_twiddles: false,
178 thread_settings: ThreadingSettings::SingleThreaded,
179 });
180 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
181
182 let ntt = make_ntt(&NTTOptions {
183 precompute_twiddles: true,
184 thread_settings: ThreadingSettings::SingleThreaded,
185 });
186 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
187
188 let multithreaded = get_log_max_threads() > 0;
189 let ntt = make_ntt(&NTTOptions {
190 precompute_twiddles: false,
191 thread_settings: ThreadingSettings::MultithreadedDefault,
192 });
193 if multithreaded {
194 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
195 } else {
196 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
197 }
198
199 let ntt = make_ntt(&NTTOptions {
200 precompute_twiddles: true,
201 thread_settings: ThreadingSettings::MultithreadedDefault,
202 });
203 if multithreaded {
204 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreadedPrecompute(_)));
205 } else {
206 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
207 }
208
209 let ntt = make_ntt(&NTTOptions {
210 precompute_twiddles: false,
211 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 },
212 });
213 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
214
215 let ntt = make_ntt(&NTTOptions {
216 precompute_twiddles: true,
217 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
218 });
219 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
220
221 let ntt = make_ntt(&NTTOptions {
222 precompute_twiddles: false,
223 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
224 });
225 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
226 }
227}