1use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_maybe_rayon::prelude::*;
6use binius_utils::{rayon::get_log_max_threads, strided_array::StridedArray2DViewMut};
7
8use super::{
9 additive_ntt::{AdditiveNTT, NTTShape},
10 error::Error,
11 single_threaded::{SingleThreadedNTT, check_batch_transform_inputs_and_params},
12 twiddle::TwiddleAccess,
13};
14use crate::twiddle::OnTheFlyTwiddleAccess;
15
16#[derive(Debug)]
18pub struct MultithreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F, Vec<F>>>
19{
20 single_threaded: SingleThreadedNTT<F, TA>,
21 log_max_threads: usize,
22}
23
24impl<F: BinaryField, TA: TwiddleAccess<F> + Sync> SingleThreadedNTT<F, TA> {
25 pub fn multithreaded(self) -> MultithreadedNTT<F, TA> {
27 let log_max_threads = get_log_max_threads();
28 self.multithreaded_with_max_threads(log_max_threads as _)
29 }
30
31 pub const fn multithreaded_with_max_threads(
33 self,
34 log_max_threads: usize,
35 ) -> MultithreadedNTT<F, TA> {
36 MultithreadedNTT {
37 single_threaded: self,
38 log_max_threads,
39 }
40 }
41}
42
43impl<F, TA> AdditiveNTT<F> for MultithreadedNTT<F, TA>
44where
45 F: BinaryField,
46 TA: TwiddleAccess<F> + Sync,
47{
48 fn log_domain_size(&self) -> usize {
49 self.single_threaded.log_domain_size()
50 }
51
52 fn subspace(&self, i: usize) -> BinarySubspace<F> {
53 self.single_threaded.subspace(i)
54 }
55
56 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
57 self.single_threaded.get_subspace_eval(i, j)
58 }
59
60 fn forward_transform<P: PackedField<Scalar = F>>(
61 &self,
62 data: &mut [P],
63 shape: NTTShape,
64 coset: usize,
65 coset_bits: usize,
66 skip_rounds: usize,
67 ) -> Result<(), Error> {
68 forward_transform(
69 &self.single_threaded,
70 data,
71 shape,
72 coset,
73 coset_bits,
74 skip_rounds,
75 self.log_max_threads,
76 )
77 }
78
79 fn inverse_transform<P: PackedField<Scalar = F>>(
80 &self,
81 data: &mut [P],
82 shape: NTTShape,
83 coset: usize,
84 coset_bits: usize,
85 skip_rounds: usize,
86 ) -> Result<(), Error> {
87 inverse_transform(
88 &self.single_threaded,
89 data,
90 shape,
91 coset,
92 coset_bits,
93 skip_rounds,
94 self.log_max_threads,
95 )
96 }
97}
98
99#[allow(clippy::too_many_arguments)]
100fn forward_transform<F, P, TA>(
101 subntt: &SingleThreadedNTT<F, TA>,
102 data: &mut [P],
103 shape: NTTShape,
104 coset: usize,
105 coset_bits: usize,
106 skip_rounds: usize,
107 log_max_threads: usize,
108) -> Result<(), Error>
109where
110 F: BinaryField,
111 P: PackedField<Scalar = F>,
112 TA: TwiddleAccess<F> + Sync,
113{
114 check_batch_transform_inputs_and_params(
115 subntt.log_domain_size(),
116 data,
117 shape,
118 coset,
119 coset_bits,
120 skip_rounds,
121 )?;
122
123 match data.len() {
124 0 => return Ok(()),
125 1 => {
126 return match P::WIDTH {
127 1 => Ok(()),
128 _ => subntt.forward_transform(data, shape, coset, coset_bits, skip_rounds),
129 };
130 }
131 _ => {}
132 };
133
134 let NTTShape {
135 log_x,
136 log_y,
137 log_z,
138 } = shape;
139
140 let log_w = P::LOG_WIDTH;
141
142 let min_log_width_scalars = (log_w + 1).max(log_x);
148
149 let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
152 let log_width = log_x + log_y + log_z - (log_w + log_height);
153
154 let par_rounds = log_height.saturating_sub(log_z);
161
162 {
164 let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
165 .expect("dimensions are correct");
166
167 let log_strides = log_max_threads.min(log_width);
168 let log_stride_len = log_width - log_strides;
169
170 let log_domain_size = subntt.log_domain_size();
171 let s_evals = &subntt.twiddles()[log_domain_size - (par_rounds + coset_bits)..];
172
173 matrix
174 .into_par_strides(1 << log_stride_len)
175 .for_each(|mut stride| {
176 for i in (0..par_rounds.saturating_sub(skip_rounds)).rev() {
178 let s_evals_par_i = &s_evals[i];
179 let coset_offset = coset << (par_rounds - 1 - i);
180
181 for j in 0..1 << log_z {
183 for k in 0..1 << (par_rounds - 1 - i) {
186 let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
187 for l in 0..1 << i {
189 for m in 0..1 << log_stride_len {
190 let idx0 = j << par_rounds | k << (i + 1) | l;
191 let idx1 = idx0 | 1 << i;
192
193 let mut u = stride[(idx0, m)];
194 let mut v = stride[(idx1, m)];
195 u += v * twiddle;
196 v += u;
197 stride[(idx0, m)] = u;
198 stride[(idx1, m)] = v;
199 }
200 }
201 }
202 }
203 }
204 });
205 }
206
207 let log_row_z = log_z.saturating_sub(log_height);
208 let single_thread_log_y = log_width + log_w - log_x - log_row_z;
209
210 data.par_chunks_mut(1 << (log_width + par_rounds))
211 .flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
212 .try_for_each(|(inner_coset, chunk)| {
213 subntt.forward_transform(
214 chunk,
215 NTTShape {
216 log_x,
217 log_y: single_thread_log_y,
218 log_z: log_row_z,
219 },
220 coset << par_rounds | inner_coset,
221 coset_bits + par_rounds,
222 skip_rounds.saturating_sub(par_rounds),
223 )
224 })?;
225
226 Ok(())
227}
228
229#[allow(clippy::too_many_arguments)]
230fn inverse_transform<F, P, TA>(
231 subntt: &SingleThreadedNTT<F, TA>,
232 data: &mut [P],
233 shape: NTTShape,
234 coset: usize,
235 coset_bits: usize,
236 skip_rounds: usize,
237 log_max_threads: usize,
238) -> Result<(), Error>
239where
240 F: BinaryField,
241 P: PackedField<Scalar = F>,
242 TA: TwiddleAccess<F> + Sync,
243{
244 check_batch_transform_inputs_and_params(
245 subntt.log_domain_size(),
246 data,
247 shape,
248 coset,
249 coset_bits,
250 skip_rounds,
251 )?;
252
253 match data.len() {
254 0 => return Ok(()),
255 1 => {
256 return match P::WIDTH {
257 1 => Ok(()),
258 _ => subntt.inverse_transform(data, shape, coset, coset_bits, skip_rounds),
259 };
260 }
261 _ => {}
262 };
263
264 let NTTShape {
265 log_x,
266 log_y,
267 log_z,
268 } = shape;
269
270 let log_w = P::LOG_WIDTH;
271
272 let min_log_width_scalars = (log_w + 1).max(log_x);
278
279 let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
282 let log_width = log_x + log_y + log_z - (log_w + log_height);
283
284 let par_rounds = log_height.saturating_sub(log_z);
291
292 let log_row_z = log_z.saturating_sub(log_height);
293 let single_thread_log_y = log_width + log_w - log_x - log_row_z;
294
295 data.par_chunks_mut(1 << (log_width + par_rounds))
296 .flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
297 .try_for_each(|(inner_coset, chunk)| {
298 subntt.inverse_transform(
299 chunk,
300 NTTShape {
301 log_x,
302 log_y: single_thread_log_y,
303 log_z: log_row_z,
304 },
305 coset << par_rounds | inner_coset,
306 coset_bits + par_rounds,
307 skip_rounds.saturating_sub(par_rounds),
308 )
309 })?;
310
311 let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
313 .expect("dimensions are correct");
314
315 let log_strides = log_max_threads.min(log_width);
316 let log_stride_len = log_width - log_strides;
317
318 let log_domain_size = subntt.log_domain_size();
319 let s_evals = &subntt.twiddles()[log_domain_size - (par_rounds + coset_bits)..];
320
321 matrix
322 .into_par_strides(1 << log_stride_len)
323 .for_each(|mut stride| {
324 #[allow(clippy::needless_range_loop)]
326 for i in 0..par_rounds.saturating_sub(skip_rounds) {
327 let s_evals_par_i = &s_evals[i];
328 let coset_offset = coset << (par_rounds - 1 - i);
329
330 for j in 0..1 << log_z {
332 for k in 0..1 << (par_rounds - 1 - i) {
335 let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
336 for l in 0..1 << i {
338 for m in 0..1 << log_stride_len {
339 let idx0 = j << par_rounds | k << (i + 1) | l;
340 let idx1 = idx0 | 1 << i;
341
342 let mut u = stride[(idx0, m)];
343 let mut v = stride[(idx1, m)];
344 v += u;
345 u += v * twiddle;
346 stride[(idx0, m)] = u;
347 stride[(idx1, m)] = v;
348 }
349 }
350 }
351 }
352 }
353 });
354
355 Ok(())
356}