1use std::{
4 cmp::{max, min},
5 ops::Range,
6 slice::from_raw_parts_mut,
7};
8
9use binius_field::PackedField;
10use binius_utils::rayon::{
11 iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
12 slice::ParallelSliceMut,
13};
14
15use super::{AdditiveNTT, DomainContext};
16use crate::{FieldSlice, FieldSliceMut};
17
18const DEFAULT_LOG_BASE_LEN: usize = 8;
19
20pub struct NeighborsLastReference<DC> {
24 pub domain_context: DC,
25}
26
27impl<DC: DomainContext> AdditiveNTT for NeighborsLastReference<DC> {
28 type Field = DC::Field;
29
30 fn forward_transform<P: PackedField<Scalar = Self::Field>>(
31 &self,
32 mut data: FieldSliceMut<P>,
33 skip_early: usize,
34 skip_late: usize,
35 ) {
36 let log_d = data.log_len();
37
38 for layer in skip_early..(log_d - skip_late) {
39 let num_blocks = 1 << layer;
40 let block_size_half = 1 << (log_d - layer - 1);
41 for block in 0..num_blocks {
42 let twiddle = self.domain_context.twiddle(layer, block);
43 let block_start = block << (log_d - layer);
44 for idx0 in block_start..(block_start + block_size_half) {
45 let idx1 = block_size_half | idx0;
46 let mut u = data.get(idx0).unwrap();
48 let mut v = data.get(idx1).unwrap();
49 u += v * twiddle;
50 v += u;
51 data.set(idx0, u).unwrap();
52 data.set(idx1, v).unwrap();
53 }
54 }
55 }
56 }
57
58 fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
59 &self,
60 mut data: FieldSliceMut<P>,
61 skip_early: usize,
62 skip_late: usize,
63 ) {
64 let log_d = data.log_len();
65
66 for layer in (skip_early..(log_d - skip_late)).rev() {
67 let num_blocks = 1 << layer;
68 let block_size_half = 1 << (log_d - layer - 1);
69 for block in 0..num_blocks {
70 let twiddle = self.domain_context.twiddle(layer, block);
71 let block_start = block << (log_d - layer);
72 for idx0 in block_start..(block_start + block_size_half) {
73 let idx1 = block_size_half | idx0;
74 let mut u = data.get(idx0).unwrap();
76 let mut v = data.get(idx1).unwrap();
77 v += u;
78 u += v * twiddle;
79 data.set(idx0, u).unwrap();
80 data.set(idx1, v).unwrap();
81 }
82 }
83 }
84 }
85
86 fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
87 &self.domain_context
88 }
89}
90
91fn forward_depth_first<P: PackedField>(
122 domain_context: &impl DomainContext<Field = P::Scalar>,
123 data: &mut [P],
124 log_d: usize,
125 layer: usize,
126 block: usize,
127 mut layer_range: Range<usize>,
128 log_base_len: usize,
129) {
130 debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
132 debug_assert!(data.len() >= 2);
133 debug_assert!(layer_range.end <= domain_context.log_domain_size());
134 debug_assert!(layer <= layer_range.start);
135
136 if layer >= layer_range.end {
137 return;
138 }
139
140 if log_d <= log_base_len || data.len() <= 2 {
144 forward_breadth_first(domain_context, data, log_d, layer, block, layer_range);
145 return;
146 }
147
148 let block_size_half = 1 << (log_d - 1 - P::LOG_WIDTH);
149 if layer >= layer_range.start {
150 let twiddle = domain_context.twiddle(layer, block);
152 let twiddle = P::broadcast(twiddle);
153 for idx0 in 0..block_size_half {
154 let idx1 = block_size_half | idx0;
155 let mut u = data[idx0];
157 let mut v = data[idx1];
158 u += v * twiddle;
159 v += u;
160 data[idx0] = u;
161 data[idx1] = v;
162 }
163
164 layer_range.start += 1;
165 }
166
167 forward_depth_first(
169 domain_context,
170 &mut data[..block_size_half],
171 log_d - 1,
172 layer + 1,
173 block << 1,
174 layer_range.clone(),
175 log_base_len,
176 );
177 forward_depth_first(
178 domain_context,
179 &mut data[block_size_half..],
180 log_d - 1,
181 layer + 1,
182 (block << 1) + 1,
183 layer_range,
184 log_base_len,
185 );
186}
187
188fn forward_breadth_first<P: PackedField>(
197 domain_context: &impl DomainContext<Field = P::Scalar>,
198 data: &mut [P],
199 log_d: usize,
200 layer: usize,
201 block: usize,
202 layer_range: Range<usize>,
203) {
204 debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
206 debug_assert!(data.len() >= 2);
207 debug_assert!(layer_range.end <= domain_context.log_domain_size());
208 debug_assert!(layer <= layer_range.start);
209
210 let num_packed_layers_left = log_d - P::LOG_WIDTH;
211 let first_interleaved_layer = layer + num_packed_layers_left;
212 let first_interleaved_layer = max(first_interleaved_layer, layer_range.start);
213 let packed_layers = layer_range.start..min(first_interleaved_layer, layer_range.end);
214 let interleaved_layers = first_interleaved_layer..layer_range.end;
215
216 for l in packed_layers {
220 let l_rel = l - layer;
223 let block_offset = block << l_rel;
224 let block_size_half = 1 << (log_d - l_rel - 1 - P::LOG_WIDTH);
227 for b in 0..1 << l_rel {
228 let twiddle = domain_context.twiddle(l, block_offset | b);
230 let twiddle = P::broadcast(twiddle);
231 let block_start = b << (log_d - l_rel - P::LOG_WIDTH);
232 for idx0 in block_start..(block_start + block_size_half) {
233 let idx1 = block_size_half | idx0;
234 let mut u = data[idx0];
236 let mut v = data[idx1];
237 u += v * twiddle;
238 v += u;
239 data[idx0] = u;
240 data[idx1] = v;
241 }
242 }
243 }
244
245 for l in interleaved_layers {
249 let l_rel = l - layer;
250 let block_offset = block << l_rel;
251 let block_size_half = 1 << (log_d - l_rel - 1);
252 let log_block_size_half = log_d - l_rel - 1;
253
254 let mut packed_twiddle_offset = P::zero();
256 let log_blocks_per_packed = P::LOG_WIDTH - log_block_size_half;
257 for i in 0..1 << log_blocks_per_packed {
258 let block_start = i << log_block_size_half;
259 let twiddle_offset_i_index = rotate_right(i, log_blocks_per_packed);
260 let twiddle_offset_i = domain_context.twiddle(l, twiddle_offset_i_index);
261 for j in block_start..(block_start + block_size_half) {
262 packed_twiddle_offset.set(j, twiddle_offset_i);
263 }
264 }
265
266 for b in (0..data.len()).step_by(2) {
267 let first_twiddle =
268 domain_context.twiddle(l, block_offset | (b << (log_blocks_per_packed - 1)));
269 let twiddle = P::broadcast(first_twiddle) + packed_twiddle_offset;
270 let (mut u, mut v) = data[b].interleave(data[b | 1], log_block_size_half);
271 u += v * twiddle;
272 v += u;
273 (data[b], data[b | 1]) = u.interleave(v, log_block_size_half);
274 }
275 }
276}
277
278fn forward_shared_layer<P: PackedField>(
290 domain_context: &(impl DomainContext<Field = P::Scalar> + Sync),
291 data: &mut [P],
292 log_d: usize,
293 layer: usize,
294 log_num_shares: usize,
295) {
296 debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
298 debug_assert!(1 << (log_num_shares + 1) <= data.len());
299 debug_assert!(layer < domain_context.log_domain_size());
300
301 let log_num_chunks = log_num_shares + 1;
302 let log_d_chunk = log_d - log_num_chunks;
303 let data_ptr = data.as_mut_ptr();
304 let shift = log_num_shares - layer;
305 let tasks: Vec<_> = (0..1 << log_num_shares)
306 .map(|k| {
307 let (chunk0, chunk1) = with_middle_bit(k, shift);
308 let block = chunk0 >> (log_num_chunks - layer);
309 assert!(P::LOG_WIDTH <= log_d_chunk);
310 let log_chunk_len = log_d_chunk - P::LOG_WIDTH;
311 let chunk0 = unsafe {
312 from_raw_parts_mut(data_ptr.add(chunk0 << log_chunk_len), 1 << log_chunk_len)
313 };
314 let chunk1 = unsafe {
315 from_raw_parts_mut(data_ptr.add(chunk1 << log_chunk_len), 1 << log_chunk_len)
316 };
317 let twiddle = domain_context.twiddle(layer, block);
318 let twiddle = P::broadcast(twiddle);
319 (chunk0, chunk1, twiddle)
320 })
321 .collect();
322
323 tasks.into_par_iter().for_each(|(chunk0, chunk1, twiddle)| {
324 for i in 0..chunk0.len() {
325 let mut u = chunk0[i];
326 let mut v = chunk1[i];
327 u += v * twiddle;
328 v += u;
329 chunk0[i] = u;
330 chunk1[i] = v;
331 }
332 });
333}
334
335#[inline]
338fn rotate_right(value: usize, bits: usize) -> usize {
339 (value >> 1) | ((value & 1) << (bits - 1))
340}
341
342fn with_middle_bit(k: usize, shift: usize) -> (usize, usize) {
351 assert!(shift >= 1);
352
353 let ms = k >> (shift - 1);
355 let ls = k & ((1 << shift) - 1);
356
357 let k0 = ls | ((ms & !1) << shift);
358 let k1 = ls | ((ms | 1) << shift);
359
360 (k0, k1)
361}
362
363fn input_check<P: PackedField>(
366 domain_context: &impl DomainContext<Field = P::Scalar>,
367 data: FieldSlice<P>,
368 skip_early: usize,
369 skip_late: usize,
370) -> usize {
371 let log_d = data.log_len();
372
373 assert!(skip_early + skip_late <= log_d);
375
376 assert!(log_d <= domain_context.log_domain_size() + skip_late);
378
379 log_d
380}
381
382#[derive(Debug)]
393pub struct NeighborsLastSingleThread<DC> {
394 pub domain_context: DC,
396 pub log_base_len: usize,
398}
399
400impl<DC> NeighborsLastSingleThread<DC> {
401 pub fn new(domain_context: DC) -> Self {
403 Self {
404 domain_context,
405 log_base_len: DEFAULT_LOG_BASE_LEN,
406 }
407 }
408}
409
410impl<DC: DomainContext> AdditiveNTT for NeighborsLastSingleThread<DC> {
411 type Field = DC::Field;
412
413 fn forward_transform<P: PackedField<Scalar = Self::Field>>(
414 &self,
415 mut data_orig: FieldSliceMut<P>,
416 skip_early: usize,
417 skip_late: usize,
418 ) {
419 let log_d = input_check(&self.domain_context, data_orig.to_ref(), skip_early, skip_late);
421
422 let data = data_orig.as_mut();
423
424 if data.len() == 1 {
427 let reference_ntt = NeighborsLastReference {
428 domain_context: &self.domain_context,
429 };
430 reference_ntt.forward_transform(data_orig, skip_early, skip_late);
431 return;
432 }
433
434 forward_depth_first(
435 &self.domain_context,
436 data,
437 log_d,
438 0,
439 0,
440 skip_early..(log_d - skip_late),
441 self.log_base_len,
442 );
443 }
444
445 fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
446 &self,
447 mut _data_orig: FieldSliceMut<P>,
448 _skip_early: usize,
449 _skip_late: usize,
450 ) {
451 unimplemented!()
452 }
453
454 fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
455 &self.domain_context
456 }
457}
458
459#[derive(Debug)]
470pub struct NeighborsLastMultiThread<DC> {
471 pub domain_context: DC,
473 pub log_base_len: usize,
475 pub log_num_shares: usize,
479}
480
481impl<DC> NeighborsLastMultiThread<DC> {
482 pub fn new(domain_context: DC, log_num_shares: usize) -> Self {
484 Self {
485 domain_context,
486 log_base_len: DEFAULT_LOG_BASE_LEN,
487 log_num_shares,
488 }
489 }
490}
491
492impl<DC: DomainContext + Sync> AdditiveNTT for NeighborsLastMultiThread<DC> {
493 type Field = DC::Field;
494
495 fn forward_transform<P: PackedField<Scalar = Self::Field>>(
496 &self,
497 mut data_orig: FieldSliceMut<P>,
498 skip_early: usize,
499 skip_late: usize,
500 ) {
501 let log_d = input_check(&self.domain_context, data_orig.to_ref(), skip_early, skip_late);
503
504 let data = data_orig.as_mut();
505
506 if data.len() == 1 {
509 let reference_ntt = NeighborsLastReference {
510 domain_context: &self.domain_context,
511 };
512 reference_ntt.forward_transform(data_orig, skip_early, skip_late);
513 return;
514 }
515
516 let maximum_log_num_shares = log_d - P::LOG_WIDTH - 1;
525 let actual_log_num_shares = min(self.log_num_shares, maximum_log_num_shares);
526 let first_independent_layer = actual_log_num_shares;
527
528 let last_layer = log_d - skip_late;
529 let shared_layers = skip_early..min(first_independent_layer, last_layer);
530 let independent_layers = max(first_independent_layer, skip_early)..last_layer;
531
532 for layer in shared_layers {
533 forward_shared_layer(&self.domain_context, data, log_d, layer, actual_log_num_shares);
534 }
535
536 let layer = min(independent_layers.start, maximum_log_num_shares);
541 let log_d_chunk = log_d - layer;
542 data.par_chunks_exact_mut(1 << (log_d_chunk - P::LOG_WIDTH))
543 .enumerate()
544 .for_each(|(block, chunk)| {
545 forward_depth_first(
546 &self.domain_context,
547 chunk,
548 log_d_chunk,
549 layer,
550 block,
551 independent_layers.clone(),
552 self.log_base_len,
553 );
554 });
555 }
556
557 fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
558 &self,
559 mut _data_orig: FieldSliceMut<P>,
560 _skip_early: usize,
561 _skip_late: usize,
562 ) {
563 unimplemented!()
564 }
565
566 fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
567 &self.domain_context
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn test_rotate_right() {
577 assert_eq!(rotate_right(0b1001011, 7), 0b1100101);
578 assert_eq!(rotate_right(0b0001011, 7), 0b1000101);
579 assert_eq!(rotate_right(0b0001010, 7), 0b0000101);
580 assert_eq!(rotate_right(0b1, 1), 0b1);
581 assert_eq!(rotate_right(0b0, 1), 0b0);
582 }
583
584 #[test]
585 fn test_with_middle_bit() {
586 assert_eq!(with_middle_bit(0b000, 1), (0b0000, 0b0010));
587 assert_eq!(with_middle_bit(0b000, 2), (0b0000, 0b0100));
588 assert_eq!(with_middle_bit(0b000, 3), (0b0000, 0b1000));
589
590 assert_eq!(with_middle_bit(0b111, 1), (0b1101, 0b1111));
591 assert_eq!(with_middle_bit(0b111, 2), (0b1011, 0b1111));
592 assert_eq!(with_middle_bit(0b111, 3), (0b0111, 0b1111));
593
594 assert_eq!(with_middle_bit(0b1110110, 2), (0b11101010, 0b11101110));
595 }
596}