binius_math/ntt/
neighbors_last.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	cmp::{max, min},
5	iter,
6	ops::Range,
7	slice::from_raw_parts_mut,
8};
9
10use binius_field::{BinaryField, PackedField};
11use binius_utils::rayon::{
12	iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
13	slice::ParallelSliceMut,
14};
15
16use super::{
17	AdditiveNTT, DomainContext,
18	reference::{NeighborsLastReference, input_check},
19};
20use crate::field_buffer::FieldSliceMut;
21
22// This value is chosen assuming 128-bit field elements.
23//
24// Empirically it performs well and is small enough for the buffer to fit comfortably in L1 cache.
25const DEFAULT_LOG_BASE_LEN: usize = 10;
26
27/// Runs a **part** of an NTT butterfly network, in depth-first order.
28///
29/// Concretely, it processes a specific memory block in the butterfly network, which is given by
30/// `layer` and `block`. For this memory block, it processes the layers given by `layer_range`.
31///
32/// For example, suppose `layer=2` and `block=2`.
33/// That means we are in an NTT butterfly network in layer 2 (the third layer) and block 2 (the
34/// third block in this layer, there are four blocks in total in this layer). `data` contains the
35/// data of this block, so it's only a chunk of the total data used in the NTT. Now suppose
36/// `layer_range=2..5`. Then we will process the following butterfly blocks:
37/// - `layer=2` `block=2`
38/// - `layer=3` `block=4`
39/// - `layer=3` `block=5`
40/// - `layer=4` `block=8`
41/// - `layer=4` `block=9`
42/// - `layer=4` `block=10`
43/// - `layer=4` `block=11`
44///
45/// (Just in a different order. We listed breadth-first order, we would process them in
46/// depth-first order.)
47///
48/// The argument `log_base_len` determines for which `log_d` we call the breadth-first
49/// implementation as a base case.
50///
51/// ## Preconditions
52///
53/// - `2^(log_d) == data.len() * packing_width`
54/// - `data.len() >= 2`
55/// - `domain_context` holds all the twiddles up to `layer_range.end` (exclusive)
56/// - `layer <= layer_range.start`
57fn forward_depth_first<P: PackedField>(
58	domain_context: &impl DomainContext<Field = P::Scalar>,
59	data: &mut [P],
60	log_d: usize,
61	layer: usize,
62	block: usize,
63	mut layer_range: Range<usize>,
64	log_base_len: usize,
65) {
66	// check preconditions
67	debug_assert!(P::LOG_WIDTH < log_d);
68	debug_assert_eq!(data.len(), 1 << (log_d - P::LOG_WIDTH));
69	debug_assert!(layer_range.end <= domain_context.log_domain_size());
70	debug_assert!(layer <= layer_range.start);
71	debug_assert!(log_base_len > P::LOG_WIDTH);
72
73	let log_n = log_d + layer;
74	debug_assert!(layer_range.end <= log_n);
75
76	if layer >= layer_range.end {
77		return;
78	}
79
80	// if the problem size is small, we just do breadth_first (to get rid of the stack overhead)
81	if log_d <= log_base_len {
82		forward_breadth_first(domain_context, data, log_d, layer, block, layer_range);
83		return;
84	}
85
86	let block_size_half = 1 << (log_d - 1 - P::LOG_WIDTH);
87	if layer >= layer_range.start {
88		// process only one layer of this block
89		let twiddle = domain_context.twiddle(layer, block);
90		let packed_twiddle = P::broadcast(twiddle);
91		let (block0, block1) = data.split_at_mut(block_size_half);
92		for (u, v) in iter::zip(block0, block1) {
93			// perform butterfly
94			*u += *v * packed_twiddle;
95			*v += *u;
96		}
97
98		layer_range.start += 1;
99	}
100
101	// then recurse
102	forward_depth_first(
103		domain_context,
104		&mut data[..block_size_half],
105		log_d - 1,
106		layer + 1,
107		block << 1,
108		layer_range.clone(),
109		log_base_len,
110	);
111	forward_depth_first(
112		domain_context,
113		&mut data[block_size_half..],
114		log_d - 1,
115		layer + 1,
116		(block << 1) + 1,
117		layer_range,
118		log_base_len,
119	);
120}
121
122/// Same as [`forward_depth_first`], but runs in breadth-first order.
123///
124/// ## Preconditions
125///
126/// - `P::LOG_WIDTH < log_d`
127/// - `2^(log_d) == data.len() * packing_width`
128/// - `data.len() >= 2`
129/// - `domain_context` holds all the twiddles up to `layer_bound` (exclusive)
130/// - `layer <= layer_range.start`
131fn forward_breadth_first<P: PackedField>(
132	domain_context: &impl DomainContext<Field = P::Scalar>,
133	data: &mut [P],
134	log_d: usize,
135	base_layer: usize,
136	base_block: usize,
137	layer_range: Range<usize>,
138) {
139	// check preconditions
140	debug_assert!(P::LOG_WIDTH < log_d);
141	debug_assert_eq!(data.len(), 1 << (log_d - P::LOG_WIDTH));
142	debug_assert!(layer_range.end <= domain_context.log_domain_size());
143	debug_assert!(base_layer <= layer_range.start);
144
145	let log_n = log_d + base_layer;
146	debug_assert!(layer_range.end <= log_n);
147
148	let packed_cutoff = (log_n - P::LOG_WIDTH).clamp(layer_range.start, layer_range.end);
149
150	// In these rounds, layer <= log_n - P::LOG_WIDTH. All butterflies are between values in
151	// separate packed elements, and all butterflies within a block share the same twiddle factor.
152	for layer in layer_range.start..packed_cutoff {
153		// log_block_size is log2 the number of packed elements forming one block.
154		let log_block_size = log_n - P::LOG_WIDTH - layer;
155		let log_half_block_size = log_block_size - 1;
156
157		// log2 the number of blocks to process in this layer
158		let log_blocks = layer - base_layer;
159		let layer_twiddles = domain_context
160			.iter_twiddles(layer, 0)
161			.skip(base_block << log_blocks)
162			.take(1 << log_blocks);
163		let blocks = data.chunks_exact_mut(1 << log_block_size);
164		for (block, twiddle) in iter::zip(blocks, layer_twiddles) {
165			let packed_twiddle = P::broadcast(twiddle);
166			let (block0, block1) = block.split_at_mut(1 << log_half_block_size);
167			for (u, v) in iter::zip(block0, block1) {
168				// perform butterfly
169				*u += *v * packed_twiddle;
170				*v += *u;
171			}
172		}
173	}
174
175	// In these rounds, layer > log_n - P::LOG_WIDTH. The butterflies operate on elements within
176	// packed field elements. We solve this problem by interleaving the packed elements with each
177	// other.
178	for layer in packed_cutoff..layer_range.end {
179		// log_block_size is log2 the number of single elements forming one block.
180		let log_block_size = log_n - layer;
181		let log_half_block_size = log_block_size - 1;
182		let log_blocks_per_packed = P::LOG_WIDTH - log_block_size;
183		let log_half_blocks_per_packed = log_blocks_per_packed + 1;
184
185		// calculate packed_twiddle_offset
186		let mut packed_twiddle_offset = P::zero();
187		for block in 0..1 << log_blocks_per_packed {
188			let twiddle0 = domain_context.twiddle(layer, block);
189			let twiddle1 = domain_context.twiddle(layer, (1 << log_blocks_per_packed) | block);
190
191			let block_start = block << log_block_size;
192			for j in 0..1 << log_half_block_size {
193				packed_twiddle_offset.set(block_start | j, twiddle0);
194				packed_twiddle_offset.set(block_start | j | (1 << log_half_block_size), twiddle1);
195			}
196		}
197
198		// log2 the number of packed element pairs to process in this layer
199		let log_packed_pairs = packed_cutoff - base_layer - 1;
200		let layer_twiddles = domain_context
201			.iter_twiddles(layer, log_half_blocks_per_packed)
202			.skip(base_block << log_packed_pairs)
203			.take(1 << log_packed_pairs);
204
205		let (data_pairs, rest) = data.as_chunks_mut::<2>();
206		debug_assert!(
207			rest.is_empty(),
208			"data_packed length is a power of two; \
209				data_packed length is greater than 1 (checked at beginning of method)"
210		);
211		debug_assert_eq!(data_pairs.len(), 1 << log_packed_pairs);
212
213		for ([packed0, packed1], first_twiddle) in iter::zip(data_pairs, layer_twiddles) {
214			let packed_twiddle = P::broadcast(first_twiddle) + packed_twiddle_offset;
215
216			let (mut u, mut v) = (*packed0).interleave(*packed1, log_half_block_size);
217			u += v * packed_twiddle;
218			v += u;
219			(*packed0, *packed1) = u.interleave(v, log_half_block_size);
220		}
221	}
222}
223
224/// Process a layer of the NTT butterfly network in parallel by splitting the work up into
225/// `2^log_num_shares` many shares. This will also split up single *blocks* into multiple shares.
226///
227/// (The latter is the whole purpose of this function. If the number of shares is small enough (and
228/// the number of blocks is big enough) so that we don't need to split up blocks, we could just run
229/// [`forward_depth_first`] on disjoint chunks.)
230///
231/// - `2^(log_d) == data.len() * packing_width`
232/// - **Important:** `2^log_num_shares * 2 <= data.len()` (every share is working with whole packed
233///   elements, so every share needs at least 2 packed elements)
234/// - `domain_context` holds the twiddles of `layer`
235fn forward_shared_layer<P: PackedField>(
236	domain_context: &(impl DomainContext<Field = P::Scalar> + Sync),
237	data: &mut [P],
238	log_d: usize,
239	layer: usize,
240	log_num_shares: usize,
241) {
242	// check preconditions
243	debug_assert_eq!(data.len() * (1 << P::LOG_WIDTH), 1 << log_d);
244	debug_assert!(1 << (log_num_shares + 1) <= data.len());
245	debug_assert!(layer < domain_context.log_domain_size());
246
247	let log_num_chunks = log_num_shares + 1;
248	let log_d_chunk = log_d - log_num_chunks;
249	let data_ptr = data.as_mut_ptr();
250	let shift = log_num_shares - layer;
251	let tasks: Vec<_> = (0..1 << log_num_shares)
252		.map(|k| {
253			let (chunk0, chunk1) = with_middle_bit(k, shift);
254			let block = chunk0 >> (log_num_chunks - layer);
255			assert!(P::LOG_WIDTH <= log_d_chunk);
256			let log_chunk_len = log_d_chunk - P::LOG_WIDTH;
257			let chunk0 = unsafe {
258				from_raw_parts_mut(data_ptr.add(chunk0 << log_chunk_len), 1 << log_chunk_len)
259			};
260			let chunk1 = unsafe {
261				from_raw_parts_mut(data_ptr.add(chunk1 << log_chunk_len), 1 << log_chunk_len)
262			};
263			let twiddle = domain_context.twiddle(layer, block);
264			let twiddle = P::broadcast(twiddle);
265			(chunk0, chunk1, twiddle)
266		})
267		.collect();
268
269	tasks.into_par_iter().for_each(|(chunk0, chunk1, twiddle)| {
270		for i in 0..chunk0.len() {
271			let mut u = chunk0[i];
272			let mut v = chunk1[i];
273			u += v * twiddle;
274			v += u;
275			chunk0[i] = u;
276			chunk1[i] = v;
277		}
278	});
279}
280
281/// Inserts a bit into `k`. Returns both the version with `0` inserted and `1` inserted.
282///
283/// The first `shift` bits are preserved, then `0` or `1` is inserted, and then the remaining bits
284/// of `k` follow.
285///
286/// ## Preconditions
287///
288/// - `shift` must be strictly greater than 0
289fn with_middle_bit(k: usize, shift: usize) -> (usize, usize) {
290	assert!(shift >= 1);
291
292	// most significant and least significant bits, overlapping in one bit
293	let ms = k >> (shift - 1);
294	let ls = k & ((1 << shift) - 1);
295
296	let k0 = ls | ((ms & !1) << shift);
297	let k1 = ls | ((ms | 1) << shift);
298
299	(k0, k1)
300}
301
302#[derive(Debug)]
303pub struct NeighborsLastBreadthFirst<DC> {
304	/// The domain context from which the twiddles are pulled.
305	pub domain_context: DC,
306}
307
308impl<F, DC> AdditiveNTT for NeighborsLastBreadthFirst<DC>
309where
310	F: BinaryField,
311	DC: DomainContext<Field = F>,
312{
313	type Field = F;
314
315	fn forward_transform<P: PackedField<Scalar = F>>(
316		&self,
317		mut data: FieldSliceMut<P>,
318		skip_early: usize,
319		skip_late: usize,
320	) {
321		let log_d = data.log_len();
322		if log_d <= P::LOG_WIDTH {
323			let fallback_ntt = NeighborsLastReference {
324				domain_context: &self.domain_context,
325			};
326			return fallback_ntt.forward_transform(data, skip_early, skip_late);
327		}
328
329		input_check(&self.domain_context, log_d, skip_early, skip_late);
330
331		forward_breadth_first(
332			self.domain_context(),
333			data.as_mut(),
334			log_d,
335			0,
336			0,
337			skip_early..(log_d - skip_late),
338		);
339	}
340
341	fn inverse_transform<P: PackedField<Scalar = F>>(
342		&self,
343		_data: FieldSliceMut<P>,
344		_skip_early: usize,
345		_skip_late: usize,
346	) {
347		todo!()
348	}
349
350	fn domain_context(&self) -> &impl DomainContext<Field = F> {
351		&self.domain_context
352	}
353}
354
355/// A single-threaded implementation of [`AdditiveNTT`].
356///
357/// The code only makes sure that it's fast for a _large_ data input.
358/// For small inputs, it can be comparatively slow!
359///
360/// The implementation is depth-first, but calls a breadth-first implementation as a base case.
361///
362/// Note that "neighbors last" refers to the memory layout for the NTT: In the _last_ layer of this
363/// NTT algorithm, neighboring elements speak to each other. In the classic FFT that's usually the
364/// case for "decimation in frequency".
365#[derive(Debug)]
366pub struct NeighborsLastSingleThread<DC> {
367	/// The domain context from which the twiddles are pulled.
368	pub domain_context: DC,
369	/// Determines when to switch from depth-first to the breadth-first base case.
370	pub log_base_len: usize,
371}
372
373impl<DC> NeighborsLastSingleThread<DC> {
374	/// Convenience constructor which sets `log_base_len` to a reasonable default.
375	pub fn new(domain_context: DC) -> Self {
376		Self {
377			domain_context,
378			log_base_len: DEFAULT_LOG_BASE_LEN,
379		}
380	}
381}
382
383impl<DC: DomainContext> AdditiveNTT for NeighborsLastSingleThread<DC> {
384	type Field = DC::Field;
385
386	fn forward_transform<P: PackedField<Scalar = Self::Field>>(
387		&self,
388		mut data: FieldSliceMut<P>,
389		skip_early: usize,
390		skip_late: usize,
391	) {
392		let log_d = data.log_len();
393		if log_d <= P::LOG_WIDTH {
394			let fallback_ntt = NeighborsLastReference {
395				domain_context: &self.domain_context,
396			};
397			return fallback_ntt.forward_transform(data, skip_early, skip_late);
398		}
399
400		input_check(&self.domain_context, log_d, skip_early, skip_late);
401
402		forward_depth_first(
403			&self.domain_context,
404			data.as_mut(),
405			log_d,
406			0,
407			0,
408			skip_early..(log_d - skip_late),
409			// Ensures that log_base_len satisfies precondition
410			self.log_base_len.max(P::LOG_WIDTH + 1),
411		);
412	}
413
414	fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
415		&self,
416		_data_orig: FieldSliceMut<P>,
417		_skip_early: usize,
418		_skip_late: usize,
419	) {
420		unimplemented!()
421	}
422
423	fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
424		&self.domain_context
425	}
426}
427
428/// A multi-threaded implementation of [`AdditiveNTT`].
429///
430/// The code only makes sure that it's fast for a _large_ data input.
431/// For small inputs, it can be comparatively slow!
432///
433/// The implementation is depth-first, but calls a breadth-first implementation as a base case.
434///
435/// Note that "neighbors last" refers to the memory layout for the NTT: In the _last_ layer of this
436/// NTT algorithm, neighboring elements speak to each other. In the classic FFT that's usually the
437/// case for "decimation in frequency".
438#[derive(Debug)]
439pub struct NeighborsLastMultiThread<DC> {
440	/// The domain context from which the twiddles are pulled.
441	pub domain_context: DC,
442	/// Determines when to switch from depth-first to the breadth-first base case.
443	pub log_base_len: usize,
444	/// The base-2 logarithm of number of equal-sized shares that the problem should be split into.
445	/// Each share needs to do the same amount of work. If you have equally powered cores
446	/// available, this should be the base-2 logarithm of the number of cores.
447	pub log_num_shares: usize,
448}
449
450impl<DC> NeighborsLastMultiThread<DC> {
451	/// Convenience constructor which sets `log_base_len` to a reasonable default.
452	pub fn new(domain_context: DC, log_num_shares: usize) -> Self {
453		Self {
454			domain_context,
455			log_base_len: DEFAULT_LOG_BASE_LEN,
456			log_num_shares,
457		}
458	}
459}
460
461impl<DC: DomainContext + Sync> AdditiveNTT for NeighborsLastMultiThread<DC> {
462	type Field = DC::Field;
463
464	fn forward_transform<P: PackedField<Scalar = Self::Field>>(
465		&self,
466		mut data: FieldSliceMut<P>,
467		skip_early: usize,
468		skip_late: usize,
469	) {
470		let log_d = data.log_len();
471		if log_d <= P::LOG_WIDTH {
472			let fallback_ntt = NeighborsLastReference {
473				domain_context: &self.domain_context,
474			};
475			return fallback_ntt.forward_transform(data, skip_early, skip_late);
476		}
477
478		input_check(&self.domain_context, log_d, skip_early, skip_late);
479
480		// Decide on `actual_log_num_shares`, which also determines how many shared rounds we do.
481		// By default this would just be `self.log_num_shares`, but we will potentially decrease it
482		// in order to make sure that `2^log_num_shares * 2 <= data.len()`. This serves two
483		// purposes:
484		// - when we do the shared rounds, each thread should have at least 2 packed elements to
485		//   work with, see the precondition of [`forward_shared_layer`]
486		// - when we do the independent rounds, again each share should have `chunk.len() >= 2`
487		//   because this is required by [`forward_depth_first`]
488		let maximum_log_num_shares = log_d - P::LOG_WIDTH - 1;
489		let actual_log_num_shares = min(self.log_num_shares, maximum_log_num_shares);
490		let first_independent_layer = actual_log_num_shares;
491
492		let last_layer = log_d - skip_late;
493		let shared_layers = skip_early..min(first_independent_layer, last_layer);
494		let independent_layers = max(first_independent_layer, skip_early)..last_layer;
495
496		for layer in shared_layers {
497			forward_shared_layer(
498				&self.domain_context,
499				data.as_mut(),
500				log_d,
501				layer,
502				actual_log_num_shares,
503			);
504		}
505
506		// One might think that we could just call `forward_depth_first` with
507		// `layer=independent_layers.start`. However, this would mean that the chunk size (that we
508		// split into using `par_chunks_mut`) could be just one packed element, or even less than
509		// one packed element.
510		let layer = min(independent_layers.start, maximum_log_num_shares);
511		let log_d_chunk = log_d - layer;
512		data.as_mut()
513			.par_chunks_exact_mut(1 << (log_d_chunk - P::LOG_WIDTH))
514			.enumerate()
515			.for_each(|(block, chunk)| {
516				forward_depth_first(
517					&self.domain_context,
518					chunk,
519					log_d_chunk,
520					layer,
521					block,
522					independent_layers.clone(),
523					self.log_base_len,
524				);
525			});
526	}
527
528	fn inverse_transform<P: PackedField<Scalar = Self::Field>>(
529		&self,
530		_data_orig: FieldSliceMut<P>,
531		_skip_early: usize,
532		_skip_late: usize,
533	) {
534		unimplemented!()
535	}
536
537	fn domain_context(&self) -> &impl DomainContext<Field = DC::Field> {
538		&self.domain_context
539	}
540}
541
542#[cfg(test)]
543mod tests {
544	use super::*;
545
546	#[test]
547	fn test_with_middle_bit() {
548		assert_eq!(with_middle_bit(0b000, 1), (0b0000, 0b0010));
549		assert_eq!(with_middle_bit(0b000, 2), (0b0000, 0b0100));
550		assert_eq!(with_middle_bit(0b000, 3), (0b0000, 0b1000));
551
552		assert_eq!(with_middle_bit(0b111, 1), (0b1101, 0b1111));
553		assert_eq!(with_middle_bit(0b111, 2), (0b1011, 0b1111));
554		assert_eq!(with_middle_bit(0b111, 3), (0b0111, 0b1111));
555
556		assert_eq!(with_middle_bit(0b1110110, 2), (0b11101010, 0b11101110));
557	}
558}