binius_math/ntt/
reference.rs1use binius_field::PackedField;
4
5use crate::{
6 field_buffer::FieldSliceMut,
7 ntt::{AdditiveNTT, DomainContext},
8};
9
10pub struct NeighborsLastReference<DC> {
14 pub domain_context: DC,
15}
16
17impl<DC: DomainContext> AdditiveNTT for NeighborsLastReference<DC> {
18 type Field = DC::Field;
19
20 fn forward_transform<P: PackedField<Scalar = Self::Field>>(
21 &self,
22 mut data: FieldSliceMut<P>,
23 skip_early: usize,
24 skip_late: usize,
25 ) {
26 let log_d = data.log_len();
27 input_check(&self.domain_context, log_d, skip_early, skip_late);
28
29 for layer in skip_early..(log_d - skip_late) {
30 let num_blocks = 1 << layer;
31 let block_size_half = 1 << (log_d - layer - 1);
32 for block in 0..num_blocks {
33 let twiddle = self.domain_context.twiddle(layer, block);
34 let block_start = block << (log_d - layer);
35 for idx0 in block_start..(block_start + block_size_half) {
36 let idx1 = block_size_half | idx0;
37 let mut u = data.get_checked(idx0).unwrap();
39 let mut v = data.get_checked(idx1).unwrap();
40 u += v * twiddle;
41 v += u;
42 data.set_checked(idx0, u).unwrap();
43 data.set_checked(idx1, v).unwrap();
44 }
45 }
46 }
47 }
48
49 fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
50 &self,
51 mut data: FieldSliceMut<P>,
52 skip_early: usize,
53 skip_late: usize,
54 ) {
55 let log_d = data.log_len();
56 input_check(&self.domain_context, log_d, skip_early, skip_late);
57
58 for layer in (skip_early..(log_d - skip_late)).rev() {
59 let num_blocks = 1 << layer;
60 let block_size_half = 1 << (log_d - layer - 1);
61 for block in 0..num_blocks {
62 let twiddle = self.domain_context.twiddle(layer, block);
63 let block_start = block << (log_d - layer);
64 for idx0 in block_start..(block_start + block_size_half) {
65 let idx1 = block_size_half | idx0;
66 let mut u = data.get_checked(idx0).unwrap();
68 let mut v = data.get_checked(idx1).unwrap();
69 v += u;
70 u += v * twiddle;
71 data.set_checked(idx0, u).unwrap();
72 data.set_checked(idx1, v).unwrap();
73 }
74 }
75 }
76 }
77
78 fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
79 &self.domain_context
80 }
81}
82
83pub fn input_check(
90 domain_context: &impl DomainContext,
91 log_d: usize,
92 skip_early: usize,
93 skip_late: usize,
94) {
95 assert!(skip_early + skip_late <= log_d);
97
98 assert!(log_d - skip_late <= domain_context.log_domain_size());
100}