1use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_maybe_rayon::prelude::*;
6use binius_utils::rayon::get_log_max_threads;
7
8use super::{
9 additive_ntt::{AdditiveNTT, NTTShape},
10 error::Error,
11 single_threaded::{self, check_batch_transform_inputs_and_params, SingleThreadedNTT},
12 strided_array::StridedArray2DViewMut,
13 twiddle::TwiddleAccess,
14};
15use crate::twiddle::OnTheFlyTwiddleAccess;
16
17#[derive(Debug)]
19pub struct MultithreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F, Vec<F>>>
20{
21 single_threaded: SingleThreadedNTT<F, TA>,
22 log_max_threads: usize,
23}
24
25impl<F: BinaryField, TA: TwiddleAccess<F> + Sync> SingleThreadedNTT<F, TA> {
26 pub fn multithreaded(self) -> MultithreadedNTT<F, TA> {
28 let log_max_threads = get_log_max_threads();
29 self.multithreaded_with_max_threads(log_max_threads as _)
30 }
31
32 pub const fn multithreaded_with_max_threads(
34 self,
35 log_max_threads: usize,
36 ) -> MultithreadedNTT<F, TA> {
37 MultithreadedNTT {
38 single_threaded: self,
39 log_max_threads,
40 }
41 }
42}
43
44impl<F, TA> AdditiveNTT<F> for MultithreadedNTT<F, TA>
45where
46 F: BinaryField,
47 TA: TwiddleAccess<F> + Sync,
48{
49 fn log_domain_size(&self) -> usize {
50 self.single_threaded.log_domain_size()
51 }
52
53 fn subspace(&self, i: usize) -> BinarySubspace<F> {
54 self.single_threaded.subspace(i)
55 }
56
57 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
58 self.single_threaded.get_subspace_eval(i, j)
59 }
60
61 fn forward_transform<P: PackedField<Scalar = F>>(
62 &self,
63 data: &mut [P],
64 shape: NTTShape,
65 coset: u32,
66 skip_rounds: usize,
67 ) -> Result<(), Error> {
68 forward_transform(
69 self.log_domain_size(),
70 self.single_threaded.twiddles(),
71 data,
72 shape,
73 coset,
74 skip_rounds,
75 self.log_max_threads,
76 )
77 }
78
79 fn inverse_transform<P: PackedField<Scalar = F>>(
80 &self,
81 data: &mut [P],
82 shape: NTTShape,
83 coset: u32,
84 skip_rounds: usize,
85 ) -> Result<(), Error> {
86 inverse_transform(
87 self.log_domain_size(),
88 self.single_threaded.twiddles(),
89 data,
90 shape,
91 coset,
92 skip_rounds,
93 self.log_max_threads,
94 )
95 }
96}
97
98#[allow(clippy::too_many_arguments)]
99fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
100 log_domain_size: usize,
101 s_evals: &[impl TwiddleAccess<F> + Sync],
102 data: &mut [P],
103 shape: NTTShape,
104 coset: u32,
105 skip_rounds: usize,
106 log_max_threads: usize,
107) -> Result<(), Error> {
108 check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
109
110 match data.len() {
111 0 => return Ok(()),
112 1 => {
113 return match P::WIDTH {
114 1 => Ok(()),
115 _ => single_threaded::forward_transform(
116 log_domain_size,
117 s_evals,
118 data,
119 shape,
120 coset,
121 skip_rounds,
122 ),
123 };
124 }
125 _ => {}
126 };
127
128 let NTTShape {
129 log_x,
130 log_y,
131 log_z,
132 } = shape;
133
134 let log_w = P::LOG_WIDTH;
135
136 let min_log_width_scalars = (log_w + 1).max(log_x);
142
143 let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
146 let log_width = log_x + log_y + log_z - (log_w + log_height);
147
148 let par_rounds = log_height.saturating_sub(log_z);
155
156 {
158 let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
159 .expect("dimensions are correct");
160
161 let log_strides = log_max_threads.min(log_width);
162 let log_stride_len = log_width - log_strides;
163
164 matrix
165 .into_par_strides(1 << log_stride_len)
166 .for_each(|mut stride| {
167 for i in (0..par_rounds.saturating_sub(skip_rounds)).rev() {
169 let s_evals_par_i = &s_evals[log_y - par_rounds + i];
170 let coset_offset = (coset as usize) << (par_rounds - 1 - i);
171
172 for j in 0..1 << log_z {
174 for k in 0..1 << (par_rounds - 1 - i) {
177 let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
178 for l in 0..1 << i {
180 for m in 0..1 << log_stride_len {
181 let idx0 = j << par_rounds | k << (i + 1) | l;
182 let idx1 = idx0 | 1 << i;
183
184 let mut u = stride[(idx0, m)];
185 let mut v = stride[(idx1, m)];
186 u += v * twiddle;
187 v += u;
188 stride[(idx0, m)] = u;
189 stride[(idx1, m)] = v;
190 }
191 }
192 }
193 }
194 }
195 });
196 }
197
198 let log_row_z = log_z.saturating_sub(log_height);
199 let single_thread_log_y = log_width + log_w - log_x - log_row_z;
200
201 data.par_chunks_mut(1 << (log_width + par_rounds))
202 .flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
203 .try_for_each(|(inner_coset, chunk)| {
204 single_threaded::forward_transform(
205 log_domain_size,
206 &s_evals[0..log_y - par_rounds],
207 chunk,
208 NTTShape {
209 log_x,
210 log_y: single_thread_log_y,
211 log_z: log_row_z,
212 },
213 coset << par_rounds | inner_coset as u32,
214 skip_rounds.saturating_sub(par_rounds),
215 )
216 })?;
217
218 Ok(())
219}
220
221#[allow(clippy::too_many_arguments)]
222fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
223 log_domain_size: usize,
224 s_evals: &[impl TwiddleAccess<F> + Sync],
225 data: &mut [P],
226 shape: NTTShape,
227 coset: u32,
228 skip_rounds: usize,
229 log_max_threads: usize,
230) -> Result<(), Error> {
231 check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
232
233 match data.len() {
234 0 => return Ok(()),
235 1 => {
236 return match P::WIDTH {
237 1 => Ok(()),
238 _ => single_threaded::inverse_transform(
239 log_domain_size,
240 s_evals,
241 data,
242 shape,
243 coset,
244 skip_rounds,
245 ),
246 };
247 }
248 _ => {}
249 };
250
251 let NTTShape {
252 log_x,
253 log_y,
254 log_z,
255 } = shape;
256
257 let log_w = P::LOG_WIDTH;
258
259 let min_log_width_scalars = (log_w + 1).max(log_x);
265
266 let log_height = (log_x + log_y + log_z - min_log_width_scalars).min(log_max_threads);
269 let log_width = log_x + log_y + log_z - (log_w + log_height);
270
271 let par_rounds = log_height.saturating_sub(log_z);
278
279 let log_row_z = log_z.saturating_sub(log_height);
280 let single_thread_log_y = log_width + log_w - log_x - log_row_z;
281
282 data.par_chunks_mut(1 << (log_width + par_rounds))
283 .flat_map(|large_chunk| large_chunk.par_chunks_mut(1 << log_width).enumerate())
284 .try_for_each(|(inner_coset, chunk)| {
285 single_threaded::inverse_transform(
286 log_domain_size,
287 &s_evals[0..log_y - par_rounds],
288 chunk,
289 NTTShape {
290 log_x,
291 log_y: single_thread_log_y,
292 log_z: log_row_z,
293 },
294 coset << par_rounds | inner_coset as u32,
295 skip_rounds.saturating_sub(par_rounds),
296 )
297 })?;
298
299 let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
301 .expect("dimensions are correct");
302
303 let log_strides = log_max_threads.min(log_width);
304 let log_stride_len = log_width - log_strides;
305
306 matrix
307 .into_par_strides(1 << log_stride_len)
308 .for_each(|mut stride| {
309 for i in 0..par_rounds.saturating_sub(skip_rounds) {
311 let s_evals_par_i = &s_evals[log_y - par_rounds + i];
312 let coset_offset = (coset as usize) << (par_rounds - 1 - i);
313
314 for j in 0..1 << log_z {
316 for k in 0..1 << (par_rounds - 1 - i) {
319 let twiddle = P::broadcast(s_evals_par_i.get(coset_offset | k));
320 for l in 0..1 << i {
322 for m in 0..1 << log_stride_len {
323 let idx0 = j << par_rounds | k << (i + 1) | l;
324 let idx1 = idx0 | 1 << i;
325
326 let mut u = stride[(idx0, m)];
327 let mut v = stride[(idx1, m)];
328 v += u;
329 u += v * twiddle;
330 stride[(idx0, m)] = u;
331 stride[(idx1, m)] = v;
332 }
333 }
334 }
335 }
336 }
337 });
338
339 Ok(())
340}