1use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_utils::rayon::get_log_max_threads;
6
7use super::{
8 additive_ntt::AdditiveNTT, error::Error, multithreaded::MultithreadedNTT,
9 single_threaded::SingleThreadedNTT, twiddle::PrecomputedTwiddleAccess,
10};
11
12#[derive(Default, Debug, Clone, Copy)]
14pub enum ThreadingSettings {
15 #[default]
17 SingleThreaded,
18 MultithreadedDefault,
20 ExplicitThreadsCount { log_threads: usize },
22}
23
24impl ThreadingSettings {
25 pub fn log_threads_count(&self) -> usize {
27 match self {
28 Self::SingleThreaded => 0,
29 Self::MultithreadedDefault => get_log_max_threads(),
30 Self::ExplicitThreadsCount { log_threads } => *log_threads,
31 }
32 }
33
34 pub const fn is_multithreaded(&self) -> bool {
36 match self {
37 Self::SingleThreaded => false,
38 Self::MultithreadedDefault => true,
39 Self::ExplicitThreadsCount { log_threads } => *log_threads > 0,
40 }
41 }
42}
43
44#[derive(Default)]
45pub struct NTTOptions {
46 pub precompute_twiddles: bool,
47 pub thread_settings: ThreadingSettings,
48}
49
50#[derive(Debug)]
52pub enum DynamicDispatchNTT<F: BinaryField> {
53 SingleThreaded(SingleThreadedNTT<F>),
54 SingleThreadedPrecompute(SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>>),
55 MultiThreaded(MultithreadedNTT<F>),
56 MultiThreadedPrecompute(MultithreadedNTT<F, PrecomputedTwiddleAccess<F>>),
57}
58
59impl<F: BinaryField> DynamicDispatchNTT<F> {
60 pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result<Self, Error> {
62 let log_threads = options.thread_settings.log_threads_count();
63 let result = match (options.precompute_twiddles, log_threads) {
64 (false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?),
65 (true, 0) => Self::SingleThreadedPrecompute(
66 SingleThreadedNTT::new(log_domain_size)?.precompute_twiddles(),
67 ),
68 (false, _) => Self::MultiThreaded(
69 SingleThreadedNTT::new(log_domain_size)?
70 .multithreaded_with_max_threads(log_threads),
71 ),
72 (true, _) => Self::MultiThreadedPrecompute(
73 SingleThreadedNTT::new(log_domain_size)?
74 .precompute_twiddles()
75 .multithreaded_with_max_threads(log_threads),
76 ),
77 };
78
79 Ok(result)
80 }
81}
82
83impl<F: BinaryField> AdditiveNTT<F> for DynamicDispatchNTT<F> {
84 fn log_domain_size(&self) -> usize {
85 match self {
86 Self::SingleThreaded(ntt) => ntt.log_domain_size(),
87 Self::SingleThreadedPrecompute(ntt) => ntt.log_domain_size(),
88 Self::MultiThreaded(ntt) => ntt.log_domain_size(),
89 Self::MultiThreadedPrecompute(ntt) => ntt.log_domain_size(),
90 }
91 }
92
93 fn subspace(&self, i: usize) -> BinarySubspace<F> {
94 match self {
95 Self::SingleThreaded(ntt) => ntt.subspace(i),
96 Self::SingleThreadedPrecompute(ntt) => ntt.subspace(i),
97 Self::MultiThreaded(ntt) => ntt.subspace(i),
98 Self::MultiThreadedPrecompute(ntt) => ntt.subspace(i),
99 }
100 }
101
102 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
103 match self {
104 Self::SingleThreaded(ntt) => ntt.get_subspace_eval(i, j),
105 Self::SingleThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
106 Self::MultiThreaded(ntt) => ntt.get_subspace_eval(i, j),
107 Self::MultiThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
108 }
109 }
110
111 fn forward_transform<P: PackedField<Scalar = F>>(
112 &self,
113 data: &mut [P],
114 coset: u32,
115 log_batch_size: usize,
116 log_n: usize,
117 ) -> Result<(), Error> {
118 match self {
119 Self::SingleThreaded(ntt) => ntt.forward_transform(data, coset, log_batch_size, log_n),
120 Self::SingleThreadedPrecompute(ntt) => {
121 ntt.forward_transform(data, coset, log_batch_size, log_n)
122 }
123 Self::MultiThreaded(ntt) => ntt.forward_transform(data, coset, log_batch_size, log_n),
124 Self::MultiThreadedPrecompute(ntt) => {
125 ntt.forward_transform(data, coset, log_batch_size, log_n)
126 }
127 }
128 }
129
130 fn inverse_transform<P: PackedField<Scalar = F>>(
131 &self,
132 data: &mut [P],
133 coset: u32,
134 log_batch_size: usize,
135 log_n: usize,
136 ) -> Result<(), Error> {
137 match self {
138 Self::SingleThreaded(ntt) => ntt.inverse_transform(data, coset, log_batch_size, log_n),
139 Self::SingleThreadedPrecompute(ntt) => {
140 ntt.inverse_transform(data, coset, log_batch_size, log_n)
141 }
142 Self::MultiThreaded(ntt) => ntt.inverse_transform(data, coset, log_batch_size, log_n),
143 Self::MultiThreadedPrecompute(ntt) => {
144 ntt.inverse_transform(data, coset, log_batch_size, log_n)
145 }
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use binius_field::BinaryField8b;
153
154 use super::*;
155
156 #[test]
157 fn test_creation() {
158 fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT<BinaryField8b> {
159 DynamicDispatchNTT::<BinaryField8b>::new(6, options).unwrap()
160 }
161
162 let ntt = make_ntt(&NTTOptions {
163 precompute_twiddles: false,
164 thread_settings: ThreadingSettings::SingleThreaded,
165 });
166 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
167
168 let ntt = make_ntt(&NTTOptions {
169 precompute_twiddles: true,
170 thread_settings: ThreadingSettings::SingleThreaded,
171 });
172 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
173
174 let multithreaded = get_log_max_threads() > 0;
175 let ntt = make_ntt(&NTTOptions {
176 precompute_twiddles: false,
177 thread_settings: ThreadingSettings::MultithreadedDefault,
178 });
179 if multithreaded {
180 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
181 } else {
182 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
183 }
184
185 let ntt = make_ntt(&NTTOptions {
186 precompute_twiddles: true,
187 thread_settings: ThreadingSettings::MultithreadedDefault,
188 });
189 if multithreaded {
190 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreadedPrecompute(_)));
191 } else {
192 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
193 }
194
195 let ntt = make_ntt(&NTTOptions {
196 precompute_twiddles: false,
197 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 },
198 });
199 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
200
201 let ntt = make_ntt(&NTTOptions {
202 precompute_twiddles: true,
203 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
204 });
205 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
206
207 let ntt = make_ntt(&NTTOptions {
208 precompute_twiddles: false,
209 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
210 });
211 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
212 }
213}