1use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_utils::rayon::get_log_max_threads;
6
7use super::{
8 additive_ntt::{AdditiveNTT, NTTShape},
9 error::Error,
10 multithreaded::MultithreadedNTT,
11 single_threaded::SingleThreadedNTT,
12 twiddle::PrecomputedTwiddleAccess,
13};
14
15#[derive(Default, Debug, Clone, Copy)]
17pub enum ThreadingSettings {
18 #[default]
20 SingleThreaded,
21 MultithreadedDefault,
23 ExplicitThreadsCount { log_threads: usize },
25}
26
27impl ThreadingSettings {
28 pub fn log_threads_count(&self) -> usize {
30 match self {
31 Self::SingleThreaded => 0,
32 Self::MultithreadedDefault => get_log_max_threads(),
33 Self::ExplicitThreadsCount { log_threads } => *log_threads,
34 }
35 }
36
37 pub const fn is_multithreaded(&self) -> bool {
39 match self {
40 Self::SingleThreaded => false,
41 Self::MultithreadedDefault => true,
42 Self::ExplicitThreadsCount { log_threads } => *log_threads > 0,
43 }
44 }
45}
46
47#[derive(Default)]
48pub struct NTTOptions {
49 pub precompute_twiddles: bool,
50 pub thread_settings: ThreadingSettings,
51}
52
53#[derive(Debug)]
55pub enum DynamicDispatchNTT<F: BinaryField> {
56 SingleThreaded(SingleThreadedNTT<F>),
57 SingleThreadedPrecompute(SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>>),
58 MultiThreaded(MultithreadedNTT<F>),
59 MultiThreadedPrecompute(MultithreadedNTT<F, PrecomputedTwiddleAccess<F>>),
60}
61
62impl<F: BinaryField> DynamicDispatchNTT<F> {
63 pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result<Self, Error> {
65 let log_threads = options.thread_settings.log_threads_count();
66 let result = match (options.precompute_twiddles, log_threads) {
67 (false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?),
68 (true, 0) => Self::SingleThreadedPrecompute(
69 SingleThreadedNTT::new(log_domain_size)?.precompute_twiddles(),
70 ),
71 (false, _) => Self::MultiThreaded(
72 SingleThreadedNTT::new(log_domain_size)?
73 .multithreaded_with_max_threads(log_threads),
74 ),
75 (true, _) => Self::MultiThreadedPrecompute(
76 SingleThreadedNTT::new(log_domain_size)?
77 .precompute_twiddles()
78 .multithreaded_with_max_threads(log_threads),
79 ),
80 };
81
82 Ok(result)
83 }
84}
85
86impl<F: BinaryField> AdditiveNTT<F> for DynamicDispatchNTT<F> {
87 fn log_domain_size(&self) -> usize {
88 match self {
89 Self::SingleThreaded(ntt) => ntt.log_domain_size(),
90 Self::SingleThreadedPrecompute(ntt) => ntt.log_domain_size(),
91 Self::MultiThreaded(ntt) => ntt.log_domain_size(),
92 Self::MultiThreadedPrecompute(ntt) => ntt.log_domain_size(),
93 }
94 }
95
96 fn subspace(&self, i: usize) -> BinarySubspace<F> {
97 match self {
98 Self::SingleThreaded(ntt) => ntt.subspace(i),
99 Self::SingleThreadedPrecompute(ntt) => ntt.subspace(i),
100 Self::MultiThreaded(ntt) => ntt.subspace(i),
101 Self::MultiThreadedPrecompute(ntt) => ntt.subspace(i),
102 }
103 }
104
105 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
106 match self {
107 Self::SingleThreaded(ntt) => ntt.get_subspace_eval(i, j),
108 Self::SingleThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
109 Self::MultiThreaded(ntt) => ntt.get_subspace_eval(i, j),
110 Self::MultiThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
111 }
112 }
113
114 fn forward_transform<P: PackedField<Scalar = F>>(
115 &self,
116 data: &mut [P],
117 shape: NTTShape,
118 coset: u32,
119 skip_rounds: usize,
120 ) -> Result<(), Error> {
121 match self {
122 Self::SingleThreaded(ntt) => ntt.forward_transform(data, shape, coset, skip_rounds),
123 Self::SingleThreadedPrecompute(ntt) => {
124 ntt.forward_transform(data, shape, coset, skip_rounds)
125 }
126 Self::MultiThreaded(ntt) => ntt.forward_transform(data, shape, coset, skip_rounds),
127 Self::MultiThreadedPrecompute(ntt) => {
128 ntt.forward_transform(data, shape, coset, skip_rounds)
129 }
130 }
131 }
132
133 fn inverse_transform<P: PackedField<Scalar = F>>(
134 &self,
135 data: &mut [P],
136 shape: NTTShape,
137 coset: u32,
138 skip_rounds: usize,
139 ) -> Result<(), Error> {
140 match self {
141 Self::SingleThreaded(ntt) => ntt.inverse_transform(data, shape, coset, skip_rounds),
142 Self::SingleThreadedPrecompute(ntt) => {
143 ntt.inverse_transform(data, shape, coset, skip_rounds)
144 }
145 Self::MultiThreaded(ntt) => ntt.inverse_transform(data, shape, coset, skip_rounds),
146 Self::MultiThreadedPrecompute(ntt) => {
147 ntt.inverse_transform(data, shape, coset, skip_rounds)
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use binius_field::BinaryField8b;
156
157 use super::*;
158
159 #[test]
160 fn test_creation() {
161 fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT<BinaryField8b> {
162 DynamicDispatchNTT::<BinaryField8b>::new(6, options).unwrap()
163 }
164
165 let ntt = make_ntt(&NTTOptions {
166 precompute_twiddles: false,
167 thread_settings: ThreadingSettings::SingleThreaded,
168 });
169 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
170
171 let ntt = make_ntt(&NTTOptions {
172 precompute_twiddles: true,
173 thread_settings: ThreadingSettings::SingleThreaded,
174 });
175 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
176
177 let multithreaded = get_log_max_threads() > 0;
178 let ntt = make_ntt(&NTTOptions {
179 precompute_twiddles: false,
180 thread_settings: ThreadingSettings::MultithreadedDefault,
181 });
182 if multithreaded {
183 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
184 } else {
185 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
186 }
187
188 let ntt = make_ntt(&NTTOptions {
189 precompute_twiddles: true,
190 thread_settings: ThreadingSettings::MultithreadedDefault,
191 });
192 if multithreaded {
193 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreadedPrecompute(_)));
194 } else {
195 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
196 }
197
198 let ntt = make_ntt(&NTTOptions {
199 precompute_twiddles: false,
200 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 },
201 });
202 assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
203
204 let ntt = make_ntt(&NTTOptions {
205 precompute_twiddles: true,
206 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
207 });
208 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
209
210 let ntt = make_ntt(&NTTOptions {
211 precompute_twiddles: false,
212 thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
213 });
214 assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
215 }
216}