binius_ntt/
single_threaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp, marker::PhantomData};
4
5use binius_field::{BinaryField, PackedField, TowerField};
6use binius_math::BinarySubspace;
7use binius_utils::bail;
8
9use super::{
10	additive_ntt::{AdditiveNTT, NTTShape},
11	error::Error,
12	twiddle::TwiddleAccess,
13};
14use crate::twiddle::{OnTheFlyTwiddleAccess, PrecomputedTwiddleAccess, expand_subspace_evals};
15
16/// Implementation of `AdditiveNTT` that performs the computation single-threaded.
17#[derive(Debug)]
18pub struct SingleThreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F>> {
19	// TODO: Figure out how to make this private, it should not be `pub(super)`.
20	pub(super) s_evals: Vec<TA>,
21	_marker: PhantomData<F>,
22}
23
24impl<F: BinaryField> SingleThreadedNTT<F> {
25	/// Default constructor constructs an NTT over the canonical subspace for the field using
26	/// on-the-fly computed twiddle factors.
27	pub fn new(log_domain_size: usize) -> Result<Self, Error> {
28		let subspace = BinarySubspace::with_dim(log_domain_size)?;
29		Self::with_subspace(&subspace)
30	}
31
32	/// Constructs an NTT over an isomorphic subspace for the given domain field using on-the-fly
33	/// computed twiddle factors.
34	pub fn with_domain_field<FDomain>(log_domain_size: usize) -> Result<Self, Error>
35	where
36		FDomain: BinaryField,
37		F: From<FDomain>,
38	{
39		let subspace = BinarySubspace::<FDomain>::with_dim(log_domain_size)?.isomorphic();
40		Self::with_subspace(&subspace)
41	}
42
43	pub fn with_subspace(subspace: &BinarySubspace<F>) -> Result<Self, Error> {
44		let twiddle_access = OnTheFlyTwiddleAccess::generate(subspace)?;
45		Ok(Self::with_twiddle_access(twiddle_access))
46	}
47
48	pub fn precompute_twiddles(&self) -> SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>> {
49		SingleThreadedNTT::with_twiddle_access(expand_subspace_evals(&self.s_evals))
50	}
51}
52
53impl<F: TowerField> SingleThreadedNTT<F> {
54	/// A specialization of [`with_domain_field`](Self::with_domain_field) to the canonical tower
55	/// field.
56	pub fn with_canonical_field(log_domain_size: usize) -> Result<Self, Error> {
57		Self::with_domain_field::<F::Canonical>(log_domain_size)
58	}
59}
60
61impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
62	const fn with_twiddle_access(twiddle_access: Vec<TA>) -> Self {
63		Self {
64			s_evals: twiddle_access,
65			_marker: PhantomData,
66		}
67	}
68}
69
70impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
71	pub fn twiddles(&self) -> &[TA] {
72		&self.s_evals
73	}
74}
75
76impl<F, TA> AdditiveNTT<F> for SingleThreadedNTT<F, TA>
77where
78	F: BinaryField,
79	TA: TwiddleAccess<F>,
80{
81	fn log_domain_size(&self) -> usize {
82		self.s_evals.len()
83	}
84
85	fn subspace(&self, i: usize) -> BinarySubspace<F> {
86		let (subspace, shift) = self.s_evals[i].affine_subspace();
87		debug_assert_eq!(shift, F::ZERO, "s_evals subspaces must be linear by construction");
88		subspace
89	}
90
91	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
92		self.s_evals[i].get(j)
93	}
94
95	fn forward_transform<P: PackedField<Scalar = F>>(
96		&self,
97		data: &mut [P],
98		shape: NTTShape,
99		coset: usize,
100		coset_bits: usize,
101		skip_rounds: usize,
102	) -> Result<(), Error> {
103		forward_transform(
104			self.log_domain_size(),
105			&self.s_evals,
106			data,
107			shape,
108			coset,
109			coset_bits,
110			skip_rounds,
111		)
112	}
113
114	fn inverse_transform<P: PackedField<Scalar = F>>(
115		&self,
116		data: &mut [P],
117		shape: NTTShape,
118		coset: usize,
119		coset_bits: usize,
120		skip_rounds: usize,
121	) -> Result<(), Error> {
122		inverse_transform(
123			self.log_domain_size(),
124			&self.s_evals,
125			data,
126			shape,
127			coset,
128			coset_bits,
129			skip_rounds,
130		)
131	}
132}
133
134pub fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
135	log_domain_size: usize,
136	s_evals: &[impl TwiddleAccess<F>],
137	data: &mut [P],
138	shape: NTTShape,
139	coset: usize,
140	coset_bits: usize,
141	skip_rounds: usize,
142) -> Result<(), Error> {
143	check_batch_transform_inputs_and_params(
144		log_domain_size,
145		data,
146		shape,
147		coset,
148		coset_bits,
149		skip_rounds,
150	)?;
151
152	match data.len() {
153		0 => return Ok(()),
154		1 => {
155			return match P::LOG_WIDTH {
156				0 => Ok(()),
157				_ => {
158					// Special case when there is only one packed element: since we cannot
159					// interleave with another packed element, the code below will panic when there
160					// is only one.
161					//
162					// Handle the case of one packed element by batch transforming the original
163					// data with dummy data and extracting the transformed result.
164					let mut buffer = [data[0], P::zero()];
165					forward_transform(
166						log_domain_size,
167						s_evals,
168						&mut buffer,
169						shape,
170						coset,
171						coset_bits,
172						skip_rounds,
173					)?;
174					data[0] = buffer[0];
175					Ok(())
176				}
177			};
178		}
179		_ => {}
180	};
181
182	let NTTShape {
183		log_x,
184		log_y,
185		log_z,
186	} = shape;
187
188	let log_w = P::LOG_WIDTH;
189
190	// Cutoff is the stage of the NTT where each the butterfly units are contained within
191	// packed base field elements.
192	let cutoff = log_w.saturating_sub(log_x);
193
194	// Choose the twiddle factors so that NTTs on differently sized domains, with the same
195	// coset_bits, share the beginning layer twiddles.
196	let s_evals = &s_evals[log_domain_size - (log_y + coset_bits)..];
197
198	// i indexes the layer of the NTT network, also the binary subspace.
199	for i in (cutoff..(log_y - skip_rounds)).rev() {
200		let s_evals_i = &s_evals[i];
201		let coset_offset = coset << (log_y - 1 - i);
202
203		// j indexes the outer Z tensor axis.
204		for j in 0..1 << log_z {
205			// k indexes the block within the layer. Each block performs butterfly operations with
206			// the same twiddle factor.
207			for k in 0..1 << (log_y - 1 - i) {
208				let twiddle = s_evals_i.get(coset_offset | k);
209				for l in 0..1 << (i + log_x - log_w) {
210					let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
211					let idx1 = idx0 | 1 << (log_x + i - log_w);
212					data[idx0] += data[idx1] * twiddle;
213					data[idx1] += data[idx0];
214				}
215			}
216		}
217	}
218
219	for i in (0..cmp::min(cutoff, log_y - skip_rounds)).rev() {
220		let s_evals_i = &s_evals[i];
221		let coset_offset = coset << (log_y - 1 - i);
222
223		// A block is a block of butterfly units that all have the same twiddle factor. Since we
224		// are below the cutoff round, the block length is less than the packing width, and
225		// therefore each packed multiplication is with a non-uniform twiddle. Since the subspace
226		// polynomials are linear, we can calculate an additive factor that can be added to the
227		// packed twiddles for all packed butterfly units.
228		let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
229
230		let log_block_len = i + log_x;
231		let log_packed_count = (log_y - 1).saturating_sub(cutoff);
232		for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
233			for k in 0..1 << log_packed_count {
234				let twiddle =
235					P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
236				let index = k << 1 | j << (log_packed_count + 1);
237				let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
238				u += v * twiddle;
239				v += u;
240				(data[index], data[index | 1]) = u.interleave(v, log_block_len);
241			}
242		}
243	}
244
245	Ok(())
246}
247
248pub fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
249	log_domain_size: usize,
250	s_evals: &[impl TwiddleAccess<F>],
251	data: &mut [P],
252	shape: NTTShape,
253	coset: usize,
254	coset_bits: usize,
255	skip_rounds: usize,
256) -> Result<(), Error> {
257	check_batch_transform_inputs_and_params(
258		log_domain_size,
259		data,
260		shape,
261		coset,
262		coset_bits,
263		skip_rounds,
264	)?;
265
266	match data.len() {
267		0 => return Ok(()),
268		1 => {
269			return match P::LOG_WIDTH {
270				0 => Ok(()),
271				_ => {
272					// Special case when there is only one packed element: since we cannot
273					// interleave with another packed element, the code below will panic when there
274					// is only one.
275					//
276					// Handle the case of one packed element by batch transforming the original
277					// data with dummy data and extracting the transformed result.
278					let mut buffer = [data[0], P::zero()];
279					inverse_transform(
280						log_domain_size,
281						s_evals,
282						&mut buffer,
283						shape,
284						coset,
285						coset_bits,
286						skip_rounds,
287					)?;
288					data[0] = buffer[0];
289					Ok(())
290				}
291			};
292		}
293		_ => {}
294	};
295
296	let NTTShape {
297		log_x,
298		log_y,
299		log_z,
300	} = shape;
301
302	let log_w = P::LOG_WIDTH;
303
304	// Cutoff is the stage of the NTT where each the butterfly units are contained within
305	// packed base field elements.
306	let cutoff = log_w.saturating_sub(log_x);
307
308	// Choose the twiddle factors so that NTTs on differently sized domains, with the same
309	// coset_bits, share the final layer twiddles.
310	let s_evals = &s_evals[log_domain_size - (log_y + coset_bits)..];
311
312	#[allow(clippy::needless_range_loop)]
313	for i in 0..cutoff.min(log_y - skip_rounds) {
314		let s_evals_i = &s_evals[i];
315		let coset_offset = coset << (log_y - 1 - i);
316
317		// A block is a block of butterfly units that all have the same twiddle factor. Since we
318		// are below the cutoff round, the block length is less than the packing width, and
319		// therefore each packed multiplication is with a non-uniform twiddle. Since the subspace
320		// polynomials are linear, we can calculate an additive factor that can be added to the
321		// packed twiddles for all packed butterfly units.
322		let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
323
324		let log_block_len = i + log_x;
325		let log_packed_count = (log_y - 1).saturating_sub(cutoff);
326		for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
327			for k in 0..1 << log_packed_count {
328				let twiddle =
329					P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
330				let index = k << 1 | j << (log_packed_count + 1);
331				let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
332				v += u;
333				u += v * twiddle;
334				(data[index], data[index | 1]) = u.interleave(v, log_block_len);
335			}
336		}
337	}
338
339	// i indexes the layer of the NTT network, also the binary subspace.
340	#[allow(clippy::needless_range_loop)]
341	for i in cutoff..(log_y - skip_rounds) {
342		let s_evals_i = &s_evals[i];
343		let coset_offset = coset << (log_y - 1 - i);
344
345		// j indexes the outer Z tensor axis.
346		for j in 0..1 << log_z {
347			// k indexes the block within the layer. Each block performs butterfly operations with
348			// the same twiddle factor.
349			for k in 0..1 << (log_y - 1 - i) {
350				let twiddle = s_evals_i.get(coset_offset | k);
351				for l in 0..1 << (i + log_x - log_w) {
352					let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
353					let idx1 = idx0 | 1 << (log_x + i - log_w);
354					data[idx1] += data[idx0];
355					data[idx0] += data[idx1] * twiddle;
356				}
357			}
358		}
359	}
360
361	Ok(())
362}
363
364pub fn check_batch_transform_inputs_and_params<PB: PackedField>(
365	log_domain_size: usize,
366	data: &[PB],
367	shape: NTTShape,
368	coset: usize,
369	coset_bits: usize,
370	skip_rounds: usize,
371) -> Result<(), Error> {
372	let NTTShape {
373		log_x,
374		log_y,
375		log_z,
376	} = shape;
377
378	if !data.len().is_power_of_two() {
379		bail!(Error::PowerOfTwoLengthRequired);
380	}
381	if skip_rounds > log_y {
382		bail!(Error::SkipRoundsTooLarge);
383	}
384
385	let full_sized_y = (data.len() * PB::WIDTH) >> (log_x + log_z);
386
387	// Verify that our log_y exactly matches the data length, except when we are NTT-ing one packed
388	// field
389	if (1 << log_y != full_sized_y && data.len() > 2) || (1 << log_y > full_sized_y) {
390		bail!(Error::BatchTooLarge);
391	}
392
393	if coset >= (1 << coset_bits) {
394		bail!(Error::CosetIndexOutOfBounds { coset, coset_bits });
395	}
396
397	// The domain size should be at least large enough to represent the given coset.
398	let log_required_domain_size = log_y + coset_bits;
399	if log_required_domain_size > log_domain_size {
400		bail!(Error::DomainTooSmall {
401			log_required_domain_size
402		});
403	}
404
405	Ok(())
406}
407
408#[inline]
409fn calculate_packed_additive_twiddle<P>(
410	s_evals: &impl TwiddleAccess<P::Scalar>,
411	shape: NTTShape,
412	ntt_round: usize,
413) -> P
414where
415	P: PackedField<Scalar: BinaryField>,
416{
417	let NTTShape {
418		log_x,
419		log_y,
420		log_z,
421	} = shape;
422	debug_assert!(log_y > 0);
423
424	let log_block_len = ntt_round + log_x;
425	debug_assert!(log_block_len < P::LOG_WIDTH);
426
427	let packed_log_len = (log_x + log_y + log_z).min(P::LOG_WIDTH);
428	let log_blocks_count = packed_log_len.saturating_sub(log_block_len + 1);
429
430	let packed_log_z = packed_log_len.saturating_sub(log_x + log_y);
431	let packed_log_y = packed_log_len - packed_log_z - log_x;
432
433	let twiddle_stride = P::LOG_WIDTH
434		.saturating_sub(log_x)
435		.min(log_blocks_count - packed_log_z);
436
437	let mut twiddle = P::default();
438	for i in 0..1 << (log_blocks_count - twiddle_stride) {
439		for j in 0..1 << twiddle_stride {
440			let (subblock_twiddle_0, subblock_twiddle_1) = if packed_log_y == log_y {
441				let same_twiddle = s_evals.get(j);
442				(same_twiddle, same_twiddle)
443			} else {
444				s_evals.get_pair(twiddle_stride, j)
445			};
446			let idx0 = j << (log_block_len + 1) | i << (log_block_len + twiddle_stride + 1);
447			let idx1 = idx0 | 1 << log_block_len;
448
449			for k in 0..1 << log_block_len {
450				twiddle.set(idx0 | k, subblock_twiddle_0);
451				twiddle.set(idx1 | k, subblock_twiddle_1);
452			}
453		}
454	}
455	twiddle
456}
457
458#[cfg(test)]
459mod tests {
460	use std::iter::repeat_with;
461
462	use assert_matches::assert_matches;
463	use binius_field::{
464		BinaryField8b, BinaryField16b, Field, PackedBinaryField8x16b, PackedFieldIndexable,
465	};
466	use binius_math::Error as MathError;
467	use rand::{SeedableRng, rngs::StdRng};
468
469	use super::*;
470
471	#[test]
472	fn test_additive_ntt_fails_with_field_too_small() {
473		assert_matches!(
474			SingleThreadedNTT::<BinaryField8b>::new(10),
475			Err(Error::MathError(MathError::DomainSizeTooLarge))
476		);
477	}
478
479	#[test]
480	fn test_subspace_size_agrees_with_domain_size() {
481		let ntt = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
482		assert_eq!(ntt.subspace(0).dim(), 10);
483		assert_eq!(ntt.subspace(9).dim(), 1);
484	}
485
486	/// The additive NTT has a useful property that the NTT output of an internally zero-padded
487	/// input has the structure of repeating codeword symbols.
488	///
489	/// More precisely, let $m \in K^{2^\ell}$ be a message and $c = \text{NTT}_{\ell}(m)$ be its
490	/// NTT evaluation on the domain $S^{(\ell)}$. Define $m' \in K^{2^{\ell+\nu}}$ to be a
491	/// sequence with $m'_{i * 2^\nu} = m_i$ and $m'_j = 0$ when $j \ne 0 \mod 2^\nu$. The
492	/// property is that $c' = \text{NTT}_{\ell+\nu}(m')$ will have the structure
493	/// $c'_j = c_{\lfloor j / 2^\nu \rfloor}$.
494	///
495	/// So for $\nu = 2$, then $m' = (m_0, 0, 0, 0, m_1, 0, 0, 0, \ldots, m_{\ell-1}, 0, 0, 0)$ and
496	/// $c' = (c_0, c_0, c_0, c_0, c_1, c_1, c_1, c_1, \ldots, c_{\ell-1}, c_{\ell-1}, c_{\ell-1},
497	/// c_{\ell-1})$.
498	#[test]
499	fn test_repetition_property() {
500		let log_len = 8;
501		let ntt = SingleThreadedNTT::<BinaryField16b>::new(log_len + 2).unwrap();
502
503		let mut rng = StdRng::seed_from_u64(0);
504		let msg = repeat_with(|| <BinaryField16b as Field>::random(&mut rng))
505			.take(1 << log_len)
506			.collect::<Vec<_>>();
507
508		let mut msg_padded = vec![BinaryField16b::ZERO; 1 << (log_len + 2)];
509		for i in 0..1 << log_len {
510			msg_padded[i << 2] = msg[i];
511		}
512
513		let mut out = msg;
514		ntt.forward_transform(
515			&mut out,
516			NTTShape {
517				log_y: log_len,
518				..Default::default()
519			},
520			0,
521			0,
522			0,
523		)
524		.unwrap();
525		let mut out_rep = msg_padded;
526		ntt.forward_transform(
527			&mut out_rep,
528			NTTShape {
529				log_y: log_len + 2,
530				..Default::default()
531			},
532			0,
533			0,
534			0,
535		)
536		.unwrap();
537		for i in 0..1 << (log_len + 2) {
538			assert_eq!(out_rep[i], out[i >> 2]);
539		}
540	}
541
542	#[test]
543	fn one_packed_field_forward() {
544		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
545		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
546
547		let mut packed_copy = packed;
548
549		let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
550
551		let shape = NTTShape {
552			log_x: 0,
553			log_y: 3,
554			log_z: 0,
555		};
556		let _ = s.forward_transform(&mut packed, shape, 3, 2, 0);
557		let _ = s.forward_transform(unpacked, shape, 3, 2, 0);
558
559		for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
560			assert_eq!(packed[0].get(i), *unpacked_item);
561		}
562	}
563
564	#[test]
565	fn one_packed_field_inverse() {
566		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
567		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
568
569		let mut packed_copy = packed;
570
571		let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
572
573		let shape = NTTShape {
574			log_x: 0,
575			log_y: 3,
576			log_z: 0,
577		};
578		let _ = s.inverse_transform(&mut packed, shape, 3, 2, 0);
579		let _ = s.inverse_transform(unpacked, shape, 3, 2, 0);
580
581		for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
582			assert_eq!(packed[0].get(i), *unpacked_item);
583		}
584	}
585
586	#[test]
587	fn smaller_embedded_batch_forward() {
588		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
589		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
590
591		let mut packed_copy = packed;
592
593		let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
594
595		let shape = NTTShape {
596			log_x: 0,
597			log_y: 2,
598			log_z: 0,
599		};
600		let _ = forward_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 2, 0);
601		let _ = s.forward_transform(unpacked, shape, 3, 2, 0);
602
603		for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
604			assert_eq!(packed[0].get(i), *unpacked_item);
605		}
606	}
607
608	#[test]
609	fn smaller_embedded_batch_inverse() {
610		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
611		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
612
613		let mut packed_copy = packed;
614
615		let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
616
617		let shape = NTTShape {
618			log_x: 0,
619			log_y: 2,
620			log_z: 0,
621		};
622		let _ = inverse_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 2, 0);
623		let _ = s.inverse_transform(unpacked, shape, 3, 2, 0);
624
625		for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
626			assert_eq!(packed[0].get(i), *unpacked_item);
627		}
628	}
629
630	// TODO: Write test that compares polynomial evaluation via additive NTT with naive Lagrange
631	// polynomial interpolation. A randomized test should suffice for larger NTT sizes.
632}