binius_ntt/
dynamic_dispatch.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_utils::rayon::get_log_max_threads;
6
7use super::{
8	additive_ntt::{AdditiveNTT, NTTShape},
9	error::Error,
10	multithreaded::MultithreadedNTT,
11	single_threaded::SingleThreadedNTT,
12	twiddle::PrecomputedTwiddleAccess,
13};
14
15/// How many threads to use (threads number is a power of 2).
16#[derive(Default, Debug, Clone, Copy)]
17pub enum ThreadingSettings {
18	/// Use a single thread for calculations.
19	#[default]
20	SingleThreaded,
21	/// Use the default number of threads based on the number of cores.
22	MultithreadedDefault,
23	/// Explicitly set the logarithm of number of threads.
24	ExplicitThreadsCount { log_threads: usize },
25}
26
27impl ThreadingSettings {
28	/// Get the log2 of the number of threads to use.
29	pub fn log_threads_count(&self) -> usize {
30		match self {
31			Self::SingleThreaded => 0,
32			Self::MultithreadedDefault => get_log_max_threads(),
33			Self::ExplicitThreadsCount { log_threads } => *log_threads,
34		}
35	}
36
37	/// Check if settings imply multithreading.
38	pub const fn is_multithreaded(&self) -> bool {
39		match self {
40			Self::SingleThreaded => false,
41			Self::MultithreadedDefault => true,
42			Self::ExplicitThreadsCount { log_threads } => *log_threads > 0,
43		}
44	}
45}
46
47#[derive(Default)]
48pub struct NTTOptions {
49	pub precompute_twiddles: bool,
50	pub thread_settings: ThreadingSettings,
51}
52
53/// An enum that can be used to switch between different NTT implementations without passing AdditiveNTT as a type parameter.
54#[derive(Debug)]
55pub enum DynamicDispatchNTT<F: BinaryField> {
56	SingleThreaded(SingleThreadedNTT<F>),
57	SingleThreadedPrecompute(SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>>),
58	MultiThreaded(MultithreadedNTT<F>),
59	MultiThreadedPrecompute(MultithreadedNTT<F, PrecomputedTwiddleAccess<F>>),
60}
61
62impl<F: BinaryField> DynamicDispatchNTT<F> {
63	/// Create a new AdditiveNTT based on the given settings.
64	pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result<Self, Error> {
65		let log_threads = options.thread_settings.log_threads_count();
66		let result = match (options.precompute_twiddles, log_threads) {
67			(false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?),
68			(true, 0) => Self::SingleThreadedPrecompute(
69				SingleThreadedNTT::new(log_domain_size)?.precompute_twiddles(),
70			),
71			(false, _) => Self::MultiThreaded(
72				SingleThreadedNTT::new(log_domain_size)?
73					.multithreaded_with_max_threads(log_threads),
74			),
75			(true, _) => Self::MultiThreadedPrecompute(
76				SingleThreadedNTT::new(log_domain_size)?
77					.precompute_twiddles()
78					.multithreaded_with_max_threads(log_threads),
79			),
80		};
81
82		Ok(result)
83	}
84}
85
86impl<F: BinaryField> AdditiveNTT<F> for DynamicDispatchNTT<F> {
87	fn log_domain_size(&self) -> usize {
88		match self {
89			Self::SingleThreaded(ntt) => ntt.log_domain_size(),
90			Self::SingleThreadedPrecompute(ntt) => ntt.log_domain_size(),
91			Self::MultiThreaded(ntt) => ntt.log_domain_size(),
92			Self::MultiThreadedPrecompute(ntt) => ntt.log_domain_size(),
93		}
94	}
95
96	fn subspace(&self, i: usize) -> BinarySubspace<F> {
97		match self {
98			Self::SingleThreaded(ntt) => ntt.subspace(i),
99			Self::SingleThreadedPrecompute(ntt) => ntt.subspace(i),
100			Self::MultiThreaded(ntt) => ntt.subspace(i),
101			Self::MultiThreadedPrecompute(ntt) => ntt.subspace(i),
102		}
103	}
104
105	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
106		match self {
107			Self::SingleThreaded(ntt) => ntt.get_subspace_eval(i, j),
108			Self::SingleThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
109			Self::MultiThreaded(ntt) => ntt.get_subspace_eval(i, j),
110			Self::MultiThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
111		}
112	}
113
114	fn forward_transform<P: PackedField<Scalar = F>>(
115		&self,
116		data: &mut [P],
117		shape: NTTShape,
118		coset: u32,
119		skip_rounds: usize,
120	) -> Result<(), Error> {
121		match self {
122			Self::SingleThreaded(ntt) => ntt.forward_transform(data, shape, coset, skip_rounds),
123			Self::SingleThreadedPrecompute(ntt) => {
124				ntt.forward_transform(data, shape, coset, skip_rounds)
125			}
126			Self::MultiThreaded(ntt) => ntt.forward_transform(data, shape, coset, skip_rounds),
127			Self::MultiThreadedPrecompute(ntt) => {
128				ntt.forward_transform(data, shape, coset, skip_rounds)
129			}
130		}
131	}
132
133	fn inverse_transform<P: PackedField<Scalar = F>>(
134		&self,
135		data: &mut [P],
136		shape: NTTShape,
137		coset: u32,
138		skip_rounds: usize,
139	) -> Result<(), Error> {
140		match self {
141			Self::SingleThreaded(ntt) => ntt.inverse_transform(data, shape, coset, skip_rounds),
142			Self::SingleThreadedPrecompute(ntt) => {
143				ntt.inverse_transform(data, shape, coset, skip_rounds)
144			}
145			Self::MultiThreaded(ntt) => ntt.inverse_transform(data, shape, coset, skip_rounds),
146			Self::MultiThreadedPrecompute(ntt) => {
147				ntt.inverse_transform(data, shape, coset, skip_rounds)
148			}
149		}
150	}
151}
152
153#[cfg(test)]
154mod tests {
155	use binius_field::BinaryField8b;
156
157	use super::*;
158
159	#[test]
160	fn test_creation() {
161		fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT<BinaryField8b> {
162			DynamicDispatchNTT::<BinaryField8b>::new(6, options).unwrap()
163		}
164
165		let ntt = make_ntt(&NTTOptions {
166			precompute_twiddles: false,
167			thread_settings: ThreadingSettings::SingleThreaded,
168		});
169		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
170
171		let ntt = make_ntt(&NTTOptions {
172			precompute_twiddles: true,
173			thread_settings: ThreadingSettings::SingleThreaded,
174		});
175		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
176
177		let multithreaded = get_log_max_threads() > 0;
178		let ntt = make_ntt(&NTTOptions {
179			precompute_twiddles: false,
180			thread_settings: ThreadingSettings::MultithreadedDefault,
181		});
182		if multithreaded {
183			assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
184		} else {
185			assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
186		}
187
188		let ntt = make_ntt(&NTTOptions {
189			precompute_twiddles: true,
190			thread_settings: ThreadingSettings::MultithreadedDefault,
191		});
192		if multithreaded {
193			assert!(matches!(ntt, DynamicDispatchNTT::MultiThreadedPrecompute(_)));
194		} else {
195			assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
196		}
197
198		let ntt = make_ntt(&NTTOptions {
199			precompute_twiddles: false,
200			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 },
201		});
202		assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
203
204		let ntt = make_ntt(&NTTOptions {
205			precompute_twiddles: true,
206			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
207		});
208		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
209
210		let ntt = make_ntt(&NTTOptions {
211			precompute_twiddles: false,
212			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
213		});
214		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
215	}
216}