binius_math/ntt/
neighbors_last.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	cmp::{max, min},
5	ops::Range,
6	slice::from_raw_parts_mut,
7};
8
9use binius_field::PackedField;
10use binius_utils::rayon::{
11	iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
12	slice::ParallelSliceMut,
13};
14
15use super::{AdditiveNTT, DomainContext};
16use crate::{FieldSlice, FieldSliceMut};
17
18const DEFAULT_LOG_BASE_LEN: usize = 8;
19
20/// Reference implementation of [`AdditiveNTT`].
21///
22/// This is slow. Do not use in production.
23pub struct NeighborsLastReference<DC> {
24	pub domain_context: DC,
25}
26
27impl<DC: DomainContext> AdditiveNTT for NeighborsLastReference<DC> {
28	type Field = DC::Field;
29
30	fn forward_transform<P: PackedField<Scalar = Self::Field>>(
31		&self,
32		mut data: FieldSliceMut<P>,
33		skip_early: usize,
34		skip_late: usize,
35	) {
36		let log_d = data.log_len();
37
38		for layer in skip_early..(log_d - skip_late) {
39			let num_blocks = 1 << layer;
40			let block_size_half = 1 << (log_d - layer - 1);
41			for block in 0..num_blocks {
42				let twiddle = self.domain_context.twiddle(layer, block);
43				let block_start = block << (log_d - layer);
44				for idx0 in block_start..(block_start + block_size_half) {
45					let idx1 = block_size_half | idx0;
46					// perform butterfly
47					let mut u = data.get(idx0).unwrap();
48					let mut v = data.get(idx1).unwrap();
49					u += v * twiddle;
50					v += u;
51					data.set(idx0, u).unwrap();
52					data.set(idx1, v).unwrap();
53				}
54			}
55		}
56	}
57
58	fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
59		&self,
60		mut data: FieldSliceMut<P>,
61		skip_early: usize,
62		skip_late: usize,
63	) {
64		let log_d = data.log_len();
65
66		for layer in (skip_early..(log_d - skip_late)).rev() {
67			let num_blocks = 1 << layer;
68			let block_size_half = 1 << (log_d - layer - 1);
69			for block in 0..num_blocks {
70				let twiddle = self.domain_context.twiddle(layer, block);
71				let block_start = block << (log_d - layer);
72				for idx0 in block_start..(block_start + block_size_half) {
73					let idx1 = block_size_half | idx0;
74					// perform butterfly
75					let mut u = data.get(idx0).unwrap();
76					let mut v = data.get(idx1).unwrap();
77					v += u;
78					u += v * twiddle;
79					data.set(idx0, u).unwrap();
80					data.set(idx1, v).unwrap();
81				}
82			}
83		}
84	}
85
86	fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
87		&self.domain_context
88	}
89}
90
91/// Runs a **part** of an NTT butterfly network, in depth-first order.
92///
93/// Concretely, it processes a specific memory block in the butterfly network, which is given by
94/// `layer` and `block`. For this memory block, it processes the layers given by `layer_range`.
95///
96/// For example, suppose `layer=2` and `block=2`.
97/// That means we are in an NTT butterfly network in layer 2 (the third layer) and block 2 (the
98/// third block in this layer, there are four blocks in total in this layer). `data` contains the
99/// data of this block, so it's only a chunk of the total data used in the NTT. Now suppose
100/// `layer_range=2..5`. Then we will process the following butterfly blocks:
101/// - `layer=2` `block=2`
102/// - `layer=3` `block=4`
103/// - `layer=3` `block=5`
104/// - `layer=4` `block=8`
105/// - `layer=4` `block=9`
106/// - `layer=4` `block=10`
107/// - `layer=4` `block=11`
108///
109/// (Just in a different order. We listed breadth-first order, we would process them in
110/// depth-first order.)
111///
112/// The argument `log_base_len` determines for which `log_d` we call the breadth-first
113/// implementation as a base case.
114///
115/// ## Preconditions
116///
117/// - `2^(log_d) == data.len() * packing_width`
118/// - `data.len() >= 2`
119/// - `domain_context` holds all the twiddles up to `layer_range.end` (exclusive)
120/// - `layer <= layer_range.start`
121fn forward_depth_first<P: PackedField>(
122	domain_context: &impl DomainContext<Field = P::Scalar>,
123	data: &mut [P],
124	log_d: usize,
125	layer: usize,
126	block: usize,
127	mut layer_range: Range<usize>,
128	log_base_len: usize,
129) {
130	// check preconditions
131	debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
132	debug_assert!(data.len() >= 2);
133	debug_assert!(layer_range.end <= domain_context.log_domain_size());
134	debug_assert!(layer <= layer_range.start);
135
136	if layer >= layer_range.end {
137		return;
138	}
139
140	// if the problem size is small, we just do breadth_first (to get rid of the stack overhead)
141	// we also need to do that if the number of scalars is just two times our packing width, i.e. if
142	// we only have two packed elements in our data slice
143	if log_d <= log_base_len || data.len() <= 2 {
144		forward_breadth_first(domain_context, data, log_d, layer, block, layer_range);
145		return;
146	}
147
148	let block_size_half = 1 << (log_d - 1 - P::LOG_WIDTH);
149	if layer >= layer_range.start {
150		// process only one layer of this block
151		let twiddle = domain_context.twiddle(layer, block);
152		let twiddle = P::broadcast(twiddle);
153		for idx0 in 0..block_size_half {
154			let idx1 = block_size_half | idx0;
155			// perform butterfly
156			let mut u = data[idx0];
157			let mut v = data[idx1];
158			u += v * twiddle;
159			v += u;
160			data[idx0] = u;
161			data[idx1] = v;
162		}
163
164		layer_range.start += 1;
165	}
166
167	// then recurse
168	forward_depth_first(
169		domain_context,
170		&mut data[..block_size_half],
171		log_d - 1,
172		layer + 1,
173		block << 1,
174		layer_range.clone(),
175		log_base_len,
176	);
177	forward_depth_first(
178		domain_context,
179		&mut data[block_size_half..],
180		log_d - 1,
181		layer + 1,
182		(block << 1) + 1,
183		layer_range,
184		log_base_len,
185	);
186}
187
188/// Same as [`forward_depth_first`], but runs in breadth-first order.
189///
190/// ## Preconditions
191///
192/// - `2^(log_d) == data.len() * packing_width`
193/// - `data.len() >= 2`
194/// - `domain_context` holds all the twiddles up to `layer_bound` (exclusive)
195/// - `layer <= layer_range.start`
196fn forward_breadth_first<P: PackedField>(
197	domain_context: &impl DomainContext<Field = P::Scalar>,
198	data: &mut [P],
199	log_d: usize,
200	layer: usize,
201	block: usize,
202	layer_range: Range<usize>,
203) {
204	// check preconditions
205	debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
206	debug_assert!(data.len() >= 2);
207	debug_assert!(layer_range.end <= domain_context.log_domain_size());
208	debug_assert!(layer <= layer_range.start);
209
210	let num_packed_layers_left = log_d - P::LOG_WIDTH;
211	let first_interleaved_layer = layer + num_packed_layers_left;
212	let first_interleaved_layer = max(first_interleaved_layer, layer_range.start);
213	let packed_layers = layer_range.start..min(first_interleaved_layer, layer_range.end);
214	let interleaved_layers = first_interleaved_layer..layer_range.end;
215
216	// Run the layers where packing_width >= block_size_half.
217	// In these layers, we can always process whole packed elements, we don't need to interleave
218	// them or read/write "within" packed elements.
219	for l in packed_layers {
220		// there are 2^l total blocks in layer l, but we only want to process those which are
221		// contained in `block`
222		let l_rel = l - layer;
223		let block_offset = block << l_rel;
224		// there are 2^(log_d - l) total scalars in a block in layer l, but they come in
225		// pairs, so we iterate over 2^(log_d - l - 1) many scalars
226		let block_size_half = 1 << (log_d - l_rel - 1 - P::LOG_WIDTH);
227		for b in 0..1 << l_rel {
228			// all butterflys within a block share the same twiddle
229			let twiddle = domain_context.twiddle(l, block_offset | b);
230			let twiddle = P::broadcast(twiddle);
231			let block_start = b << (log_d - l_rel - P::LOG_WIDTH);
232			for idx0 in block_start..(block_start + block_size_half) {
233				let idx1 = block_size_half | idx0;
234				// perform butterfly
235				let mut u = data[idx0];
236				let mut v = data[idx1];
237				u += v * twiddle;
238				v += u;
239				data[idx0] = u;
240				data[idx1] = v;
241			}
242		}
243	}
244
245	// Run the layers where packing_width < block_size_half.
246	// That is, we would need to work "within" packed field elements.
247	// We solve this problem by interleaving the packed elements with each other.
248	for l in interleaved_layers {
249		let l_rel = l - layer;
250		let block_offset = block << l_rel;
251		let block_size_half = 1 << (log_d - l_rel - 1);
252		let log_block_size_half = log_d - l_rel - 1;
253
254		// calculate packed_twiddle_offset
255		let mut packed_twiddle_offset = P::zero();
256		let log_blocks_per_packed = P::LOG_WIDTH - log_block_size_half;
257		for i in 0..1 << log_blocks_per_packed {
258			let block_start = i << log_block_size_half;
259			let twiddle_offset_i_index = rotate_right(i, log_blocks_per_packed);
260			let twiddle_offset_i = domain_context.twiddle(l, twiddle_offset_i_index);
261			for j in block_start..(block_start + block_size_half) {
262				packed_twiddle_offset.set(j, twiddle_offset_i);
263			}
264		}
265
266		for b in (0..data.len()).step_by(2) {
267			let first_twiddle =
268				domain_context.twiddle(l, block_offset | (b << (log_blocks_per_packed - 1)));
269			let twiddle = P::broadcast(first_twiddle) + packed_twiddle_offset;
270			let (mut u, mut v) = data[b].interleave(data[b | 1], log_block_size_half);
271			u += v * twiddle;
272			v += u;
273			(data[b], data[b | 1]) = u.interleave(v, log_block_size_half);
274		}
275	}
276}
277
278/// Process a layer of the NTT butterfly network in parallel by splitting the work up into
279/// `2^log_num_shares` many shares. This will also split up single *blocks* into multiple shares.
280///
281/// (The latter is the whole purpose of this function. If the number of shares is small enough (and
282/// the number of blocks is big enough) so that we don't need to split up blocks, we could just run
283/// [`forward_depth_first`] on disjoint chunks.)
284///
285/// - `2^(log_d) == data.len() * packing_width`
286/// - **Important:** `2^log_num_shares * 2 <= data.len()` (every share is working with whole packed
287///   elements, so every share needs at least 2 packed elements)
288/// - `domain_context` holds the twiddles of `layer`
289fn forward_shared_layer<P: PackedField>(
290	domain_context: &(impl DomainContext<Field = P::Scalar> + Sync),
291	data: &mut [P],
292	log_d: usize,
293	layer: usize,
294	log_num_shares: usize,
295) {
296	// check preconditions
297	debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
298	debug_assert!(1 << (log_num_shares + 1) <= data.len());
299	debug_assert!(layer < domain_context.log_domain_size());
300
301	let log_num_chunks = log_num_shares + 1;
302	let log_d_chunk = log_d - log_num_chunks;
303	let data_ptr = data.as_mut_ptr();
304	let shift = log_num_shares - layer;
305	let tasks: Vec<_> = (0..1 << log_num_shares)
306		.map(|k| {
307			let (chunk0, chunk1) = with_middle_bit(k, shift);
308			let block = chunk0 >> (log_num_chunks - layer);
309			assert!(P::LOG_WIDTH <= log_d_chunk);
310			let log_chunk_len = log_d_chunk - P::LOG_WIDTH;
311			let chunk0 = unsafe {
312				from_raw_parts_mut(data_ptr.add(chunk0 << log_chunk_len), 1 << log_chunk_len)
313			};
314			let chunk1 = unsafe {
315				from_raw_parts_mut(data_ptr.add(chunk1 << log_chunk_len), 1 << log_chunk_len)
316			};
317			let twiddle = domain_context.twiddle(layer, block);
318			let twiddle = P::broadcast(twiddle);
319			(chunk0, chunk1, twiddle)
320		})
321		.collect();
322
323	tasks.into_par_iter().for_each(|(chunk0, chunk1, twiddle)| {
324		for i in 0..chunk0.len() {
325			let mut u = chunk0[i];
326			let mut v = chunk1[i];
327			u += v * twiddle;
328			v += u;
329			chunk0[i] = u;
330			chunk1[i] = v;
331		}
332	});
333}
334
335/// Interprets `value` as a bit-sequence of `bits` many bits, and rotates them to the right.
336/// For example: `rotate_right(0b0001011, 7) -> 0b1000101`
337#[inline]
338fn rotate_right(value: usize, bits: usize) -> usize {
339	(value >> 1) | ((value & 1) << (bits - 1))
340}
341
342/// Inserts a bit into `k`. Returns both the version with `0` inserted and `1` inserted.
343///
344/// The first `shift` bits are preserved, then `0` or `1` is inserted, and then the remaining bits
345/// of `k` follow.
346///
347/// ## Preconditions
348///
349/// - `shift` must be strictly greater than 0
350fn with_middle_bit(k: usize, shift: usize) -> (usize, usize) {
351	assert!(shift >= 1);
352
353	// most significant and least significant bits, overlapping in one bit
354	let ms = k >> (shift - 1);
355	let ls = k & ((1 << shift) - 1);
356
357	let k0 = ls | ((ms & !1) << shift);
358	let k1 = ls | ((ms | 1) << shift);
359
360	(k0, k1)
361}
362
363/// Checks for the preconditions of the `AdditiveNTT` transforms and returns `log_d`, the base-2 log
364/// of the total number of scalars in the input.
365fn input_check<P: PackedField>(
366	domain_context: &impl DomainContext<Field = P::Scalar>,
367	data: FieldSlice<P>,
368	skip_early: usize,
369	skip_late: usize,
370) -> usize {
371	let log_d = data.log_len();
372
373	// we can't "double-skip" layers
374	assert!(skip_early + skip_late <= log_d);
375
376	// we need enough twiddles in `domain_context`
377	assert!(log_d <= domain_context.log_domain_size() + skip_late);
378
379	log_d
380}
381
382/// A single-threaded implementation of [`AdditiveNTT`].
383///
384/// The code only makes sure that it's fast for a _large_ data input.
385/// For small inputs, it can be comparatively slow!
386///
387/// The implementation is depth-first, but calls a breadth-first implementation as a base case.
388///
389/// Note that "neighbors last" refers to the memory layout for the NTT: In the _last_ layer of this
390/// NTT algorithm, neighboring elements speak to each other. In the classic FFT that's usually the
391/// case for "decimation in frequency".
392#[derive(Debug)]
393pub struct NeighborsLastSingleThread<DC> {
394	/// The domain context from which the twiddles are pulled.
395	pub domain_context: DC,
396	/// Determines when to switch from depth-first to the breadth-first base case.
397	pub log_base_len: usize,
398}
399
400impl<DC> NeighborsLastSingleThread<DC> {
401	/// Convenience constructor which sets `log_base_len` to a reasonable default.
402	pub fn new(domain_context: DC) -> Self {
403		Self {
404			domain_context,
405			log_base_len: DEFAULT_LOG_BASE_LEN,
406		}
407	}
408}
409
410impl<DC: DomainContext> AdditiveNTT for NeighborsLastSingleThread<DC> {
411	type Field = DC::Field;
412
413	fn forward_transform<P: PackedField<Scalar = Self::Field>>(
414		&self,
415		mut data_orig: FieldSliceMut<P>,
416		skip_early: usize,
417		skip_late: usize,
418	) {
419		// total number of scalars
420		let log_d = input_check(&self.domain_context, data_orig.to_ref(), skip_early, skip_late);
421
422		let data = data_orig.as_mut();
423
424		// if there is only a single packed element, we don't want to bother with potential
425		// interleaving issues in the future so we just call the (slow) reference NTT
426		if data.len() == 1 {
427			let reference_ntt = NeighborsLastReference {
428				domain_context: &self.domain_context,
429			};
430			reference_ntt.forward_transform(data_orig, skip_early, skip_late);
431			return;
432		}
433
434		forward_depth_first(
435			&self.domain_context,
436			data,
437			log_d,
438			0,
439			0,
440			skip_early..(log_d - skip_late),
441			self.log_base_len,
442		);
443	}
444
445	fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
446		&self,
447		mut _data_orig: FieldSliceMut<P>,
448		_skip_early: usize,
449		_skip_late: usize,
450	) {
451		unimplemented!()
452	}
453
454	fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
455		&self.domain_context
456	}
457}
458
459/// A multi-threaded implementation of [`AdditiveNTT`].
460///
461/// The code only makes sure that it's fast for a _large_ data input.
462/// For small inputs, it can be comparatively slow!
463///
464/// The implementation is depth-first, but calls a breadth-first implementation as a base case.
465///
466/// Note that "neighbors last" refers to the memory layout for the NTT: In the _last_ layer of this
467/// NTT algorithm, neighboring elements speak to each other. In the classic FFT that's usually the
468/// case for "decimation in frequency".
469#[derive(Debug)]
470pub struct NeighborsLastMultiThread<DC> {
471	/// The domain context from which the twiddles are pulled.
472	pub domain_context: DC,
473	/// Determines when to switch from depth-first to the breadth-first base case.
474	pub log_base_len: usize,
475	/// The base-2 logarithm of number of equal-sized shares that the problem should be split into.
476	/// Each share needs to do the same amount of work. If you have equally powered cores
477	/// available, this should be the base-2 logarithm of the number of cores.
478	pub log_num_shares: usize,
479}
480
481impl<DC> NeighborsLastMultiThread<DC> {
482	/// Convenience constructor which sets `log_base_len` to a reasonable default.
483	pub fn new(domain_context: DC, log_num_shares: usize) -> Self {
484		Self {
485			domain_context,
486			log_base_len: DEFAULT_LOG_BASE_LEN,
487			log_num_shares,
488		}
489	}
490}
491
492impl<DC: DomainContext + Sync> AdditiveNTT for NeighborsLastMultiThread<DC> {
493	type Field = DC::Field;
494
495	fn forward_transform<P: PackedField<Scalar = Self::Field>>(
496		&self,
497		mut data_orig: FieldSliceMut<P>,
498		skip_early: usize,
499		skip_late: usize,
500	) {
501		// total number of scalars
502		let log_d = input_check(&self.domain_context, data_orig.to_ref(), skip_early, skip_late);
503
504		let data = data_orig.as_mut();
505
506		// if there is only a single packed element, we don't want to bother with potential
507		// interleaving issues in the future so we just call the (slow) reference NTT
508		if data.len() == 1 {
509			let reference_ntt = NeighborsLastReference {
510				domain_context: &self.domain_context,
511			};
512			reference_ntt.forward_transform(data_orig, skip_early, skip_late);
513			return;
514		}
515
516		// Decide on `actual_log_num_shares`, which also determines how many shared rounds we do.
517		// By default this would just be `self.log_num_shares`, but we will potentially decrease it
518		// in order to make sure that `2^log_num_shares * 2 <= data.len()`. This serves two
519		// purposes:
520		// - when we do the shared rounds, each thread should have at least 2 packed elements to
521		//   work with, see the precondition of [`forward_shared_layer`]
522		// - when we do the independent rounds, again each share should have `chunk.len() >= 2`
523		//   because this is required by [`forward_depth_first`]
524		let maximum_log_num_shares = log_d - P::LOG_WIDTH - 1;
525		let actual_log_num_shares = min(self.log_num_shares, maximum_log_num_shares);
526		let first_independent_layer = actual_log_num_shares;
527
528		let last_layer = log_d - skip_late;
529		let shared_layers = skip_early..min(first_independent_layer, last_layer);
530		let independent_layers = max(first_independent_layer, skip_early)..last_layer;
531
532		for layer in shared_layers {
533			forward_shared_layer(&self.domain_context, data, log_d, layer, actual_log_num_shares);
534		}
535
536		// One might think that we could just call `forward_depth_first` with
537		// `layer=independent_layers.start`. However, this would mean that the chunk size (that we
538		// split into using `par_chunks_mut`) could be just one packed element, or even less than
539		// one packed element.
540		let layer = min(independent_layers.start, maximum_log_num_shares);
541		let log_d_chunk = log_d - layer;
542		data.par_chunks_exact_mut(1 << (log_d_chunk - P::LOG_WIDTH))
543			.enumerate()
544			.for_each(|(block, chunk)| {
545				forward_depth_first(
546					&self.domain_context,
547					chunk,
548					log_d_chunk,
549					layer,
550					block,
551					independent_layers.clone(),
552					self.log_base_len,
553				);
554			});
555	}
556
557	fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
558		&self,
559		mut _data_orig: FieldSliceMut<P>,
560		_skip_early: usize,
561		_skip_late: usize,
562	) {
563		unimplemented!()
564	}
565
566	fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
567		&self.domain_context
568	}
569}
570
571#[cfg(test)]
572mod tests {
573	use super::*;
574
575	#[test]
576	fn test_rotate_right() {
577		assert_eq!(rotate_right(0b1001011, 7), 0b1100101);
578		assert_eq!(rotate_right(0b0001011, 7), 0b1000101);
579		assert_eq!(rotate_right(0b0001010, 7), 0b0000101);
580		assert_eq!(rotate_right(0b1, 1), 0b1);
581		assert_eq!(rotate_right(0b0, 1), 0b0);
582	}
583
584	#[test]
585	fn test_with_middle_bit() {
586		assert_eq!(with_middle_bit(0b000, 1), (0b0000, 0b0010));
587		assert_eq!(with_middle_bit(0b000, 2), (0b0000, 0b0100));
588		assert_eq!(with_middle_bit(0b000, 3), (0b0000, 0b1000));
589
590		assert_eq!(with_middle_bit(0b111, 1), (0b1101, 0b1111));
591		assert_eq!(with_middle_bit(0b111, 2), (0b1011, 0b1111));
592		assert_eq!(with_middle_bit(0b111, 3), (0b0111, 0b1111));
593
594		assert_eq!(with_middle_bit(0b1110110, 2), (0b11101010, 0b11101110));
595	}
596}