binius_ntt/
single_threaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp, marker::PhantomData};
4
5use binius_field::{BinaryField, PackedField, TowerField};
6use binius_math::BinarySubspace;
7
8use super::{
9	additive_ntt::{AdditiveNTT, NTTShape},
10	error::Error,
11	twiddle::TwiddleAccess,
12};
13use crate::twiddle::{expand_subspace_evals, OnTheFlyTwiddleAccess, PrecomputedTwiddleAccess};
14
15/// Implementation of `AdditiveNTT` that performs the computation single-threaded.
16#[derive(Debug)]
17pub struct SingleThreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F>> {
18	// TODO: Figure out how to make this private, it should not be `pub(super)`.
19	pub(super) s_evals: Vec<TA>,
20	_marker: PhantomData<F>,
21}
22
23impl<F: BinaryField> SingleThreadedNTT<F> {
24	/// Default constructor constructs an NTT over the canonical subspace for the field using
25	/// on-the-fly computed twiddle factors.
26	pub fn new(log_domain_size: usize) -> Result<Self, Error> {
27		let subspace = BinarySubspace::with_dim(log_domain_size)?;
28		Self::with_subspace(&subspace)
29	}
30
31	/// Constructs an NTT over an isomorphic subspace for the given domain field using on-the-fly
32	/// computed twiddle factors.
33	pub fn with_domain_field<FDomain>(log_domain_size: usize) -> Result<Self, Error>
34	where
35		FDomain: BinaryField,
36		F: From<FDomain>,
37	{
38		let subspace = BinarySubspace::<FDomain>::with_dim(log_domain_size)?.isomorphic();
39		Self::with_subspace(&subspace)
40	}
41
42	pub fn with_subspace(subspace: &BinarySubspace<F>) -> Result<Self, Error> {
43		let twiddle_access = OnTheFlyTwiddleAccess::generate(subspace)?;
44		Ok(Self::with_twiddle_access(twiddle_access))
45	}
46
47	pub fn precompute_twiddles(&self) -> SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>> {
48		SingleThreadedNTT::with_twiddle_access(expand_subspace_evals(&self.s_evals))
49	}
50}
51
52impl<F: TowerField> SingleThreadedNTT<F> {
53	/// A specialization of [`with_domain_field`](Self::with_domain_field) to the canonical tower field.
54	pub fn with_canonical_field(log_domain_size: usize) -> Result<Self, Error> {
55		Self::with_domain_field::<F::Canonical>(log_domain_size)
56	}
57}
58
59impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
60	const fn with_twiddle_access(twiddle_access: Vec<TA>) -> Self {
61		Self {
62			s_evals: twiddle_access,
63			_marker: PhantomData,
64		}
65	}
66}
67
68impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
69	pub fn twiddles(&self) -> &[TA] {
70		&self.s_evals
71	}
72}
73
74impl<F, TA> AdditiveNTT<F> for SingleThreadedNTT<F, TA>
75where
76	F: BinaryField,
77	TA: TwiddleAccess<F>,
78{
79	fn log_domain_size(&self) -> usize {
80		self.s_evals.len()
81	}
82
83	fn subspace(&self, i: usize) -> BinarySubspace<F> {
84		let (subspace, shift) = self.s_evals[i].affine_subspace();
85		debug_assert_eq!(shift, F::ZERO, "s_evals subspaces must be linear by construction");
86		subspace
87	}
88
89	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
90		self.s_evals[i].get(j)
91	}
92
93	fn forward_transform<P: PackedField<Scalar = F>>(
94		&self,
95		data: &mut [P],
96		shape: NTTShape,
97		coset: u32,
98		skip_rounds: usize,
99	) -> Result<(), Error> {
100		forward_transform(self.log_domain_size(), &self.s_evals, data, shape, coset, skip_rounds)
101	}
102
103	fn inverse_transform<P: PackedField<Scalar = F>>(
104		&self,
105		data: &mut [P],
106		shape: NTTShape,
107		coset: u32,
108		skip_rounds: usize,
109	) -> Result<(), Error> {
110		inverse_transform(self.log_domain_size(), &self.s_evals, data, shape, coset, skip_rounds)
111	}
112}
113
114pub fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
115	log_domain_size: usize,
116	s_evals: &[impl TwiddleAccess<F>],
117	data: &mut [P],
118	shape: NTTShape,
119	coset: u32,
120	skip_rounds: usize,
121) -> Result<(), Error> {
122	check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
123
124	match data.len() {
125		0 => return Ok(()),
126		1 => {
127			return match P::LOG_WIDTH {
128				0 => Ok(()),
129				_ => {
130					// Special case when there is only one packed element: since we cannot
131					// interleave with another packed element, the code below will panic when there
132					// is only one.
133					//
134					// Handle the case of one packed element by batch transforming the original
135					// data with dummy data and extracting the transformed result.
136					let mut buffer = [data[0], P::zero()];
137					forward_transform(
138						log_domain_size,
139						s_evals,
140						&mut buffer,
141						shape,
142						coset,
143						skip_rounds,
144					)?;
145					data[0] = buffer[0];
146					Ok(())
147				}
148			};
149		}
150		_ => {}
151	};
152
153	let NTTShape {
154		log_x,
155		log_y,
156		log_z,
157	} = shape;
158
159	let log_w = P::LOG_WIDTH;
160
161	// Cutoff is the stage of the NTT where each the butterfly units are contained within
162	// packed base field elements.
163	let cutoff = log_w.saturating_sub(log_x);
164
165	// i indexes the layer of the NTT network, also the binary subspace.
166	for i in (cutoff..(log_y - skip_rounds)).rev() {
167		let s_evals_i = &s_evals[i];
168		let coset_offset = (coset as usize) << (log_y - 1 - i);
169
170		// j indexes the outer Z tensor axis.
171		for j in 0..1 << log_z {
172			// k indexes the block within the layer. Each block performs butterfly operations with
173			// the same twiddle factor.
174			for k in 0..1 << (log_y - 1 - i) {
175				let twiddle = s_evals_i.get(coset_offset | k);
176				for l in 0..1 << (i + log_x - log_w) {
177					let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
178					let idx1 = idx0 | 1 << (log_x + i - log_w);
179					data[idx0] += data[idx1] * twiddle;
180					data[idx1] += data[idx0];
181				}
182			}
183		}
184	}
185
186	for i in (0..cmp::min(cutoff, log_y - skip_rounds)).rev() {
187		let s_evals_i = &s_evals[i];
188		let coset_offset = (coset as usize) << (log_y - 1 - i);
189
190		// A block is a block of butterfly units that all have the same twiddle factor. Since we
191		// are below the cutoff round, the block length is less than the packing width, and
192		// therefore each packed multiplication is with a non-uniform twiddle. Since the subspace
193		// polynomials are linear, we can calculate an additive factor that can be added to the
194		// packed twiddles for all packed butterfly units.
195		let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
196
197		let log_block_len = i + log_x;
198		let log_packed_count = (log_y - 1).saturating_sub(cutoff);
199		for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
200			for k in 0..1 << log_packed_count {
201				let twiddle =
202					P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
203				let index = k << 1 | j << (log_packed_count + 1);
204				let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
205				u += v * twiddle;
206				v += u;
207				(data[index], data[index | 1]) = u.interleave(v, log_block_len);
208			}
209		}
210	}
211
212	Ok(())
213}
214
215pub fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
216	log_domain_size: usize,
217	s_evals: &[impl TwiddleAccess<F>],
218	data: &mut [P],
219	shape: NTTShape,
220	coset: u32,
221	skip_rounds: usize,
222) -> Result<(), Error> {
223	check_batch_transform_inputs_and_params(log_domain_size, data, shape, coset, skip_rounds)?;
224
225	match data.len() {
226		0 => return Ok(()),
227		1 => {
228			return match P::LOG_WIDTH {
229				0 => Ok(()),
230				_ => {
231					// Special case when there is only one packed element: since we cannot
232					// interleave with another packed element, the code below will panic when there
233					// is only one.
234					//
235					// Handle the case of one packed element by batch transforming the original
236					// data with dummy data and extracting the transformed result.
237					let mut buffer = [data[0], P::zero()];
238					inverse_transform(
239						log_domain_size,
240						s_evals,
241						&mut buffer,
242						shape,
243						coset,
244						skip_rounds,
245					)?;
246					data[0] = buffer[0];
247					Ok(())
248				}
249			};
250		}
251		_ => {}
252	};
253
254	let NTTShape {
255		log_x,
256		log_y,
257		log_z,
258	} = shape;
259
260	let log_w = P::LOG_WIDTH;
261
262	// Cutoff is the stage of the NTT where each the butterfly units are contained within
263	// packed base field elements.
264	let cutoff = log_w.saturating_sub(log_x);
265
266	#[allow(clippy::needless_range_loop)]
267	for i in 0..cutoff.min(log_y - skip_rounds) {
268		let s_evals_i = &s_evals[i];
269		let coset_offset = (coset as usize) << (log_y - 1 - i);
270
271		// A block is a block of butterfly units that all have the same twiddle factor. Since we
272		// are below the cutoff round, the block length is less than the packing width, and
273		// therefore each packed multiplication is with a non-uniform twiddle. Since the subspace
274		// polynomials are linear, we can calculate an additive factor that can be added to the
275		// packed twiddles for all packed butterfly units.
276		let block_twiddle = calculate_packed_additive_twiddle::<P>(s_evals_i, shape, i);
277
278		let log_block_len = i + log_x;
279		let log_packed_count = (log_y - 1).saturating_sub(cutoff);
280		for j in 0..1 << (log_x + log_y + log_z).saturating_sub(log_w + log_packed_count + 1) {
281			for k in 0..1 << log_packed_count {
282				let twiddle =
283					P::broadcast(s_evals_i.get(coset_offset | k << (cutoff - i))) + block_twiddle;
284				let index = k << 1 | j << (log_packed_count + 1);
285				let (mut u, mut v) = data[index].interleave(data[index | 1], log_block_len);
286				v += u;
287				u += v * twiddle;
288				(data[index], data[index | 1]) = u.interleave(v, log_block_len);
289			}
290		}
291	}
292
293	// i indexes the layer of the NTT network, also the binary subspace.
294	#[allow(clippy::needless_range_loop)]
295	for i in cutoff..(log_y - skip_rounds) {
296		let s_evals_i = &s_evals[i];
297		let coset_offset = (coset as usize) << (log_y - 1 - i);
298
299		// j indexes the outer Z tensor axis.
300		for j in 0..1 << log_z {
301			// k indexes the block within the layer. Each block performs butterfly operations with
302			// the same twiddle factor.
303			for k in 0..1 << (log_y - 1 - i) {
304				let twiddle = s_evals_i.get(coset_offset | k);
305				for l in 0..1 << (i + log_x - log_w) {
306					let idx0 = j << (log_x + log_y - log_w) | k << (log_x + i + 1 - log_w) | l;
307					let idx1 = idx0 | 1 << (log_x + i - log_w);
308					data[idx1] += data[idx0];
309					data[idx0] += data[idx1] * twiddle;
310				}
311			}
312		}
313	}
314
315	Ok(())
316}
317
318pub fn check_batch_transform_inputs_and_params<PB: PackedField>(
319	log_domain_size: usize,
320	data: &[PB],
321	shape: NTTShape,
322	coset: u32,
323	skip_rounds: usize,
324) -> Result<(), Error> {
325	let NTTShape {
326		log_x,
327		log_y,
328		log_z,
329	} = shape;
330
331	if !data.len().is_power_of_two() {
332		return Err(Error::PowerOfTwoLengthRequired);
333	}
334	if skip_rounds > log_y {
335		return Err(Error::SkipRoundsTooLarge);
336	}
337
338	let full_sized_y = (data.len() * PB::WIDTH) >> (log_x + log_z);
339
340	// Verify that our log_y exactly matches the data length, except when we are NTT-ing one packed field
341	if (1 << log_y != full_sized_y && data.len() > 2) || (1 << log_y > full_sized_y) {
342		return Err(Error::BatchTooLarge);
343	}
344
345	let coset_bits = 32 - coset.leading_zeros() as usize;
346
347	// The domain size should be at least large enough to represent the given coset.
348	let log_required_domain_size = log_y + coset_bits;
349	if log_required_domain_size > log_domain_size {
350		return Err(Error::DomainTooSmall {
351			log_required_domain_size,
352		});
353	}
354
355	Ok(())
356}
357
358#[inline]
359fn calculate_packed_additive_twiddle<P>(
360	s_evals: &impl TwiddleAccess<P::Scalar>,
361	shape: NTTShape,
362	ntt_round: usize,
363) -> P
364where
365	P: PackedField<Scalar: BinaryField>,
366{
367	let NTTShape {
368		log_x,
369		log_y,
370		log_z,
371	} = shape;
372	debug_assert!(log_y > 0);
373
374	let log_block_len = ntt_round + log_x;
375	debug_assert!(log_block_len < P::LOG_WIDTH);
376
377	let packed_log_len = (log_x + log_y + log_z).min(P::LOG_WIDTH);
378	let log_blocks_count = packed_log_len.saturating_sub(log_block_len + 1);
379
380	let packed_log_z = packed_log_len.saturating_sub(log_x + log_y);
381	let packed_log_y = packed_log_len - packed_log_z - log_x;
382
383	let twiddle_stride = P::LOG_WIDTH
384		.saturating_sub(log_x)
385		.min(log_blocks_count - packed_log_z);
386
387	let mut twiddle = P::default();
388	for i in 0..1 << (log_blocks_count - twiddle_stride) {
389		for j in 0..1 << twiddle_stride {
390			let (subblock_twiddle_0, subblock_twiddle_1) = if packed_log_y == log_y {
391				let same_twiddle = s_evals.get(j);
392				(same_twiddle, same_twiddle)
393			} else {
394				s_evals.get_pair(twiddle_stride, j)
395			};
396			let idx0 = j << (log_block_len + 1) | i << (log_block_len + twiddle_stride + 1);
397			let idx1 = idx0 | 1 << log_block_len;
398
399			for k in 0..1 << log_block_len {
400				twiddle.set(idx0 | k, subblock_twiddle_0);
401				twiddle.set(idx1 | k, subblock_twiddle_1);
402			}
403		}
404	}
405	twiddle
406}
407
408#[cfg(test)]
409mod tests {
410	use assert_matches::assert_matches;
411	use binius_field::{
412		BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedFieldIndexable,
413	};
414	use binius_math::Error as MathError;
415	use rand::{rngs::StdRng, SeedableRng};
416
417	use super::*;
418
419	#[test]
420	fn test_additive_ntt_fails_with_field_too_small() {
421		assert_matches!(
422			SingleThreadedNTT::<BinaryField8b>::new(10),
423			Err(Error::MathError(MathError::DomainSizeTooLarge))
424		);
425	}
426
427	#[test]
428	fn test_subspace_size_agrees_with_domain_size() {
429		let ntt = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
430		assert_eq!(ntt.subspace(0).dim(), 10);
431		assert_eq!(ntt.subspace(9).dim(), 1);
432	}
433
434	#[test]
435	fn one_packed_field_forward() {
436		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
437		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
438
439		let mut packed_copy = packed;
440
441		let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
442
443		let shape = NTTShape {
444			log_x: 0,
445			log_y: 3,
446			log_z: 0,
447		};
448		let _ = s.forward_transform(&mut packed, shape, 3, 0);
449		let _ = s.forward_transform(unpacked, shape, 3, 0);
450
451		for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
452			assert_eq!(packed[0].get(i), *unpacked_item);
453		}
454	}
455
456	#[test]
457	fn one_packed_field_inverse() {
458		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
459		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
460
461		let mut packed_copy = packed;
462
463		let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
464
465		let shape = NTTShape {
466			log_x: 0,
467			log_y: 3,
468			log_z: 0,
469		};
470		let _ = s.inverse_transform(&mut packed, shape, 3, 0);
471		let _ = s.inverse_transform(unpacked, shape, 3, 0);
472
473		for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
474			assert_eq!(packed[0].get(i), *unpacked_item);
475		}
476	}
477
478	#[test]
479	fn smaller_embedded_batch_forward() {
480		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
481		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
482
483		let mut packed_copy = packed;
484
485		let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
486
487		let shape = NTTShape {
488			log_x: 0,
489			log_y: 2,
490			log_z: 0,
491		};
492		let _ = forward_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 0);
493		let _ = s.forward_transform(unpacked, shape, 3, 0);
494
495		for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
496			assert_eq!(packed[0].get(i), *unpacked_item);
497		}
498	}
499
500	#[test]
501	fn smaller_embedded_batch_inverse() {
502		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
503		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
504
505		let mut packed_copy = packed;
506
507		let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
508
509		let shape = NTTShape {
510			log_x: 0,
511			log_y: 2,
512			log_z: 0,
513		};
514		let _ = inverse_transform(s.log_domain_size(), &s.s_evals, &mut packed, shape, 3, 0);
515		let _ = s.inverse_transform(unpacked, shape, 3, 0);
516
517		for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
518			assert_eq!(packed[0].get(i), *unpacked_item);
519		}
520	}
521
522	// TODO: Write test that compares polynomial evaluation via additive NTT with naive Lagrange
523	// polynomial interpolation. A randomized test should suffice for larger NTT sizes.
524}