binius_core/reed_solomon/
reed_solomon.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3//! [Reed–Solomon] codes over binary fields.
4//!
5//! See [`ReedSolomonCode`] for details.
6
7use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField};
8use binius_math::BinarySubspace;
9use binius_ntt::{AdditiveNTT, NTTShape, SingleThreadedNTT};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15/// [Reed–Solomon] codes over binary fields.
16///
17/// The Reed–Solomon code admits an efficient encoding algorithm over binary fields due to [LCH14].
18/// The additive NTT encoding algorithm encodes messages interpreted as the coefficients of a
19/// polynomial in a non-standard, novel polynomial basis and the codewords are the polynomial
20/// evaluations over a linear subspace of the field. See the [binius_ntt] crate for more details.
21///
22/// [Reed–Solomon]: <https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction>
23/// [LCH14]: <https://arxiv.org/abs/1404.3458>
24#[derive(Debug, Getters, CopyGetters)]
25pub struct ReedSolomonCode<F: BinaryField> {
26	#[get = "pub"]
27	subspace: BinarySubspace<F>,
28	log_dimension: usize,
29	#[get_copy = "pub"]
30	log_inv_rate: usize,
31}
32
33impl<F: BinaryField> ReedSolomonCode<F> {
34	pub fn new(log_dimension: usize, log_inv_rate: usize) -> Result<Self, Error> {
35		let ntt = SingleThreadedNTT::new(log_dimension + log_inv_rate)?;
36		Self::with_ntt_subspace(&ntt, log_dimension, log_inv_rate)
37	}
38
39	pub fn with_ntt_subspace(
40		ntt: &impl AdditiveNTT<F>,
41		log_dimension: usize,
42		log_inv_rate: usize,
43	) -> Result<Self, Error> {
44		if log_dimension + log_inv_rate > ntt.log_domain_size() {
45			return Err(Error::SubspaceDimensionMismatch);
46		}
47		let subspace_idx = ntt.log_domain_size() - (log_dimension + log_inv_rate);
48		Self::with_subspace(ntt.subspace(subspace_idx), log_dimension, log_inv_rate)
49	}
50
51	pub fn with_subspace(
52		subspace: BinarySubspace<F>,
53		log_dimension: usize,
54		log_inv_rate: usize,
55	) -> Result<Self, Error> {
56		if subspace.dim() != log_dimension + log_inv_rate {
57			return Err(Error::SubspaceDimensionMismatch);
58		}
59		Ok(Self {
60			subspace,
61			log_dimension,
62			log_inv_rate,
63		})
64	}
65
66	/// The dimension.
67	pub const fn dim(&self) -> usize {
68		1 << self.dim_bits()
69	}
70
71	pub const fn log_dim(&self) -> usize {
72		self.log_dimension
73	}
74
75	pub const fn log_len(&self) -> usize {
76		self.log_dimension + self.log_inv_rate
77	}
78
79	/// The block length.
80	#[allow(clippy::len_without_is_empty)]
81	pub const fn len(&self) -> usize {
82		1 << (self.log_dimension + self.log_inv_rate)
83	}
84
85	/// The base-2 log of the dimension.
86	const fn dim_bits(&self) -> usize {
87		self.log_dimension
88	}
89
90	/// The reciprocal of the rate, ie. `self.len() / self.dim()`.
91	pub const fn inv_rate(&self) -> usize {
92		1 << self.log_inv_rate
93	}
94
95	/// Encode a batch of interleaved messages in-place in a provided buffer.
96	///
97	/// The message symbols are interleaved in the buffer, which improves the cache-efficiency of
98	/// the encoding procedure. The interleaved codeword is stored in the buffer when the method
99	/// completes.
100	///
101	/// ## Throws
102	///
103	/// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements.
104	fn encode_batch_inplace<P: PackedField<Scalar = F>, NTT: AdditiveNTT<F> + Sync>(
105		&self,
106		ntt: &NTT,
107		code: &mut [P],
108		log_batch_size: usize,
109	) -> Result<(), Error> {
110		if ntt.subspace(ntt.log_domain_size() - self.log_len()) != self.subspace {
111			bail!(Error::EncoderSubspaceMismatch);
112		}
113		let expected_buffer_len =
114			1 << (self.log_len() + log_batch_size).saturating_sub(P::LOG_WIDTH);
115		if code.len() != expected_buffer_len {
116			bail!(Error::IncorrectBufferLength {
117				expected: expected_buffer_len,
118				actual: code.len(),
119			});
120		}
121
122		let _scope = tracing::trace_span!(
123			"Reed–Solomon encode",
124			log_len = self.log_len(),
125			log_batch_size = log_batch_size,
126			symbol_bits = F::N_BITS,
127		)
128		.entered();
129
130		// Repeat the message to fill the entire buffer.
131
132		// First, if the message is less than the packing width, we need to repeat it to fill one
133		// packed element.
134		if self.dim() + log_batch_size < P::LOG_WIDTH {
135			let repeated_values = code[0]
136				.into_iter()
137				.take(1 << (self.log_dim() + log_batch_size))
138				.cycle();
139			code[0] = P::from_scalars(repeated_values);
140		}
141
142		// Repeat the packed message to fill the entire buffer.
143		let mut chunks =
144			code.chunks_mut(1 << (self.log_dim() + log_batch_size).saturating_sub(P::LOG_WIDTH));
145		let first_chunk = chunks.next().expect("code is not empty; checked above");
146		for chunk in chunks {
147			chunk.copy_from_slice(first_chunk);
148		}
149
150		let shape = NTTShape {
151			log_x: log_batch_size,
152			log_y: self.log_len(),
153			..Default::default()
154		};
155		ntt.forward_transform(code, shape, 0, 0, self.log_inv_rate)?;
156		Ok(())
157	}
158
159	/// Encode a batch of interleaved messages of extension field elements in-place in a provided
160	/// buffer.
161	///
162	/// A linear code can be naturally extended to a code over extension fields by encoding each
163	/// dimension of the extension as a vector-space separately.
164	///
165	/// ## Preconditions
166	///
167	/// * `PE::Scalar::DEGREE` must be a power of two.
168	///
169	/// ## Throws
170	///
171	/// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements.
172	pub fn encode_ext_batch_inplace<PE: PackedExtension<F>, NTT: AdditiveNTT<F> + Sync>(
173		&self,
174		ntt: &NTT,
175		code: &mut [PE],
176		log_batch_size: usize,
177	) -> Result<(), Error> {
178		self.encode_batch_inplace(
179			ntt,
180			PE::cast_bases_mut(code),
181			log_batch_size + PE::Scalar::LOG_DEGREE,
182		)
183	}
184}