1use binius_field::{BinaryField, PackedField};
4use binius_math::BinarySubspace;
5use binius_maybe_rayon::prelude::*;
6use binius_utils::rayon::get_log_max_threads;
7
8use super::{
9 error::Error,
10 single_threaded::{self, check_batch_transform_inputs_and_params},
11 strided_array::StridedArray2DViewMut,
12 twiddle::TwiddleAccess,
13 AdditiveNTT, SingleThreadedNTT,
14};
15use crate::twiddle::OnTheFlyTwiddleAccess;
16
17#[derive(Debug)]
19pub struct MultithreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F, Vec<F>>>
20{
21 single_threaded: SingleThreadedNTT<F, TA>,
22 log_max_threads: usize,
23}
24
25impl<F: BinaryField, TA: TwiddleAccess<F> + Sync> SingleThreadedNTT<F, TA> {
26 pub fn multithreaded(self) -> MultithreadedNTT<F, TA> {
28 let log_max_threads = get_log_max_threads();
29 self.multithreaded_with_max_threads(log_max_threads as _)
30 }
31
32 pub const fn multithreaded_with_max_threads(
34 self,
35 log_max_threads: usize,
36 ) -> MultithreadedNTT<F, TA> {
37 MultithreadedNTT {
38 single_threaded: self,
39 log_max_threads,
40 }
41 }
42}
43
44impl<F, TA> AdditiveNTT<F> for MultithreadedNTT<F, TA>
45where
46 F: BinaryField,
47 TA: TwiddleAccess<F> + Sync,
48{
49 fn log_domain_size(&self) -> usize {
50 self.single_threaded.log_domain_size()
51 }
52
53 fn subspace(&self, i: usize) -> BinarySubspace<F> {
54 self.single_threaded.subspace(i)
55 }
56
57 fn get_subspace_eval(&self, i: usize, j: usize) -> F {
58 self.single_threaded.get_subspace_eval(i, j)
59 }
60
61 fn forward_transform<P: PackedField<Scalar = F>>(
62 &self,
63 data: &mut [P],
64 coset: u32,
65 log_batch_size: usize,
66 log_n: usize,
67 ) -> Result<(), Error> {
68 forward_transform(
69 self.log_domain_size(),
70 self.single_threaded.twiddles(),
71 data,
72 coset,
73 log_batch_size,
74 log_n,
75 self.log_max_threads,
76 )
77 }
78
79 fn inverse_transform<P: PackedField<Scalar = F>>(
80 &self,
81 data: &mut [P],
82 coset: u32,
83 log_batch_size: usize,
84 log_n: usize,
85 ) -> Result<(), Error> {
86 inverse_transform(
87 self.log_domain_size(),
88 self.single_threaded.twiddles(),
89 data,
90 coset,
91 log_batch_size,
92 log_n,
93 self.log_max_threads,
94 )
95 }
96}
97
98fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
99 log_domain_size: usize,
100 s_evals: &[impl TwiddleAccess<F> + Sync],
101 data: &mut [P],
102 coset: u32,
103 log_batch_size: usize,
104 log_n: usize,
105 log_max_threads: usize,
106) -> Result<(), Error> {
107 match data.len() {
108 0 => return Ok(()),
109 1 => {
110 return match P::WIDTH {
111 1 => Ok(()),
112 _ => single_threaded::forward_transform(
113 log_domain_size,
114 s_evals,
115 data,
116 coset,
117 log_batch_size,
118 log_n,
119 ),
120 };
121 }
122 _ => {}
123 };
124
125 let log_w = P::LOG_WIDTH;
126
127 let log_b = log_batch_size;
128
129 check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_b, log_n)?;
130
131 let cutoff = log_w.saturating_sub(log_b);
134
135 let par_rounds = (log_n - cutoff).min(log_max_threads);
136 let log_height = par_rounds;
137 let log_width = log_n + log_b - log_w - log_height;
138
139 {
141 let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
142 .expect("dimensions are correct");
143
144 let log_strides = log_max_threads.min(log_width);
145 let log_stride_len = log_width - log_strides;
146
147 matrix
148 .into_par_strides(1 << log_stride_len)
149 .for_each(|mut stride| {
150 for i in (0..par_rounds).rev() {
151 let coset_twiddle = s_evals[log_n - par_rounds + i]
152 .coset(log_domain_size - log_n, coset as usize);
153
154 for j in 0..1 << (par_rounds - 1 - i) {
155 let twiddle = P::broadcast(coset_twiddle.get(j));
156 for k in 0..1 << i {
157 for l in 0..1 << log_stride_len {
158 let idx0 = j << (i + 1) | k;
159 let idx1 = idx0 | 1 << i;
160
161 let mut u = stride[(idx0, l)];
162 let mut v = stride[(idx1, l)];
163 u += v * twiddle;
164 v += u;
165 stride[(idx0, l)] = u;
166 stride[(idx1, l)] = v;
167 }
168 }
169 }
170 }
171 });
172 }
173
174 let single_thread_log_n = log_width + P::LOG_WIDTH - log_batch_size;
175
176 data.par_chunks_mut(1 << log_width)
177 .enumerate()
178 .try_for_each(|(inner_coset, chunk)| {
179 single_threaded::forward_transform(
180 log_domain_size,
181 &s_evals[0..log_n - par_rounds],
182 chunk,
183 coset << par_rounds | (inner_coset as u32),
184 log_batch_size,
185 single_thread_log_n,
186 )
187 })?;
188
189 Ok(())
190}
191
192fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
193 log_domain_size: usize,
194 s_evals: &[impl TwiddleAccess<F> + Sync],
195 data: &mut [P],
196 coset: u32,
197 log_batch_size: usize,
198 log_n: usize,
199 log_max_threads: usize,
200) -> Result<(), Error> {
201 match data.len() {
202 0 => return Ok(()),
203 1 => {
204 return match P::WIDTH {
205 1 => Ok(()),
206 _ => single_threaded::inverse_transform(
207 log_domain_size,
208 s_evals,
209 data,
210 coset,
211 log_batch_size,
212 log_n,
213 ),
214 };
215 }
216 _ => {}
217 };
218
219 let log_w = P::LOG_WIDTH;
220
221 let log_b = log_batch_size;
222
223 check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_b, log_n)?;
224
225 let cutoff = log_w.saturating_sub(log_b);
228
229 let par_rounds = (log_n - cutoff).min(log_max_threads);
230 let log_height = par_rounds;
231 let log_width = log_n + log_b - log_w - log_height;
232 let single_thread_log_n = log_width + P::LOG_WIDTH - log_batch_size;
233
234 data.par_chunks_mut(1 << log_width)
235 .enumerate()
236 .try_for_each(|(inner_coset, chunk)| {
237 single_threaded::inverse_transform(
238 log_domain_size,
239 &s_evals[0..log_n - par_rounds],
240 chunk,
241 coset << par_rounds | (inner_coset as u32),
242 log_batch_size,
243 single_thread_log_n,
244 )
245 })?;
246
247 let matrix = StridedArray2DViewMut::without_stride(data, 1 << log_height, 1 << log_width)
249 .expect("dimensions are correct");
250
251 let log_strides = log_max_threads.min(log_width);
252 let log_stride_len = log_width - log_strides;
253
254 matrix
255 .into_par_strides(1 << log_stride_len)
256 .for_each(|mut stride| {
257 for i in 0..par_rounds {
258 let coset_twiddle =
259 s_evals[log_n - par_rounds + i].coset(log_domain_size - log_n, coset as usize);
260
261 for j in 0..1 << (par_rounds - 1 - i) {
262 let twiddle = P::broadcast(coset_twiddle.get(j));
263 for k in 0..1 << i {
264 for l in 0..1 << log_stride_len {
265 let idx0 = j << (i + 1) | k;
266 let idx1 = idx0 | 1 << i;
267
268 let mut u = stride[(idx0, l)];
269 let mut v = stride[(idx1, l)];
270 v += u;
271 u += v * twiddle;
272 stride[(idx0, l)] = u;
273 stride[(idx1, l)] = v;
274 }
275 }
276 }
277 }
278 });
279
280 Ok(())
281}