binius_ntt/
dynamic_dispatch.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_utils::rayon::get_log_max_threads;
6
7use super::{
8	additive_ntt::{AdditiveNTT, NTTShape},
9	error::Error,
10	multithreaded::MultithreadedNTT,
11	single_threaded::SingleThreadedNTT,
12	twiddle::PrecomputedTwiddleAccess,
13};
14
15/// How many threads to use (threads number is a power of 2).
16#[derive(Default, Debug, Clone, Copy)]
17pub enum ThreadingSettings {
18	/// Use a single thread for calculations.
19	#[default]
20	SingleThreaded,
21	/// Use the default number of threads based on the number of cores.
22	MultithreadedDefault,
23	/// Explicitly set the logarithm of number of threads.
24	ExplicitThreadsCount { log_threads: usize },
25}
26
27impl ThreadingSettings {
28	/// Get the log2 of the number of threads to use.
29	pub fn log_threads_count(&self) -> usize {
30		match self {
31			Self::SingleThreaded => 0,
32			Self::MultithreadedDefault => get_log_max_threads(),
33			Self::ExplicitThreadsCount { log_threads } => *log_threads,
34		}
35	}
36
37	/// Check if settings imply multithreading.
38	pub const fn is_multithreaded(&self) -> bool {
39		match self {
40			Self::SingleThreaded => false,
41			Self::MultithreadedDefault => true,
42			Self::ExplicitThreadsCount { log_threads } => *log_threads > 0,
43		}
44	}
45}
46
47#[derive(Default)]
48pub struct NTTOptions {
49	pub precompute_twiddles: bool,
50	pub thread_settings: ThreadingSettings,
51}
52
53/// An enum that can be used to switch between different NTT implementations without passing
54/// AdditiveNTT as a type parameter.
55#[derive(Debug)]
56pub enum DynamicDispatchNTT<F: BinaryField> {
57	SingleThreaded(SingleThreadedNTT<F>),
58	SingleThreadedPrecompute(SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>>),
59	MultiThreaded(MultithreadedNTT<F>),
60	MultiThreadedPrecompute(MultithreadedNTT<F, PrecomputedTwiddleAccess<F>>),
61}
62
63impl<F: BinaryField> DynamicDispatchNTT<F> {
64	/// Create a new AdditiveNTT based on the given settings.
65	pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result<Self, Error> {
66		let log_threads = options.thread_settings.log_threads_count();
67		let result = match (options.precompute_twiddles, log_threads) {
68			(false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?),
69			(true, 0) => Self::SingleThreadedPrecompute(
70				SingleThreadedNTT::new(log_domain_size)?.precompute_twiddles(),
71			),
72			(false, _) => Self::MultiThreaded(
73				SingleThreadedNTT::new(log_domain_size)?
74					.multithreaded_with_max_threads(log_threads),
75			),
76			(true, _) => Self::MultiThreadedPrecompute(
77				SingleThreadedNTT::new(log_domain_size)?
78					.precompute_twiddles()
79					.multithreaded_with_max_threads(log_threads),
80			),
81		};
82
83		Ok(result)
84	}
85}
86
87impl<F: BinaryField> AdditiveNTT<F> for DynamicDispatchNTT<F> {
88	fn log_domain_size(&self) -> usize {
89		match self {
90			Self::SingleThreaded(ntt) => ntt.log_domain_size(),
91			Self::SingleThreadedPrecompute(ntt) => ntt.log_domain_size(),
92			Self::MultiThreaded(ntt) => ntt.log_domain_size(),
93			Self::MultiThreadedPrecompute(ntt) => ntt.log_domain_size(),
94		}
95	}
96
97	fn subspace(&self, i: usize) -> BinarySubspace<F> {
98		match self {
99			Self::SingleThreaded(ntt) => ntt.subspace(i),
100			Self::SingleThreadedPrecompute(ntt) => ntt.subspace(i),
101			Self::MultiThreaded(ntt) => ntt.subspace(i),
102			Self::MultiThreadedPrecompute(ntt) => ntt.subspace(i),
103		}
104	}
105
106	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
107		match self {
108			Self::SingleThreaded(ntt) => ntt.get_subspace_eval(i, j),
109			Self::SingleThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
110			Self::MultiThreaded(ntt) => ntt.get_subspace_eval(i, j),
111			Self::MultiThreadedPrecompute(ntt) => ntt.get_subspace_eval(i, j),
112		}
113	}
114
115	fn forward_transform<P: PackedField<Scalar = F>>(
116		&self,
117		data: &mut [P],
118		shape: NTTShape,
119		coset: usize,
120		coset_bits: usize,
121		skip_rounds: usize,
122	) -> Result<(), Error> {
123		match self {
124			Self::SingleThreaded(ntt) => {
125				ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
126			}
127			Self::SingleThreadedPrecompute(ntt) => {
128				ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
129			}
130			Self::MultiThreaded(ntt) => {
131				ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
132			}
133			Self::MultiThreadedPrecompute(ntt) => {
134				ntt.forward_transform(data, shape, coset, coset_bits, skip_rounds)
135			}
136		}
137	}
138
139	fn inverse_transform<P: PackedField<Scalar = F>>(
140		&self,
141		data: &mut [P],
142		shape: NTTShape,
143		coset: usize,
144		coset_bits: usize,
145		skip_rounds: usize,
146	) -> Result<(), Error> {
147		match self {
148			Self::SingleThreaded(ntt) => {
149				ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
150			}
151			Self::SingleThreadedPrecompute(ntt) => {
152				ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
153			}
154			Self::MultiThreaded(ntt) => {
155				ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
156			}
157			Self::MultiThreadedPrecompute(ntt) => {
158				ntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds)
159			}
160		}
161	}
162}
163
164#[cfg(test)]
165mod tests {
166	use binius_field::BinaryField8b;
167
168	use super::*;
169
170	#[test]
171	fn test_creation() {
172		fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT<BinaryField8b> {
173			DynamicDispatchNTT::<BinaryField8b>::new(6, options).unwrap()
174		}
175
176		let ntt = make_ntt(&NTTOptions {
177			precompute_twiddles: false,
178			thread_settings: ThreadingSettings::SingleThreaded,
179		});
180		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
181
182		let ntt = make_ntt(&NTTOptions {
183			precompute_twiddles: true,
184			thread_settings: ThreadingSettings::SingleThreaded,
185		});
186		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
187
188		let multithreaded = get_log_max_threads() > 0;
189		let ntt = make_ntt(&NTTOptions {
190			precompute_twiddles: false,
191			thread_settings: ThreadingSettings::MultithreadedDefault,
192		});
193		if multithreaded {
194			assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
195		} else {
196			assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
197		}
198
199		let ntt = make_ntt(&NTTOptions {
200			precompute_twiddles: true,
201			thread_settings: ThreadingSettings::MultithreadedDefault,
202		});
203		if multithreaded {
204			assert!(matches!(ntt, DynamicDispatchNTT::MultiThreadedPrecompute(_)));
205		} else {
206			assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
207		}
208
209		let ntt = make_ntt(&NTTOptions {
210			precompute_twiddles: false,
211			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 },
212		});
213		assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_)));
214
215		let ntt = make_ntt(&NTTOptions {
216			precompute_twiddles: true,
217			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
218		});
219		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_)));
220
221		let ntt = make_ntt(&NTTOptions {
222			precompute_twiddles: false,
223			thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 },
224		});
225		assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_)));
226	}
227}