binius_ntt/
multithreaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_maybe_rayon::prelude::*;
6use binius_utils::{rayon::get_log_max_threads, strided_array::StridedArray2DViewMut};
7
8use super::{
9	additive_ntt::{AdditiveNTT, NTTShape},
10	error::Error,
11	single_threaded::{SingleThreadedNTT, check_batch_transform_inputs_and_params},
12	twiddle::TwiddleAccess,
13};
14use crate::twiddle::OnTheFlyTwiddleAccess;
15
16/// Implementation of `AdditiveNTT` that performs the computation multithreaded.
17#[derive(Debug)]
18pub struct MultithreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F, Vec<F>>>
19{
20	single_threaded: SingleThreadedNTT<F, TA>,
21	log_max_threads: usize,
22}
23
24impl<F: BinaryField, TA: TwiddleAccess<F> + Sync> SingleThreadedNTT<F, TA> {
25	/// Returns multithreaded NTT implementation which uses default number of threads.
26	pub fn multithreaded(self) -> MultithreadedNTT<F, TA> {
27		let log_max_threads = get_log_max_threads();
28		self.multithreaded_with_max_threads(log_max_threads as _)
29	}
30
31	/// Returns multithreaded NTT implementation which uses `1 << log_max_threads` threads.
32	pub const fn multithreaded_with_max_threads(
33		self,
34		log_max_threads: usize,
35	) -> MultithreadedNTT<F, TA> {
36		MultithreadedNTT {
37			single_threaded: self,
38			log_max_threads,
39		}
40	}
41}
42
43impl<F, TA> AdditiveNTT<F> for MultithreadedNTT<F, TA>
44where
45	F: BinaryField,
46	TA: TwiddleAccess<F> + Sync,
47{
48	fn log_domain_size(&self) -> usize {
49		self.single_threaded.log_domain_size()
50	}
51
52	fn subspace(&self, i: usize) -> BinarySubspace<F> {
53		self.single_threaded.subspace(i)
54	}
55
56	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
57		self.single_threaded.get_subspace_eval(i, j)
58	}
59
60	fn forward_transform<P: PackedField<Scalar = F>>(
61		&self,
62		data: &mut [P],
63		shape: NTTShape,
64		coset: usize,
65		coset_bits: usize,
66		skip_rounds: usize,
67	) -> Result<(), Error> {
68		forward_transform(
69			&self.single_threaded,
70			data,
71			shape,
72			coset,
73			coset_bits,
74			skip_rounds,
75			self.log_max_threads,
76		)
77	}
78
79	fn inverse_transform<P: PackedField<Scalar = F>>(
80		&self,
81		data: &mut [P],
82		shape: NTTShape,
83		coset: usize,
84		coset_bits: usize,
85		skip_rounds: usize,
86	) -> Result<(), Error> {
87		inverse_transform(
88			&self.single_threaded,
89			data,
90			shape,
91			coset,
92			coset_bits,
93			skip_rounds,
94			self.log_max_threads,
95		)
96	}
97}
98
99#[allow(clippy::too_many_arguments)]
100fn forward_transform<F, P, TA>(
101	subntt: &SingleThreadedNTT<F, TA>,
102	data: &mut [P],
103	shape: NTTShape,
104	coset: usize,
105	coset_bits: usize,
106	skip_rounds: usize,
107	log_max_threads: usize,
108) -> Result<(), Error>
109where
110	F: BinaryField,
111	P: PackedField<Scalar = F>,
112	TA: TwiddleAccess<F> + Sync,
113{
114	check_batch_transform_inputs_and_params(
115		subntt.log_domain_size(),
116		data,
117		shape,
118		coset,
119		coset_bits,
120		skip_rounds,
121	)?;
122
123	match data.len() {
124		0 => return Ok(()),
125		1 => {
126			return match P::WIDTH {
127				1 => Ok(()),
128				_ => subntt.forward_transform(data, shape, coset, coset_bits, skip_rounds),
129			};
130		}
131		_ => {}
132	};
133
134	let NTTShape {
135		log_x,
136		log_y,
137		log_z,
138	} = shape;
139
140	let log_w = P::LOG_WIDTH;
141
142	// Determine the optimal log_height and log_width, measured in packed field elements. log_width
143	// must be at least 1, so that the row-wise NTTs have pairs of butterfly blocks to work with.
144	// Furthermore, we want the number of scalars in the width to be at least log_x, so that each
145	// row-wise NTT operates on full batches. Subject to the minimum width requirement, we want to
146	// set the height to the lowest value that takes advantage of all available threads.
147	let min_log_width_scalars = (log_w + 1).max(log_x);
148
149	// The subtraction must not underflow because data is checked to have at least 2 packed
150	// elements at the top of the function.
151	let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
152	let log_width = log_x + log_y + log_z - (log_w + log_height);
153
154	// The NTT algorithm shapes the tensor into a 2D matrix and runs two phases:
155	//
156	// 1. The first phase performs strided column-wise NTTs.
157	// 2. The second phase performs the row-wise NTTs.
158
159	// par_rounds is the number of phase 1 strided rounds.
160	let par_rounds = log_height.saturating_sub(log_z);
161
162	// Perform the column-wise NTTs in parallel over vertical strides of the matrix.
163	{
164		let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
165			.expect("dimensions are correct");
166
167		let log_strides = log_max_threads.min(log_width);
168		let log_stride_len = log_width - log_strides;
169
170		let log_domain_size = subntt.log_domain_size();
171		let s_evals = &subntt.twiddles()[log_domain_size - (par_rounds + coset_bits)..];
172
173		matrix
174			.into_par_strides(1 << log_stride_len)
175			.for_each(|mut stride| {
176				// i indexes the layer of the NTT network, also the binary subspace.
177				for i in (0..par_rounds.saturating_sub(skip_rounds)).rev() {
178					let s_evals_par_i = &s_evals[i];
179					let coset_offset = coset << (par_rounds - 1 - i);
180
181					// j indexes the outer Z tensor axis.
182					for j in 0..1 << log_z {
183						// k indexes the block within the layer. Each block performs butterfly
184						// operations with the same twiddle factor.
185						for k in 0..1 << (par_rounds - 1 - i) {
186							let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
187							// l indexes parallel stride columns
188							for l in 0..1 << i {
189								for m in 0..1 << log_stride_len {
190									let idx0 = j << par_rounds | k << (i + 1) | l;
191									let idx1 = idx0 | 1 << i;
192
193									let mut u = stride[(idx0, m)];
194									let mut v = stride[(idx1, m)];
195									u += v * twiddle;
196									v += u;
197									stride[(idx0, m)] = u;
198									stride[(idx1, m)] = v;
199								}
200							}
201						}
202					}
203				}
204			});
205	}
206
207	let log_row_z = log_z.saturating_sub(log_height);
208	let single_thread_log_y = log_width + log_w - log_x - log_row_z;
209
210	data.par_chunks_mut(1 << (log_width + par_rounds))
211		.flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
212		.try_for_each(|(inner_coset, chunk)| {
213			subntt.forward_transform(
214				chunk,
215				NTTShape {
216					log_x,
217					log_y: single_thread_log_y,
218					log_z: log_row_z,
219				},
220				coset << par_rounds | inner_coset,
221				coset_bits + par_rounds,
222				skip_rounds.saturating_sub(par_rounds),
223			)
224		})?;
225
226	Ok(())
227}
228
229#[allow(clippy::too_many_arguments)]
230fn inverse_transform<F, P, TA>(
231	subntt: &SingleThreadedNTT<F, TA>,
232	data: &mut [P],
233	shape: NTTShape,
234	coset: usize,
235	coset_bits: usize,
236	skip_rounds: usize,
237	log_max_threads: usize,
238) -> Result<(), Error>
239where
240	F: BinaryField,
241	P: PackedField<Scalar = F>,
242	TA: TwiddleAccess<F> + Sync,
243{
244	check_batch_transform_inputs_and_params(
245		subntt.log_domain_size(),
246		data,
247		shape,
248		coset,
249		coset_bits,
250		skip_rounds,
251	)?;
252
253	match data.len() {
254		0 => return Ok(()),
255		1 => {
256			return match P::WIDTH {
257				1 => Ok(()),
258				_ => subntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds),
259			};
260		}
261		_ => {}
262	};
263
264	let NTTShape {
265		log_x,
266		log_y,
267		log_z,
268	} = shape;
269
270	let log_w = P::LOG_WIDTH;
271
272	// Determine the optimal log_height and log_width, measured in packed field elements. log_width
273	// must be at least 1, so that the row-wise NTTs have pairs of butterfly blocks to work with.
274	// Furthermore, we want the number of scalars in the width to be at least log_x, so that each
275	// row-wise NTT operates on full batches. Subject to the minimum width requirement, we want to
276	// set the height to the lowest value that takes advantage of all available threads.
277	let min_log_width_scalars = (log_w + 1).max(log_x);
278
279	// The subtraction must not underflow because data is checked to have at least 2 packed
280	// elements at the top of the function.
281	let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
282	let log_width = log_x + log_y + log_z - (log_w + log_height);
283
284	// The iNTT algorithm shapes the tensor into a 2D matrix and runs two phases:
285	//
286	// 1. The first phase performs the row-wise iNTTs.
287	// 2. The second phase performs strided column-wise iNTTs.
288
289	// par_rounds is the number of phase 2 strided rounds.
290	let par_rounds = log_height.saturating_sub(log_z);
291
292	let log_row_z = log_z.saturating_sub(log_height);
293	let single_thread_log_y = log_width + log_w - log_x - log_row_z;
294
295	data.par_chunks_mut(1 << (log_width + par_rounds))
296		.flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
297		.try_for_each(|(inner_coset, chunk)| {
298			subntt.inverse_transform(
299				chunk,
300				NTTShape {
301					log_x,
302					log_y: single_thread_log_y,
303					log_z: log_row_z,
304				},
305				coset << par_rounds | inner_coset,
306				coset_bits + par_rounds,
307				skip_rounds.saturating_sub(par_rounds),
308			)
309		})?;
310
311	// Perform the column-wise NTTs in parallel over vertical strides of the matrix.
312	let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
313		.expect("dimensions are correct");
314
315	let log_strides = log_max_threads.min(log_width);
316	let log_stride_len = log_width - log_strides;
317
318	let log_domain_size = subntt.log_domain_size();
319	let s_evals = &subntt.twiddles()[log_domain_size - (par_rounds + coset_bits)..];
320
321	matrix
322		.into_par_strides(1 << log_stride_len)
323		.for_each(|mut stride| {
324			// i indexes the layer of the NTT network, also the binary subspace.
325			#[allow(clippy::needless_range_loop)]
326			for i in 0..par_rounds.saturating_sub(skip_rounds) {
327				let s_evals_par_i = &s_evals[i];
328				let coset_offset = coset << (par_rounds - 1 - i);
329
330				// j indexes the outer Z tensor axis.
331				for j in 0..1 << log_z {
332					// k indexes the block within the layer. Each block performs butterfly
333					// operations with the same twiddle factor.
334					for k in 0..1 << (par_rounds - 1 - i) {
335						let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
336						// l indexes parallel stride columns
337						for l in 0..1 << i {
338							for m in 0..1 << log_stride_len {
339								let idx0 = j << par_rounds | k << (i + 1) | l;
340								let idx1 = idx0 | 1 << i;
341
342								let mut u = stride[(idx0, m)];
343								let mut v = stride[(idx1, m)];
344								v += u;
345								u += v * twiddle;
346								stride[(idx0, m)] = u;
347								stride[(idx1, m)] = v;
348							}
349						}
350					}
351				}
352			}
353		});
354
355	Ok(())
356}