1use std::{
4 cmp::{max, min},
5 iter,
6 ops::Range,
7 slice::from_raw_parts_mut,
8};
9
10use binius_field::{BinaryField, PackedField};
11use binius_utils::rayon::{
12 iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
13 slice::ParallelSliceMut,
14};
15
16use super::{
17 AdditiveNTT, DomainContext,
18 reference::{NeighborsLastReference, input_check},
19};
20use crate::field_buffer::FieldSliceMut;
21
22const DEFAULT_LOG_BASE_LEN: usize = 10;
26
27fn forward_depth_first<P: PackedField>(
58 domain_context: &impl DomainContext<Field = P::Scalar>,
59 data: &mut [P],
60 log_d: usize,
61 layer: usize,
62 block: usize,
63 mut layer_range: Range<usize>,
64 log_base_len: usize,
65) {
66 debug_assert!(P::LOG_WIDTH < log_d);
68 debug_assert_eq!(data.len(), 1 << (log_d - P::LOG_WIDTH));
69 debug_assert!(layer_range.end <= domain_context.log_domain_size());
70 debug_assert!(layer <= layer_range.start);
71 debug_assert!(log_base_len > P::LOG_WIDTH);
72
73 let log_n = log_d + layer;
74 debug_assert!(layer_range.end <= log_n);
75
76 if layer >= layer_range.end {
77 return;
78 }
79
80 if log_d <= log_base_len {
82 forward_breadth_first(domain_context, data, log_d, layer, block, layer_range);
83 return;
84 }
85
86 let block_size_half = 1 << (log_d - 1 - P::LOG_WIDTH);
87 if layer >= layer_range.start {
88 let twiddle = domain_context.twiddle(layer, block);
90 let packed_twiddle = P::broadcast(twiddle);
91 let (block0, block1) = data.split_at_mut(block_size_half);
92 for (u, v) in iter::zip(block0, block1) {
93 *u += *v * packed_twiddle;
95 *v += *u;
96 }
97
98 layer_range.start += 1;
99 }
100
101 forward_depth_first(
103 domain_context,
104 &mut data[..block_size_half],
105 log_d - 1,
106 layer + 1,
107 block << 1,
108 layer_range.clone(),
109 log_base_len,
110 );
111 forward_depth_first(
112 domain_context,
113 &mut data[block_size_half..],
114 log_d - 1,
115 layer + 1,
116 (block << 1) + 1,
117 layer_range,
118 log_base_len,
119 );
120}
121
122fn forward_breadth_first<P: PackedField>(
132 domain_context: &impl DomainContext<Field = P::Scalar>,
133 data: &mut [P],
134 log_d: usize,
135 base_layer: usize,
136 base_block: usize,
137 layer_range: Range<usize>,
138) {
139 debug_assert!(P::LOG_WIDTH < log_d);
141 debug_assert_eq!(data.len(), 1 << (log_d - P::LOG_WIDTH));
142 debug_assert!(layer_range.end <= domain_context.log_domain_size());
143 debug_assert!(base_layer <= layer_range.start);
144
145 let log_n = log_d + base_layer;
146 debug_assert!(layer_range.end <= log_n);
147
148 let packed_cutoff = (log_n - P::LOG_WIDTH).clamp(layer_range.start, layer_range.end);
149
150 for layer in layer_range.start..packed_cutoff {
153 let log_block_size = log_n - P::LOG_WIDTH - layer;
155 let log_half_block_size = log_block_size - 1;
156
157 let log_blocks = layer - base_layer;
159 let layer_twiddles = domain_context
160 .iter_twiddles(layer, 0)
161 .skip(base_block << log_blocks)
162 .take(1 << log_blocks);
163 let blocks = data.chunks_exact_mut(1 << log_block_size);
164 for (block, twiddle) in iter::zip(blocks, layer_twiddles) {
165 let packed_twiddle = P::broadcast(twiddle);
166 let (block0, block1) = block.split_at_mut(1 << log_half_block_size);
167 for (u, v) in iter::zip(block0, block1) {
168 *u += *v * packed_twiddle;
170 *v += *u;
171 }
172 }
173 }
174
175 for layer in packed_cutoff..layer_range.end {
179 let log_block_size = log_n - layer;
181 let log_half_block_size = log_block_size - 1;
182 let log_blocks_per_packed = P::LOG_WIDTH - log_block_size;
183 let log_half_blocks_per_packed = log_blocks_per_packed + 1;
184
185 let mut packed_twiddle_offset = P::zero();
187 for block in 0..1 << log_blocks_per_packed {
188 let twiddle0 = domain_context.twiddle(layer, block);
189 let twiddle1 = domain_context.twiddle(layer, (1 << log_blocks_per_packed) | block);
190
191 let block_start = block << log_block_size;
192 for j in 0..1 << log_half_block_size {
193 packed_twiddle_offset.set(block_start | j, twiddle0);
194 packed_twiddle_offset.set(block_start | j | (1 << log_half_block_size), twiddle1);
195 }
196 }
197
198 let log_packed_pairs = packed_cutoff - base_layer - 1;
200 let layer_twiddles = domain_context
201 .iter_twiddles(layer, log_half_blocks_per_packed)
202 .skip(base_block << log_packed_pairs)
203 .take(1 << log_packed_pairs);
204
205 let (data_pairs, rest) = data.as_chunks_mut::<2>();
206 debug_assert!(
207 rest.is_empty(),
208 "data_packed length is a power of two; \
209 data_packed length is greater than 1 (checked at beginning of method)"
210 );
211 debug_assert_eq!(data_pairs.len(), 1 << log_packed_pairs);
212
213 for ([packed0, packed1], first_twiddle) in iter::zip(data_pairs, layer_twiddles) {
214 let packed_twiddle = P::broadcast(first_twiddle) + packed_twiddle_offset;
215
216 let (mut u, mut v) = (*packed0).interleave(*packed1, log_half_block_size);
217 u += v * packed_twiddle;
218 v += u;
219 (*packed0, *packed1) = u.interleave(v, log_half_block_size);
220 }
221 }
222}
223
224fn forward_shared_layer<P: PackedField>(
236 domain_context: &(impl DomainContext<Field = P::Scalar> + Sync),
237 data: &mut [P],
238 log_d: usize,
239 layer: usize,
240 log_num_shares: usize,
241) {
242 debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
244 debug_assert!(1 << (log_num_shares + 1) <= data.len());
245 debug_assert!(layer < domain_context.log_domain_size());
246
247 let log_num_chunks = log_num_shares + 1;
248 let log_d_chunk = log_d - log_num_chunks;
249 let data_ptr = data.as_mut_ptr();
250 let shift = log_num_shares - layer;
251 let tasks: Vec<_> = (0..1 << log_num_shares)
252 .map(|k| {
253 let (chunk0, chunk1) = with_middle_bit(k, shift);
254 let block = chunk0 >> (log_num_chunks - layer);
255 assert!(P::LOG_WIDTH <= log_d_chunk);
256 let log_chunk_len = log_d_chunk - P::LOG_WIDTH;
257 let chunk0 = unsafe {
258 from_raw_parts_mut(data_ptr.add(chunk0 << log_chunk_len), 1 << log_chunk_len)
259 };
260 let chunk1 = unsafe {
261 from_raw_parts_mut(data_ptr.add(chunk1 << log_chunk_len), 1 << log_chunk_len)
262 };
263 let twiddle = domain_context.twiddle(layer, block);
264 let twiddle = P::broadcast(twiddle);
265 (chunk0, chunk1, twiddle)
266 })
267 .collect();
268
269 tasks.into_par_iter().for_each(|(chunk0, chunk1, twiddle)| {
270 for i in 0..chunk0.len() {
271 let mut u = chunk0[i];
272 let mut v = chunk1[i];
273 u += v * twiddle;
274 v += u;
275 chunk0[i] = u;
276 chunk1[i] = v;
277 }
278 });
279}
280
281fn with_middle_bit(k: usize, shift: usize) -> (usize, usize) {
290 assert!(shift >= 1);
291
292 let ms = k >> (shift - 1);
294 let ls = k & ((1 << shift) - 1);
295
296 let k0 = ls | ((ms & !1) << shift);
297 let k1 = ls | ((ms | 1) << shift);
298
299 (k0, k1)
300}
301
302#[derive(Debug)]
303pub struct NeighborsLastBreadthFirst<DC> {
304 pub domain_context: DC,
306}
307
308impl<F, DC> AdditiveNTT for NeighborsLastBreadthFirst<DC>
309where
310 F: BinaryField,
311 DC: DomainContext<Field = F>,
312{
313 type Field = F;
314
315 fn forward_transform<P: PackedField<Scalar = F>>(
316 &self,
317 mut data: FieldSliceMut<P>,
318 skip_early: usize,
319 skip_late: usize,
320 ) {
321 let log_d = data.log_len();
322 if log_d <= P::LOG_WIDTH {
323 let fallback_ntt = NeighborsLastReference {
324 domain_context: &self.domain_context,
325 };
326 return fallback_ntt.forward_transform(data, skip_early, skip_late);
327 }
328
329 input_check(&self.domain_context, log_d, skip_early, skip_late);
330
331 forward_breadth_first(
332 self.domain_context(),
333 data.as_mut(),
334 log_d,
335 0,
336 0,
337 skip_early..(log_d - skip_late),
338 );
339 }
340
341 fn inverse_transform<P: PackedField<Scalar = F>>(
342 &self,
343 _data: FieldSliceMut<P>,
344 _skip_early: usize,
345 _skip_late: usize,
346 ) {
347 todo!()
348 }
349
350 fn domain_context(&self) -> &impl DomainContext<Field = F> {
351 &self.domain_context
352 }
353}
354
355#[derive(Debug)]
366pub struct NeighborsLastSingleThread<DC> {
367 pub domain_context: DC,
369 pub log_base_len: usize,
371}
372
373impl<DC> NeighborsLastSingleThread<DC> {
374 pub fn new(domain_context: DC) -> Self {
376 Self {
377 domain_context,
378 log_base_len: DEFAULT_LOG_BASE_LEN,
379 }
380 }
381}
382
383impl<DC: DomainContext> AdditiveNTT for NeighborsLastSingleThread<DC> {
384 type Field = DC::Field;
385
386 fn forward_transform<P: PackedField<Scalar = Self::Field>>(
387 &self,
388 mut data: FieldSliceMut<P>,
389 skip_early: usize,
390 skip_late: usize,
391 ) {
392 let log_d = data.log_len();
393 if log_d <= P::LOG_WIDTH {
394 let fallback_ntt = NeighborsLastReference {
395 domain_context: &self.domain_context,
396 };
397 return fallback_ntt.forward_transform(data, skip_early, skip_late);
398 }
399
400 input_check(&self.domain_context, log_d, skip_early, skip_late);
401
402 forward_depth_first(
403 &self.domain_context,
404 data.as_mut(),
405 log_d,
406 0,
407 0,
408 skip_early..(log_d - skip_late),
409 self.log_base_len.max(P::LOG_WIDTH + 1),
411 );
412 }
413
414 fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
415 &self,
416 _data_orig: FieldSliceMut<P>,
417 _skip_early: usize,
418 _skip_late: usize,
419 ) {
420 unimplemented!()
421 }
422
423 fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
424 &self.domain_context
425 }
426}
427
428#[derive(Debug)]
439pub struct NeighborsLastMultiThread<DC> {
440 pub domain_context: DC,
442 pub log_base_len: usize,
444 pub log_num_shares: usize,
448}
449
450impl<DC> NeighborsLastMultiThread<DC> {
451 pub fn new(domain_context: DC, log_num_shares: usize) -> Self {
453 Self {
454 domain_context,
455 log_base_len: DEFAULT_LOG_BASE_LEN,
456 log_num_shares,
457 }
458 }
459}
460
461impl<DC: DomainContext + Sync> AdditiveNTT for NeighborsLastMultiThread<DC> {
462 type Field = DC::Field;
463
464 fn forward_transform<P: PackedField<Scalar = Self::Field>>(
465 &self,
466 mut data: FieldSliceMut<P>,
467 skip_early: usize,
468 skip_late: usize,
469 ) {
470 let log_d = data.log_len();
471 if log_d <= P::LOG_WIDTH {
472 let fallback_ntt = NeighborsLastReference {
473 domain_context: &self.domain_context,
474 };
475 return fallback_ntt.forward_transform(data, skip_early, skip_late);
476 }
477
478 input_check(&self.domain_context, log_d, skip_early, skip_late);
479
480 let maximum_log_num_shares = log_d - P::LOG_WIDTH - 1;
489 let actual_log_num_shares = min(self.log_num_shares, maximum_log_num_shares);
490 let first_independent_layer = actual_log_num_shares;
491
492 let last_layer = log_d - skip_late;
493 let shared_layers = skip_early..min(first_independent_layer, last_layer);
494 let independent_layers = max(first_independent_layer, skip_early)..last_layer;
495
496 for layer in shared_layers {
497 forward_shared_layer(
498 &self.domain_context,
499 data.as_mut(),
500 log_d,
501 layer,
502 actual_log_num_shares,
503 );
504 }
505
506 let layer = min(independent_layers.start, maximum_log_num_shares);
511 let log_d_chunk = log_d - layer;
512 data.as_mut()
513 .par_chunks_exact_mut(1 << (log_d_chunk - P::LOG_WIDTH))
514 .enumerate()
515 .for_each(|(block, chunk)| {
516 forward_depth_first(
517 &self.domain_context,
518 chunk,
519 log_d_chunk,
520 layer,
521 block,
522 independent_layers.clone(),
523 self.log_base_len,
524 );
525 });
526 }
527
528 fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
529 &self,
530 _data_orig: FieldSliceMut<P>,
531 _skip_early: usize,
532 _skip_late: usize,
533 ) {
534 unimplemented!()
535 }
536
537 fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
538 &self.domain_context
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_with_middle_bit() {
548 assert_eq!(with_middle_bit(0b000, 1), (0b0000, 0b0010));
549 assert_eq!(with_middle_bit(0b000, 2), (0b0000, 0b0100));
550 assert_eq!(with_middle_bit(0b000, 3), (0b0000, 0b1000));
551
552 assert_eq!(with_middle_bit(0b111, 1), (0b1101, 0b1111));
553 assert_eq!(with_middle_bit(0b111, 2), (0b1011, 0b1111));
554 assert_eq!(with_middle_bit(0b111, 3), (0b0111, 0b1111));
555
556 assert_eq!(with_middle_bit(0b1110110, 2), (0b11101010, 0b11101110));
557 }
558}