binius_ntt/
dynamic_dispatch.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_utils::rayon::get_log_max_threads;
6
7use super::{
8	additive_ntt::AdditiveNTT, error::Error, multithreaded::MultithreadedNTT,
9	single_threaded::SingleThreadedNTT, twiddle::PrecomputedTwiddleAccess,
10};
11
12/// How many threads to use (threads number is a power of 2).
13#[derive(Default, Debug, Clone, Copy)]
14pub enum ThreadingSettings {
15	/// Use a single thread for calculations.
16	#[default]
17	SingleThreaded,
18	/// Use the default number of threads based on the number of cores.
19	MultithreadedDefault,
20	/// Explicitly set the logarithm of number of threads.
21	ExplicitThreadsCount { log_threads: usize },
22}
23
24impl ThreadingSettings {
25	/// Get the log2 of the number of threads to use.
26	pub fn log_threads_count(&self) -> usize {
27		match self {
28			Self::SingleThreaded => 0,
29			Self::MultithreadedDefault => get_log_max_threads(),
30			Self::ExplicitThreadsCount { log_threads } => *log_threads,
31		}
32	}
33
34	/// Check if settings imply multithreading.
35	pub const fn is_multithreaded(&self) -> bool {
36		match self {
37			Self::SingleThreaded => false,
38			Self::MultithreadedDefault => true,
39			Self::ExplicitThreadsCount { log_threads } => *log_threads > 0,
40		}
41	}
42}
43
44#[derive(Default)]
45pub struct NTTOptions {
46	pub precompute_twiddles: bool,
47	pub thread_settings: ThreadingSettings,
48}
49
50/// An enum that can be used to switch between different NTT implementations without passing AdditiveNTT as a type parameter.
51#[derive(Debug)]
52pub enum DynamicDispatchNTT<F: BinaryField> {
53	SingleThreaded(SingleThreadedNTT<F>),
54	SingleThreadedPrecompute(SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>>),
55	MultiThreaded(MultithreadedNTT<F>),
56	MultiThreadedPrecompute(MultithreadedNTT<F, PrecomputedTwiddleAccess<F>>),
57}
58
59impl<F: BinaryField> DynamicDispatchNTT<F> {
60	/// Create a new AdditiveNTT based on the given settings.
61	pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result<Self, Error> {
62		let log_threads = options.thread_settings.log_threads_count();
63		let result = match (options.precompute_twiddles, log_threads) {
64			(false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?),
65			(true, 0) => Self::SingleThreadedPrecompute(
66				SingleThreadedNTT::new(log_domain_size)?.precompute_twiddles(),
67			),
68			(false, _) => Self::MultiThreaded(
69				SingleThreadedNTT::new(log_domain_size)?
70					.multithreaded_with_max_threads(log_threads),
71			),
72			(true, _) => Self::MultiThreadedPrecompute(
73				SingleThreadedNTT::new(log_domain_size)?
74					.precompute_twiddles()
75					.multithreaded_with_max_threads(log_threads),
76			),
77		};
78
79		Ok(result)
80	}
81}
82
83impl<F: BinaryField> AdditiveNTT<F> for DynamicDispatchNTT<F> {
84	fn log_domain_size(&self) -> usize {
85		match self {
86			Self::SingleThreaded(ntt) => ntt.log_domain_size(),
87			Self::SingleThreadedPrecompute(ntt) => ntt.log_domain_size(),
88			Self::MultiThreaded(ntt) => ntt.log_domain_size(),
89			Self::MultiThreadedPrecompute(ntt) => ntt.log_domain_size(),
90		}
91	}
92
93	fn subspace(&self, i: usize) -> BinarySubspace<F> {
94		match self {
95			Self::SingleThreaded(ntt) => ntt.subspace(i),
96			Self::SingleThreadedPrecompute(ntt) => ntt.subspace(i),
97			Self::MultiThreaded(ntt) => ntt.subspace(i),
98			Self::MultiThreadedPrecompute(ntt) => ntt.subspace(i),
99		}
100	}
101
102	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
103		match self {
104			Self::SingleThreaded(ntt) => ntt.get_subspace_eval(i, j),
105			Self::SingleThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
106			Self::MultiThreaded(ntt) => ntt.get_subspace_eval(i, j),
107			Self::MultiThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
108		}
109	}
110
111	fn forward_transform<P: PackedField<Scalar = F>>(
112		&self,
113		data: &mut [P],
114		coset: u32,
115		log_batch_size: usize,
116		log_n: usize,
117	) -> Result<(), Error> {
118		match self {
119			Self::SingleThreaded(ntt) => ntt.forward_transform(data, coset, log_batch_size, log_n),
120			Self::SingleThreadedPrecompute(ntt) => {
121				ntt.forward_transform(data, coset, log_batch_size, log_n)
122			}
123			Self::MultiThreaded(ntt) => ntt.forward_transform(data, coset, log_batch_size, log_n),
124			Self::MultiThreadedPrecompute(ntt) => {
125				ntt.forward_transform(data, coset, log_batch_size, log_n)
126			}
127		}
128	}
129
130	fn inverse_transform<P: PackedField<Scalar = F>>(
131		&self,
132		data: &mut [P],
133		coset: u32,
134		log_batch_size: usize,
135		log_n: usize,
136	) -> Result<(), Error> {
137		match self {
138			Self::SingleThreaded(ntt) => ntt.inverse_transform(data, coset, log_batch_size, log_n),
139			Self::SingleThreadedPrecompute(ntt) => {
140				ntt.inverse_transform(data, coset, log_batch_size, log_n)
141			}
142			Self::MultiThreaded(ntt) => ntt.inverse_transform(data, coset, log_batch_size, log_n),
143			Self::MultiThreadedPrecompute(ntt) => {
144				ntt.inverse_transform(data, coset, log_batch_size, log_n)
145			}
146		}
147	}
148}
149
150#[cfg(test)]
151mod tests {
152	use binius_field::BinaryField8b;
153
154	use super::*;
155
156	#[test]
157	fn test_creation() {
158		fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT<BinaryField8b> {
159			DynamicDispatchNTT::<BinaryField8b>::new(6, options).unwrap()
160		}
161
162		let ntt = make_ntt(&NTTOptions {
163			precompute_twiddles: false,
164			thread_settings: ThreadingSettings::SingleThreaded,
165		});
166		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
167
168		let ntt = make_ntt(&NTTOptions {
169			precompute_twiddles: true,
170			thread_settings: ThreadingSettings::SingleThreaded,
171		});
172		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
173
174		let multithreaded = get_log_max_threads() > 0;
175		let ntt = make_ntt(&NTTOptions {
176			precompute_twiddles: false,
177			thread_settings: ThreadingSettings::MultithreadedDefault,
178		});
179		if multithreaded {
180			assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
181		} else {
182			assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
183		}
184
185		let ntt = make_ntt(&NTTOptions {
186			precompute_twiddles: true,
187			thread_settings: ThreadingSettings::MultithreadedDefault,
188		});
189		if multithreaded {
190			assert!(matches!(ntt, DynamicDispatchNTT::MultiThreadedPrecompute(_)));
191		} else {
192			assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
193		}
194
195		let ntt = make_ntt(&NTTOptions {
196			precompute_twiddles: false,
197			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 },
198		});
199		assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
200
201		let ntt = make_ntt(&NTTOptions {
202			precompute_twiddles: true,
203			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
204		});
205		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
206
207		let ntt = make_ntt(&NTTOptions {
208			precompute_twiddles: false,
209			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
210		});
211		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
212	}
213}