binius_ntt/
multithreaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_maybe_rayon::prelude::*;
6use binius_utils::rayon::get_log_max_threads;
7
8use super::{
9	error::Error,
10	single_threaded::{self, check_batch_transform_inputs_and_params},
11	strided_array::StridedArray2DViewMut,
12	twiddle::TwiddleAccess,
13	AdditiveNTT, SingleThreadedNTT,
14};
15use crate::twiddle::OnTheFlyTwiddleAccess;
16
17/// Implementation of `AdditiveNTT` that performs the computation multithreaded.
18#[derive(Debug)]
19pub struct MultithreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F, Vec<F>>>
20{
21	single_threaded: SingleThreadedNTT<F, TA>,
22	log_max_threads: usize,
23}
24
25impl<F: BinaryField, TA: TwiddleAccess<F> + Sync> SingleThreadedNTT<F, TA> {
26	/// Returns multithreaded NTT implementation which uses default number of threads.
27	pub fn multithreaded(self) -> MultithreadedNTT<F, TA> {
28		let log_max_threads = get_log_max_threads();
29		self.multithreaded_with_max_threads(log_max_threads as _)
30	}
31
32	/// Returns multithreaded NTT implementation which uses `1 << log_max_threads` threads.
33	pub const fn multithreaded_with_max_threads(
34		self,
35		log_max_threads: usize,
36	) -> MultithreadedNTT<F, TA> {
37		MultithreadedNTT {
38			single_threaded: self,
39			log_max_threads,
40		}
41	}
42}
43
44impl<F, TA> AdditiveNTT<F> for MultithreadedNTT<F, TA>
45where
46	F: BinaryField,
47	TA: TwiddleAccess<F> + Sync,
48{
49	fn log_domain_size(&self) -> usize {
50		self.single_threaded.log_domain_size()
51	}
52
53	fn subspace(&self, i: usize) -> BinarySubspace<F> {
54		self.single_threaded.subspace(i)
55	}
56
57	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
58		self.single_threaded.get_subspace_eval(i, j)
59	}
60
61	fn forward_transform<P: PackedField<Scalar = F>>(
62		&self,
63		data: &mut [P],
64		coset: u32,
65		log_batch_size: usize,
66		log_n: usize,
67	) -> Result<(), Error> {
68		forward_transform(
69			self.log_domain_size(),
70			self.single_threaded.twiddles(),
71			data,
72			coset,
73			log_batch_size,
74			log_n,
75			self.log_max_threads,
76		)
77	}
78
79	fn inverse_transform<P: PackedField<Scalar = F>>(
80		&self,
81		data: &mut [P],
82		coset: u32,
83		log_batch_size: usize,
84		log_n: usize,
85	) -> Result<(), Error> {
86		inverse_transform(
87			self.log_domain_size(),
88			self.single_threaded.twiddles(),
89			data,
90			coset,
91			log_batch_size,
92			log_n,
93			self.log_max_threads,
94		)
95	}
96}
97
98fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
99	log_domain_size: usize,
100	s_evals: &[impl TwiddleAccess<F> + Sync],
101	data: &mut [P],
102	coset: u32,
103	log_batch_size: usize,
104	log_n: usize,
105	log_max_threads: usize,
106) -> Result<(), Error> {
107	match data.len() {
108		0 => return Ok(()),
109		1 => {
110			return match P::WIDTH {
111				1 => Ok(()),
112				_ => single_threaded::forward_transform(
113					log_domain_size,
114					s_evals,
115					data,
116					coset,
117					log_batch_size,
118					log_n,
119				),
120			};
121		}
122		_ => {}
123	};
124
125	let log_w = P::LOG_WIDTH;
126
127	let log_b = log_batch_size;
128
129	check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_b, log_n)?;
130
131	// Cutoff is the stage of the NTT where each the butterfly units are contained within
132	// packed base field elements.
133	let cutoff = log_w.saturating_sub(log_b);
134
135	let par_rounds = (log_n - cutoff).min(log_max_threads);
136	let log_height = par_rounds;
137	let log_width = log_n + log_b - log_w - log_height;
138
139	// Perform the column-wise NTTs in parallel over vertical strides of the matrix.
140	{
141		let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
142			.expect("dimensions are correct");
143
144		let log_strides = log_max_threads.min(log_width);
145		let log_stride_len = log_width - log_strides;
146
147		matrix
148			.into_par_strides(1 << log_stride_len)
149			.for_each(|mut stride| {
150				for i in (0..par_rounds).rev() {
151					let coset_twiddle = s_evals[log_n - par_rounds + i]
152						.coset(log_domain_size - log_n, coset as usize);
153
154					for j in 0..1 << (par_rounds - 1 - i) {
155						let twiddle = P::broadcast(coset_twiddle.get(j));
156						for k in 0..1 << i {
157							for l in 0..1 << log_stride_len {
158								let idx0 = j << (i + 1) | k;
159								let idx1 = idx0 | 1 << i;
160
161								let mut u = stride[(idx0, l)];
162								let mut v = stride[(idx1, l)];
163								u += v * twiddle;
164								v += u;
165								stride[(idx0, l)] = u;
166								stride[(idx1, l)] = v;
167							}
168						}
169					}
170				}
171			});
172	}
173
174	let single_thread_log_n = log_width + P::LOG_WIDTH - log_batch_size;
175
176	data.par_chunks_mut(1 << log_width)
177		.enumerate()
178		.try_for_each(|(inner_coset, chunk)| {
179			single_threaded::forward_transform(
180				log_domain_size,
181				&s_evals[0..log_n - par_rounds],
182				chunk,
183				coset << par_rounds | (inner_coset as u32),
184				log_batch_size,
185				single_thread_log_n,
186			)
187		})?;
188
189	Ok(())
190}
191
192fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
193	log_domain_size: usize,
194	s_evals: &[impl TwiddleAccess<F> + Sync],
195	data: &mut [P],
196	coset: u32,
197	log_batch_size: usize,
198	log_n: usize,
199	log_max_threads: usize,
200) -> Result<(), Error> {
201	match data.len() {
202		0 => return Ok(()),
203		1 => {
204			return match P::WIDTH {
205				1 => Ok(()),
206				_ => single_threaded::inverse_transform(
207					log_domain_size,
208					s_evals,
209					data,
210					coset,
211					log_batch_size,
212					log_n,
213				),
214			};
215		}
216		_ => {}
217	};
218
219	let log_w = P::LOG_WIDTH;
220
221	let log_b = log_batch_size;
222
223	check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_b, log_n)?;
224
225	// Cutoff is the stage of the NTT where each the butterfly units are contained within
226	// packed base field elements.
227	let cutoff = log_w.saturating_sub(log_b);
228
229	let par_rounds = (log_n - cutoff).min(log_max_threads);
230	let log_height = par_rounds;
231	let log_width = log_n + log_b - log_w - log_height;
232	let single_thread_log_n = log_width + P::LOG_WIDTH - log_batch_size;
233
234	data.par_chunks_mut(1 << log_width)
235		.enumerate()
236		.try_for_each(|(inner_coset, chunk)| {
237			single_threaded::inverse_transform(
238				log_domain_size,
239				&s_evals[0..log_n - par_rounds],
240				chunk,
241				coset << par_rounds | (inner_coset as u32),
242				log_batch_size,
243				single_thread_log_n,
244			)
245		})?;
246
247	// Perform the column-wise NTTs in parallel over vertical strides of the matrix.
248	let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
249		.expect("dimensions are correct");
250
251	let log_strides = log_max_threads.min(log_width);
252	let log_stride_len = log_width - log_strides;
253
254	matrix
255		.into_par_strides(1 << log_stride_len)
256		.for_each(|mut stride| {
257			for i in 0..par_rounds {
258				let coset_twiddle =
259					s_evals[log_n - par_rounds + i].coset(log_domain_size - log_n, coset as usize);
260
261				for j in 0..1 << (par_rounds - 1 - i) {
262					let twiddle = P::broadcast(coset_twiddle.get(j));
263					for k in 0..1 << i {
264						for l in 0..1 << log_stride_len {
265							let idx0 = j << (i + 1) | k;
266							let idx1 = idx0 | 1 << i;
267
268							let mut u = stride[(idx0, l)];
269							let mut v = stride[(idx1, l)];
270							v += u;
271							u += v * twiddle;
272							stride[(idx0, l)] = u;
273							stride[(idx1, l)] = v;
274						}
275					}
276				}
277			}
278		});
279
280	Ok(())
281}