binius_ntt/
single_threaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp, marker::PhantomData};
4
5use binius_field::{BinaryField, PackedField, TowerField};
6use binius_math::BinarySubspace;
7
8use super::{additive_ntt::AdditiveNTT, error::Error, twiddle::TwiddleAccess};
9use crate::twiddle::{expand_subspace_evals, OnTheFlyTwiddleAccess, PrecomputedTwiddleAccess};
10
11/// Implementation of `AdditiveNTT` that performs the computation single-threaded.
12#[derive(Debug)]
13pub struct SingleThreadedNTT<F: BinaryField, TA: TwiddleAccess<F> = OnTheFlyTwiddleAccess<F>> {
14	// TODO: Figure out how to make this private, it should not be `pub(super)`.
15	pub(super) s_evals: Vec<TA>,
16	_marker: PhantomData<F>,
17}
18
19impl<F: BinaryField> SingleThreadedNTT<F> {
20	/// Default constructor constructs an NTT over the canonical subspace for the field using
21	/// on-the-fly computed twiddle factors.
22	pub fn new(log_domain_size: usize) -> Result<Self, Error> {
23		let subspace = BinarySubspace::with_dim(log_domain_size)?;
24		let twiddle_access = OnTheFlyTwiddleAccess::generate(&subspace)?;
25		Ok(Self::with_twiddle_access(twiddle_access))
26	}
27
28	/// Constructs an NTT over an isomorphic subspace for the given domain field using on-the-fly
29	/// computed twiddle factors.
30	pub fn with_domain_field<FDomain>(log_domain_size: usize) -> Result<Self, Error>
31	where
32		FDomain: BinaryField,
33		F: From<FDomain>,
34	{
35		let subspace = BinarySubspace::<FDomain>::with_dim(log_domain_size)?.isomorphic();
36		let twiddle_access = OnTheFlyTwiddleAccess::generate(&subspace)?;
37		Ok(Self::with_twiddle_access(twiddle_access))
38	}
39
40	pub fn precompute_twiddles(&self) -> SingleThreadedNTT<F, PrecomputedTwiddleAccess<F>> {
41		SingleThreadedNTT::with_twiddle_access(expand_subspace_evals(&self.s_evals))
42	}
43}
44
45impl<F: TowerField> SingleThreadedNTT<F> {
46	/// A specialization of [`with_domain_field`](Self::with_domain_field) to the canonical tower field.
47	pub fn with_canonical_field(log_domain_size: usize) -> Result<Self, Error> {
48		Self::with_domain_field::<F::Canonical>(log_domain_size)
49	}
50}
51
52impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
53	const fn with_twiddle_access(twiddle_access: Vec<TA>) -> Self {
54		Self {
55			s_evals: twiddle_access,
56			_marker: PhantomData,
57		}
58	}
59}
60
61impl<F: BinaryField, TA: TwiddleAccess<F>> SingleThreadedNTT<F, TA> {
62	pub fn twiddles(&self) -> &[TA] {
63		&self.s_evals
64	}
65}
66
67impl<F, TA> AdditiveNTT<F> for SingleThreadedNTT<F, TA>
68where
69	F: BinaryField,
70	TA: TwiddleAccess<F>,
71{
72	fn log_domain_size(&self) -> usize {
73		self.s_evals.len()
74	}
75
76	fn subspace(&self, i: usize) -> BinarySubspace<F> {
77		let (subspace, shift) = self.s_evals[i].affine_subspace();
78		debug_assert_eq!(shift, F::ZERO, "s_evals subspaces must be linear by construction");
79		subspace
80	}
81
82	fn get_subspace_eval(&self, i: usize, j: usize) -> F {
83		self.s_evals[i].get(j)
84	}
85
86	fn forward_transform<P: PackedField<Scalar = F>>(
87		&self,
88		data: &mut [P],
89		coset: u32,
90		log_batch_size: usize,
91		log_n: usize,
92	) -> Result<(), Error> {
93		forward_transform(self.log_domain_size(), &self.s_evals, data, coset, log_batch_size, log_n)
94	}
95
96	fn inverse_transform<P: PackedField<Scalar = F>>(
97		&self,
98		data: &mut [P],
99		coset: u32,
100		log_batch_size: usize,
101		log_n: usize,
102	) -> Result<(), Error> {
103		inverse_transform(self.log_domain_size(), &self.s_evals, data, coset, log_batch_size, log_n)
104	}
105}
106
107pub fn forward_transform<F: BinaryField, P: PackedField<Scalar = F>>(
108	log_domain_size: usize,
109	s_evals: &[impl TwiddleAccess<F>],
110	data: &mut [P],
111	coset: u32,
112	log_batch_size: usize,
113	log_n: usize,
114) -> Result<(), Error> {
115	match data.len() {
116		0 => return Ok(()),
117		1 => {
118			return match P::LOG_WIDTH {
119				0 => Ok(()),
120				_ => {
121					// Special case when there is only one packed element: since we cannot
122					// interleave with another packed element, the code below will panic when there
123					// is only one.
124					//
125					// Handle the case of one packed element by batch transforming the original
126					// data with dummy data and extracting the transformed result.
127					let mut buffer = [data[0], P::zero()];
128
129					forward_transform(
130						log_domain_size,
131						s_evals,
132						&mut buffer,
133						coset,
134						log_batch_size,
135						log_n,
136					)?;
137
138					data[0] = buffer[0];
139
140					Ok(())
141				}
142			};
143		}
144		_ => {}
145	};
146
147	let log_b = log_batch_size;
148
149	let log_w = P::LOG_WIDTH;
150
151	check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_batch_size, log_n)?;
152
153	// Cutoff is the stage of the NTT where each the butterfly units are contained within
154	// packed base field elements.
155	let cutoff = log_w.saturating_sub(log_b);
156
157	for i in (cutoff..log_n).rev() {
158		let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
159
160		for j in 0..1 << (log_n - 1 - i) {
161			let twiddle = P::broadcast(coset_twiddle.get(j));
162			for k in 0..1 << (i + log_b - log_w) {
163				let idx0 = j << (i + log_b - log_w + 1) | k;
164				let idx1 = idx0 | 1 << (i + log_b - log_w);
165				data[idx0] += data[idx1] * twiddle;
166				data[idx1] += data[idx0];
167			}
168		}
169	}
170
171	for i in (0..cmp::min(cutoff, log_n)).rev() {
172		let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
173
174		// A block is a block of butterfly units that all have the same twiddle factor. Since we
175		// are below the cutoff round, the block length is less than the packing width, and
176		// therefore each packed multiplication is with a non-uniform twiddle. Since the subspace
177		// polynomials are linear, we can calculate an additive factor that can be added to the
178		// packed twiddles for all packed butterfly units.
179		let log_block_len = i + log_b;
180		let block_twiddle = calculate_twiddle::<P>(
181			&s_evals[i].coset(log_domain_size - 1 - cutoff, 0),
182			log_block_len,
183		);
184
185		for j in 0..data.len() / 2 {
186			let twiddle = P::broadcast(coset_twiddle.get(j << (cutoff - i))) + block_twiddle;
187			let (mut u, mut v) = data[j << 1].interleave(data[j << 1 | 1], log_block_len);
188			u += v * twiddle;
189			v += u;
190			(data[j << 1], data[j << 1 | 1]) = u.interleave(v, log_block_len);
191		}
192	}
193
194	Ok(())
195}
196
197pub fn inverse_transform<F: BinaryField, P: PackedField<Scalar = F>>(
198	log_domain_size: usize,
199	s_evals: &[impl TwiddleAccess<F>],
200	data: &mut [P],
201	coset: u32,
202	log_batch_size: usize,
203	log_n: usize,
204) -> Result<(), Error> {
205	match data.len() {
206		0 => return Ok(()),
207		1 => {
208			return match P::LOG_WIDTH {
209				0 => Ok(()),
210				_ => {
211					// Special case when there is only one packed element: since we cannot
212					// interleave with another packed element, the code below will panic when there
213					// is only one.
214					//
215					// Handle the case of one packed element by batch transforming the original
216					// data with dummy data and extracting the transformed result.
217					let mut buffer = [data[0], P::zero()];
218
219					inverse_transform(
220						log_domain_size,
221						s_evals,
222						&mut buffer,
223						coset,
224						log_batch_size,
225						log_n,
226					)?;
227
228					data[0] = buffer[0];
229					Ok(())
230				}
231			};
232		}
233		_ => {}
234	};
235
236	let log_w = P::LOG_WIDTH;
237
238	let log_b = log_batch_size;
239
240	check_batch_transform_inputs_and_params(log_domain_size, data, coset, log_batch_size, log_n)?;
241
242	// Cutoff is the stage of the NTT where each the butterfly units are contained within
243	// packed base field elements.
244	let cutoff = log_w.saturating_sub(log_b);
245
246	#[allow(clippy::needless_range_loop)]
247	for i in 0..cmp::min(cutoff, log_n) {
248		let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
249
250		// A block is a block of butterfly units that all have the same twiddle factor. Since we
251		// are below the cutoff round, the block length is less than the packing width, and
252		// therefore each packed multiplication is with a non-uniform twiddle. Since the subspace
253		// polynomials are linear, we can calculate an additive factor that can be added to the
254		// packed twiddles for all packed butterfly units.
255		let log_block_len = i + log_b;
256		let block_twiddle = calculate_twiddle::<P>(
257			&s_evals[i].coset(log_domain_size - 1 - cutoff, 0),
258			log_block_len,
259		);
260
261		for j in 0..data.len() / 2 {
262			let twiddle = P::broadcast(coset_twiddle.get(j << (cutoff - i))) + block_twiddle;
263			let (mut u, mut v) = data[j << 1].interleave(data[j << 1 | 1], log_block_len);
264			v += u;
265			u += v * twiddle;
266			(data[j << 1], data[j << 1 | 1]) = u.interleave(v, log_block_len);
267		}
268	}
269
270	#[allow(clippy::needless_range_loop)]
271	for i in cutoff..log_n {
272		let coset_twiddle = s_evals[i].coset(log_domain_size - log_n, coset as usize);
273
274		for j in 0..1 << (log_n - 1 - i) {
275			let twiddle = P::broadcast(coset_twiddle.get(j));
276			for k in 0..1 << (i + log_b - log_w) {
277				let idx0 = j << (i + log_b - log_w + 1) | k;
278				let idx1 = idx0 | 1 << (i + log_b - log_w);
279				data[idx1] += data[idx0];
280				data[idx0] += data[idx1] * twiddle;
281			}
282		}
283	}
284
285	Ok(())
286}
287
288pub fn check_batch_transform_inputs_and_params<PB: PackedField>(
289	log_domain_size: usize,
290	data: &[PB],
291	coset: u32,
292	log_batch_size: usize,
293	log_n: usize,
294) -> Result<(), Error> {
295	if !data.len().is_power_of_two() {
296		return Err(Error::PowerOfTwoLengthRequired);
297	}
298	if !PB::WIDTH.is_power_of_two() {
299		return Err(Error::PackingWidthMustDivideDimension);
300	}
301
302	let full_sized_n = (data.len() * PB::WIDTH) >> log_batch_size;
303
304	// Verify that our log_n exactly matches the data length, except when we are NTT-ing one packed field
305	if (1 << log_n != full_sized_n && data.len() > 2) || (1 << log_n > full_sized_n) {
306		return Err(Error::BatchTooLarge);
307	}
308
309	let coset_bits = 32 - coset.leading_zeros() as usize;
310
311	// The domain size should be at least large enough to represent the given coset;
312	// on the lower end, there is a fallback for data.len() == 1 which reduces to
313	// a forward/inverse NTT on the [PB; 2], which demands log_domain_size of
314	// at least min(PB::LOG_WIDTH + 1 - log_batch_size, 0).
315	// Not enforcing this bound makes some twiddle values unavailable.
316	let log_required_domain_size =
317		(log_n + coset_bits).max((PB::LOG_WIDTH + 1).saturating_sub(log_batch_size));
318	if log_required_domain_size > log_domain_size {
319		return Err(Error::DomainTooSmall {
320			log_required_domain_size,
321		});
322	}
323
324	Ok(())
325}
326
327#[inline]
328fn calculate_twiddle<P>(s_evals: &impl TwiddleAccess<P::Scalar>, log_block_len: usize) -> P
329where
330	P: PackedField<Scalar: BinaryField>,
331{
332	let log_blocks_count = P::LOG_WIDTH - log_block_len - 1;
333
334	let mut twiddle = P::default();
335	for k in 0..1 << log_blocks_count {
336		let (subblock_twiddle_0, subblock_twiddle_1) = s_evals.get_pair(log_blocks_count, k);
337		let idx0 = k << (log_block_len + 1);
338		let idx1 = idx0 | 1 << log_block_len;
339
340		for l in 0..1 << log_block_len {
341			twiddle.set(idx0 | l, subblock_twiddle_0);
342			twiddle.set(idx1 | l, subblock_twiddle_1);
343		}
344	}
345	twiddle
346}
347
348#[cfg(test)]
349mod tests {
350	use assert_matches::assert_matches;
351	use binius_field::{
352		BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedFieldIndexable,
353	};
354	use binius_math::Error as MathError;
355	use rand::{rngs::StdRng, SeedableRng};
356
357	use super::*;
358
359	#[test]
360	fn test_additive_ntt_fails_with_field_too_small() {
361		assert_matches!(
362			SingleThreadedNTT::<BinaryField8b>::new(10),
363			Err(Error::MathError(MathError::DomainSizeTooLarge))
364		);
365	}
366
367	#[test]
368	fn test_subspace_size_agrees_with_domain_size() {
369		let ntt = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
370		assert_eq!(ntt.subspace(0).dim(), 10);
371		assert_eq!(ntt.subspace(9).dim(), 1);
372	}
373
374	#[test]
375	fn one_packed_field_forward() {
376		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
377		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
378
379		let mut packed_copy = packed;
380
381		let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
382
383		let _ = s.forward_transform(&mut packed, 3, 0, 3);
384
385		let _ = s.forward_transform(unpacked, 3, 0, 3);
386
387		for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
388			assert_eq!(packed[0].get(i), *unpacked_item);
389		}
390	}
391
392	#[test]
393	fn one_packed_field_inverse() {
394		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
395		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
396
397		let mut packed_copy = packed;
398
399		let unpacked = PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy);
400
401		let _ = s.inverse_transform(&mut packed, 3, 0, 3);
402
403		let _ = s.inverse_transform(unpacked, 3, 0, 3);
404
405		for (i, unpacked_item) in unpacked.iter().enumerate().take(8) {
406			assert_eq!(packed[0].get(i), *unpacked_item);
407		}
408	}
409
410	#[test]
411	fn smaller_embedded_batch_forward() {
412		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
413		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
414
415		let mut packed_copy = packed;
416
417		let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
418
419		let _ = forward_transform(s.log_domain_size(), &s.s_evals, &mut packed, 3, 0, 2);
420
421		let _ = s.forward_transform(unpacked, 3, 0, 2);
422
423		for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
424			assert_eq!(packed[0].get(i), *unpacked_item);
425		}
426	}
427
428	#[test]
429	fn smaller_embedded_batch_inverse() {
430		let s = SingleThreadedNTT::<BinaryField16b>::new(10).expect("msg");
431		let mut packed = [PackedBinaryField8x16b::random(StdRng::from_entropy())];
432
433		let mut packed_copy = packed;
434
435		let unpacked = &mut PackedBinaryField8x16b::unpack_scalars_mut(&mut packed_copy)[0..4];
436
437		let _ = inverse_transform(s.log_domain_size(), &s.s_evals, &mut packed, 3, 0, 2);
438
439		let _ = s.inverse_transform(unpacked, 3, 0, 2);
440
441		for (i, unpacked_item) in unpacked.iter().enumerate().take(4) {
442			assert_eq!(packed[0].get(i), *unpacked_item);
443		}
444	}
445
446	// TODO: Write test that compares polynomial evaluation via additive NTT with naive Lagrange
447	// polynomial interpolation. A randomized test should suffice for larger NTT sizes.
448}