binius_ntt/
multithreaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_maybe_rayon::prelude::*;
6use binius_utils::rayon::get_log_max_threads;
7
8use super::{
9	additive_ntt::{AdditiveNTT, NTTShape},
10	error::Error,
11	single_threaded::{self, check_batch_transform_inputs_and_params, SingleThreadedNTT},
12	strided_array::StridedArray2DViewMut,
13	twiddle::TwiddleAccess,
14};
15use crate::twiddle::OnTheFlyTwiddleAccess;
16
17/// Implementation of `AdditiveNTT` that performs the computation multithreaded.
18#[derive(Debug)]
19pub struct MultithreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F, Vec<F>>>
20{
21	single_threaded: SingleThreadedNTT<F, TA>,
22	log_max_threads: usize,
23}
24
25impl<F: BinaryField, TA: TwiddleAccess<F> + Sync> SingleThreadedNTT<F, TA> {
26	/// Returns multithreaded NTT implementation which uses default number of threads.
27	pub fn multithreaded(self) -> MultithreadedNTT<F, TA> {
28		let log_max_threads = get_log_max_threads();
29		self.multithreaded_with_max_threads(log_max_threads as _)
30	}
31
32	/// Returns multithreaded NTT implementation which uses `1 << log_max_threads` threads.
33	pub const fn multithreaded_with_max_threads(
34		self,
35		log_max_threads: usize,
36	) -> MultithreadedNTT<F, TA> {
37		MultithreadedNTT {
38			single_threaded: self,
39			log_max_threads,
40		}
41	}
42}
43
44impl<F, TA> AdditiveNTT<F> for MultithreadedNTT<F, TA>
45where
46	F: BinaryField,
47	TA: TwiddleAccess<F> + Sync,
48{
49	fn log_domain_size(&self) -> usize {
50		self.single_threaded.log_domain_size()
51	}
52
53	fn subspace(&self, i: usize) -> BinarySubspace<F> {
54		self.single_threaded.subspace(i)
55	}
56
57	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
58		self.single_threaded.get_subspace_eval(i, j)
59	}
60
61	fn forward_transform<P: PackedField<Scalar = F>>(
62		&self,
63		data: &mut [P],
64		shape: NTTShape,
65		coset: u32,
66		skip_rounds: usize,
67	) -> Result<(), Error> {
68		forward_transform(
69			self.log_domain_size(),
70			self.single_threaded.twiddles(),
71			data,
72			shape,
73			coset,
74			skip_rounds,
75			self.log_max_threads,
76		)
77	}
78
79	fn inverse_transform<P: PackedField<Scalar = F>>(
80		&self,
81		data: &mut [P],
82		shape: NTTShape,
83		coset: u32,
84		skip_rounds: usize,
85	) -> Result<(), Error> {
86		inverse_transform(
87			self.log_domain_size(),
88			self.single_threaded.twiddles(),
89			data,
90			shape,
91			coset,
92			skip_rounds,
93			self.log_max_threads,
94		)
95	}
96}
97
98#[allow(clippy::too_many_arguments)]
99fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
100	log_domain_size: usize,
101	s_evals: &[impl TwiddleAccess<F> + Sync],
102	data: &mut [P],
103	shape: NTTShape,
104	coset: u32,
105	skip_rounds: usize,
106	log_max_threads: usize,
107) -> Result<(), Error> {
108	check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
109
110	match data.len() {
111		0 => return Ok(()),
112		1 => {
113			return match P::WIDTH {
114				1 => Ok(()),
115				_ => single_threaded::forward_transform(
116					log_domain_size,
117					s_evals,
118					data,
119					shape,
120					coset,
121					skip_rounds,
122				),
123			};
124		}
125		_ => {}
126	};
127
128	let NTTShape {
129		log_x,
130		log_y,
131		log_z,
132	} = shape;
133
134	let log_w = P::LOG_WIDTH;
135
136	// Determine the optimal log_height and log_width, measured in packed field elements. log_width
137	// must be at least 1, so that the row-wise NTTs have pairs of butterfly blocks to work with.
138	// Furthermore, we want the number of scalars in the width to be at least log_x, so that each
139	// row-wise NTT operates on full batches. Subject to the minimum width requirement, we want to
140	// set the height to the lowest value that takes advantage of all available threads.
141	let min_log_width_scalars = (log_w + 1).max(log_x);
142
143	// The subtraction must not underflow because data is checked to have at least 2 packed
144	// elements at the top of the function.
145	let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
146	let log_width = log_x + log_y + log_z - (log_w + log_height);
147
148	// The NTT algorithm shapes the tensor into a 2D matrix and runs two phases:
149	//
150	// 1. The first phase performs strided column-wise NTTs.
151	// 2. The second phase performs the row-wise NTTs.
152
153	// par_rounds is the number of phase 1 strided rounds.
154	let par_rounds = log_height.saturating_sub(log_z);
155
156	// Perform the column-wise NTTs in parallel over vertical strides of the matrix.
157	{
158		let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
159			.expect("dimensions are correct");
160
161		let log_strides = log_max_threads.min(log_width);
162		let log_stride_len = log_width - log_strides;
163
164		matrix
165			.into_par_strides(1 << log_stride_len)
166			.for_each(|mut stride| {
167				// i indexes the layer of the NTT network, also the binary subspace.
168				for i in (0..par_rounds.saturating_sub(skip_rounds)).rev() {
169					let s_evals_par_i = &s_evals[log_y - par_rounds + i];
170					let coset_offset = (coset as usize) << (par_rounds - 1 - i);
171
172					// j indexes the outer Z tensor axis.
173					for j in 0..1 << log_z {
174						// k indexes the block within the layer. Each block performs butterfly operations with
175						// the same twiddle factor.
176						for k in 0..1 << (par_rounds - 1 - i) {
177							let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
178							// l indexes parallel stride columns
179							for l in 0..1 << i {
180								for m in 0..1 << log_stride_len {
181									let idx0 = j << par_rounds | k << (i + 1) | l;
182									let idx1 = idx0 | 1 << i;
183
184									let mut u = stride[(idx0, m)];
185									let mut v = stride[(idx1, m)];
186									u += v * twiddle;
187									v += u;
188									stride[(idx0, m)] = u;
189									stride[(idx1, m)] = v;
190								}
191							}
192						}
193					}
194				}
195			});
196	}
197
198	let log_row_z = log_z.saturating_sub(log_height);
199	let single_thread_log_y = log_width + log_w - log_x - log_row_z;
200
201	data.par_chunks_mut(1 << (log_width + par_rounds))
202		.flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
203		.try_for_each(|(inner_coset, chunk)| {
204			single_threaded::forward_transform(
205				log_domain_size,
206				&s_evals[0..log_y - par_rounds],
207				chunk,
208				NTTShape {
209					log_x,
210					log_y: single_thread_log_y,
211					log_z: log_row_z,
212				},
213				coset << par_rounds | inner_coset as u32,
214				skip_rounds.saturating_sub(par_rounds),
215			)
216		})?;
217
218	Ok(())
219}
220
221#[allow(clippy::too_many_arguments)]
222fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
223	log_domain_size: usize,
224	s_evals: &[impl TwiddleAccess<F> + Sync],
225	data: &mut [P],
226	shape: NTTShape,
227	coset: u32,
228	skip_rounds: usize,
229	log_max_threads: usize,
230) -> Result<(), Error> {
231	check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
232
233	match data.len() {
234		0 => return Ok(()),
235		1 => {
236			return match P::WIDTH {
237				1 => Ok(()),
238				_ => single_threaded::inverse_transform(
239					log_domain_size,
240					s_evals,
241					data,
242					shape,
243					coset,
244					skip_rounds,
245				),
246			};
247		}
248		_ => {}
249	};
250
251	let NTTShape {
252		log_x,
253		log_y,
254		log_z,
255	} = shape;
256
257	let log_w = P::LOG_WIDTH;
258
259	// Determine the optimal log_height and log_width, measured in packed field elements. log_width
260	// must be at least 1, so that the row-wise NTTs have pairs of butterfly blocks to work with.
261	// Furthermore, we want the number of scalars in the width to be at least log_x, so that each
262	// row-wise NTT operates on full batches. Subject to the minimum width requirement, we want to
263	// set the height to the lowest value that takes advantage of all available threads.
264	let min_log_width_scalars = (log_w + 1).max(log_x);
265
266	// The subtraction must not underflow because data is checked to have at least 2 packed
267	// elements at the top of the function.
268	let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
269	let log_width = log_x + log_y + log_z - (log_w + log_height);
270
271	// The iNTT algorithm shapes the tensor into a 2D matrix and runs two phases:
272	//
273	// 1. The first phase performs the row-wise iNTTs.
274	// 2. The second phase performs strided column-wise iNTTs.
275
276	// par_rounds is the number of phase 2 strided rounds.
277	let par_rounds = log_height.saturating_sub(log_z);
278
279	let log_row_z = log_z.saturating_sub(log_height);
280	let single_thread_log_y = log_width + log_w - log_x - log_row_z;
281
282	data.par_chunks_mut(1 << (log_width + par_rounds))
283		.flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
284		.try_for_each(|(inner_coset, chunk)| {
285			single_threaded::inverse_transform(
286				log_domain_size,
287				&s_evals[0..log_y - par_rounds],
288				chunk,
289				NTTShape {
290					log_x,
291					log_y: single_thread_log_y,
292					log_z: log_row_z,
293				},
294				coset << par_rounds | inner_coset as u32,
295				skip_rounds.saturating_sub(par_rounds),
296			)
297		})?;
298
299	// Perform the column-wise NTTs in parallel over vertical strides of the matrix.
300	let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
301		.expect("dimensions are correct");
302
303	let log_strides = log_max_threads.min(log_width);
304	let log_stride_len = log_width - log_strides;
305
306	matrix
307		.into_par_strides(1 << log_stride_len)
308		.for_each(|mut stride| {
309			// i indexes the layer of the NTT network, also the binary subspace.
310			for i in 0..par_rounds.saturating_sub(skip_rounds) {
311				let s_evals_par_i = &s_evals[log_y - par_rounds + i];
312				let coset_offset = (coset as usize) << (par_rounds - 1 - i);
313
314				// j indexes the outer Z tensor axis.
315				for j in 0..1 << log_z {
316					// k indexes the block within the layer. Each block performs butterfly operations with
317					// the same twiddle factor.
318					for k in 0..1 << (par_rounds - 1 - i) {
319						let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
320						// l indexes parallel stride columns
321						for l in 0..1 << i {
322							for m in 0..1 << log_stride_len {
323								let idx0 = j << par_rounds | k << (i + 1) | l;
324								let idx1 = idx0 | 1 << i;
325
326								let mut u = stride[(idx0, m)];
327								let mut v = stride[(idx1, m)];
328								v += u;
329								u += v * twiddle;
330								stride[(idx0, m)] = u;
331								stride[(idx1, m)] = v;
332							}
333						}
334					}
335				}
336			}
337		});
338
339	Ok(())
340}