binius_core/reed_solomon/
reed_solomon.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3//! [Reed–Solomon] codes over binary fields.
4//!
5//! The Reed–Solomon code admits an efficient encoding algorithm over binary fields due to [LCH14].
6//! The additive NTT encoding algorithm encodes messages interpreted as the coefficients of a
7//! polynomial in a non-standard, novel polynomial basis and the codewords are the polynomial
8//! evaluations over a linear subspace of the field. See the [binius_ntt] crate for more details.
9//!
10//! [Reed–Solomon]: <https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction>
11//! [LCH14]: <https://arxiv.org/abs/1404.3458>
12
13use std::marker::PhantomData;
14
15use binius_field::{BinaryField, ExtensionField, PackedField, RepackedExtension};
16use binius_maybe_rayon::prelude::*;
17use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings};
18use binius_utils::bail;
19use getset::CopyGetters;
20use tracing::instrument;
21
22#[derive(Debug, CopyGetters)]
23pub struct ReedSolomonCode<P>
24where
25	P: PackedField,
26	P::Scalar: BinaryField,
27{
28	ntt: DynamicDispatchNTT<P::Scalar>,
29	log_dimension: usize,
30	#[getset(get_copy = "pub")]
31	log_inv_rate: usize,
32	multithreaded: bool,
33	_p_marker: PhantomData<P>,
34}
35
36impl<P> ReedSolomonCode<P>
37where
38	P: PackedField<Scalar: BinaryField>,
39{
40	pub fn new(
41		log_dimension: usize,
42		log_inv_rate: usize,
43		ntt_options: &NTTOptions,
44	) -> Result<Self, Error> {
45		// Since we split work between log_inv_rate threads, we need to decrease the number of threads per each NTT transformation.
46		let ntt_log_threads = ntt_options
47			.thread_settings
48			.log_threads_count()
49			.saturating_sub(log_inv_rate);
50		let ntt = DynamicDispatchNTT::new(
51			log_dimension + log_inv_rate,
52			&NTTOptions {
53				thread_settings: ThreadingSettings::ExplicitThreadsCount {
54					log_threads: ntt_log_threads,
55				},
56				precompute_twiddles: ntt_options.precompute_twiddles,
57			},
58		)?;
59
60		let multithreaded =
61			!matches!(ntt_options.thread_settings, ThreadingSettings::SingleThreaded);
62
63		Ok(Self {
64			ntt,
65			log_dimension,
66			log_inv_rate,
67			multithreaded,
68			_p_marker: PhantomData,
69		})
70	}
71
72	pub const fn get_ntt(&self) -> &impl AdditiveNTT<P::Scalar> {
73		&self.ntt
74	}
75
76	/// The dimension.
77	pub const fn dim(&self) -> usize {
78		1 << self.dim_bits()
79	}
80
81	pub const fn log_dim(&self) -> usize {
82		self.log_dimension
83	}
84
85	pub const fn log_len(&self) -> usize {
86		self.log_dimension + self.log_inv_rate
87	}
88
89	/// The block length.
90	#[allow(clippy::len_without_is_empty)]
91	pub const fn len(&self) -> usize {
92		1 << (self.log_dimension + self.log_inv_rate)
93	}
94
95	/// The base-2 log of the dimension.
96	const fn dim_bits(&self) -> usize {
97		self.log_dimension
98	}
99
100	/// The reciprocal of the rate, ie. `self.len() / self.dim()`.
101	pub const fn inv_rate(&self) -> usize {
102		1 << self.log_inv_rate
103	}
104
105	/// Encode a batch of interleaved messages in-place in a provided buffer.
106	///
107	/// The message symbols are interleaved in the buffer, which improves the cache-efficiency of
108	/// the encoding procedure. The interleaved codeword is stored in the buffer when the method
109	/// completes.
110	///
111	/// ## Throws
112	///
113	/// * If the `code` buffer does not have capacity for `len() << log_batch_size` field
114	///   elements.
115	fn encode_batch_inplace(&self, code: &mut [P], log_batch_size: usize) -> Result<(), Error> {
116		let _scope = tracing::trace_span!(
117			"Reed–Solomon encode",
118			log_len = self.log_len(),
119			log_batch_size = log_batch_size,
120			symbol_bits = P::Scalar::N_BITS,
121		)
122		.entered();
123		if (code.len() << log_batch_size) < self.len() {
124			bail!(Error::BufferTooSmall {
125				log_code_len: self.len(),
126			});
127		}
128		if self.dim() % P::WIDTH != 0 {
129			bail!(Error::PackingWidthMustDivideDimension);
130		}
131
132		let msgs_len = (self.dim() / P::WIDTH) << log_batch_size;
133		for i in 1..(1 << self.log_inv_rate) {
134			code.copy_within(0..msgs_len, i * msgs_len);
135		}
136
137		if self.multithreaded {
138			(0..(1 << self.log_inv_rate))
139				.into_par_iter()
140				.zip(code.par_chunks_exact_mut(msgs_len))
141				.try_for_each(|(i, data)| {
142					self.ntt
143						.forward_transform(data, i, log_batch_size, self.log_dim())
144				})
145		} else {
146			(0..(1 << self.log_inv_rate))
147				.zip(code.chunks_exact_mut(msgs_len))
148				.try_for_each(|(i, data)| {
149					self.ntt
150						.forward_transform(data, i, log_batch_size, self.log_dim())
151				})
152		}
153	}
154
155	/// Encode a batch of interleaved messages of extension field elements in-place in a provided
156	/// buffer.
157	///
158	/// A linear code can be naturally extended to a code over extension fields by encoding each
159	/// dimension of the extension as a vector-space separately.
160	///
161	/// ## Preconditions
162	///
163	/// * `PE::Scalar::DEGREE` must be a power of two.
164	///
165	/// ## Throws
166	///
167	/// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements.
168	#[instrument(skip_all, level = "debug")]
169	pub fn encode_ext_batch_inplace<PE: RepackedExtension<P>>(
170		&self,
171		code: &mut [PE],
172		log_batch_size: usize,
173	) -> Result<(), Error> {
174		self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + PE::Scalar::LOG_DEGREE)
175	}
176}