binius_math/
reed_solomon.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3//! [Reed–Solomon] codes over binary fields.
4//!
5//! See [`ReedSolomonCode`] for details.
6
7use std::ptr;
8
9use binius_field::{BinaryField, PackedField};
10use binius_utils::rayon::prelude::*;
11use getset::{CopyGetters, Getters};
12
13use super::{
14	FieldBuffer, FieldSlice, FieldSliceMut, binary_subspace::BinarySubspace,
15	error::Error as MathError, ntt::AdditiveNTT,
16};
17use crate::{
18	bit_reverse::{bit_reverse_indices, bit_reverse_packed},
19	ntt::DomainContext,
20};
21
22/// [Reed–Solomon] codes over binary fields.
23///
24/// The Reed–Solomon code admits an efficient encoding algorithm over binary fields due to [LCH14].
25/// The additive NTT encoding algorithm encodes messages interpreted as the coefficients of a
26/// polynomial in a non-standard, novel polynomial basis and the codewords are the polynomial
27/// evaluations over a linear subspace of the field. See the [binius-math] crate for more details.
28///
29/// [Reed–Solomon]: <https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction>
30/// [LCH14]: <https://arxiv.org/abs/1404.3458>
31#[derive(Debug, Clone, Getters, CopyGetters)]
32pub struct ReedSolomonCode<F> {
33	#[get = "pub"]
34	subspace: BinarySubspace<F>,
35	log_dimension: usize,
36	#[get_copy = "pub"]
37	log_inv_rate: usize,
38}
39
40impl<F: BinaryField> ReedSolomonCode<F> {
41	pub fn new(log_dimension: usize, log_inv_rate: usize) -> Result<Self, Error> {
42		let subspace = BinarySubspace::with_dim(log_dimension + log_inv_rate)?;
43		Self::with_subspace(subspace, log_dimension, log_inv_rate)
44	}
45
46	pub fn with_ntt_subspace(
47		ntt: &impl AdditiveNTT<Field = F>,
48		log_dimension: usize,
49		log_inv_rate: usize,
50	) -> Result<Self, Error> {
51		Self::with_domain_context_subspace(ntt.domain_context(), log_dimension, log_inv_rate)
52	}
53
54	pub fn with_domain_context_subspace(
55		domain_context: &impl DomainContext<Field = F>,
56		log_dimension: usize,
57		log_inv_rate: usize,
58	) -> Result<Self, Error> {
59		let subspace_dim = log_dimension + log_inv_rate;
60		if subspace_dim > domain_context.log_domain_size() {
61			return Err(Error::SubspaceDimensionMismatch);
62		}
63		let subspace = domain_context.subspace(subspace_dim);
64		Self::with_subspace(subspace, log_dimension, log_inv_rate)
65	}
66
67	pub fn with_subspace(
68		subspace: BinarySubspace<F>,
69		log_dimension: usize,
70		log_inv_rate: usize,
71	) -> Result<Self, Error> {
72		if subspace.dim() != log_dimension + log_inv_rate {
73			return Err(Error::SubspaceDimensionMismatch);
74		}
75		Ok(Self {
76			subspace,
77			log_dimension,
78			log_inv_rate,
79		})
80	}
81
82	/// The dimension.
83	pub const fn dim(&self) -> usize {
84		1 << self.dim_bits()
85	}
86
87	pub const fn log_dim(&self) -> usize {
88		self.log_dimension
89	}
90
91	pub const fn log_len(&self) -> usize {
92		self.log_dimension + self.log_inv_rate
93	}
94
95	/// The block length.
96	#[allow(clippy::len_without_is_empty)]
97	pub const fn len(&self) -> usize {
98		1 << (self.log_dimension + self.log_inv_rate)
99	}
100
101	/// The base-2 log of the dimension.
102	const fn dim_bits(&self) -> usize {
103		self.log_dimension
104	}
105
106	/// The reciprocal of the rate, ie. `self.len() / self.dim()`.
107	pub const fn inv_rate(&self) -> usize {
108		1 << self.log_inv_rate
109	}
110
111	/// Encodes a message with an interleaved Reed–Solomon code.
112	///
113	/// This function interprets the message as a batch of independent vectors and applies an
114	/// interleaved Reed–Solomon.
115	///
116	/// ## Preconditions
117	///
118	/// * `data.log_len()` equal `log_dim() + log_batch_size`.
119	///
120	/// ## Postconditions
121	///
122	/// * On success, all elements in the output buffer are initialized with the encoded codeword.
123	///
124	/// ## Throws
125	///
126	/// * [`Error::EncoderSubspaceMismatch`] if the NTT subspace doesn't match the code's subspace.
127	/// * [`Error::Math`] if the output buffer has incorrect dimensions.
128	pub fn encode_batch<P, NTT>(
129		&self,
130		ntt: &NTT,
131		data: FieldSlice<P>,
132		log_batch_size: usize,
133	) -> Result<FieldBuffer<P>, Error>
134	where
135		P: PackedField<Scalar = F>,
136		NTT: AdditiveNTT<Field = F> + Sync,
137	{
138		if ntt.subspace(self.log_len()) != self.subspace {
139			return Err(Error::EncoderSubspaceMismatch);
140		}
141
142		assert_eq!(data.log_len(), self.log_dim() + log_batch_size); // precondition
143
144		let _scope = tracing::trace_span!(
145			"Reed-Solomon encode",
146			log_len = self.log_len(),
147			log_batch_size = log_batch_size,
148			symbol_bits = F::N_BITS,
149		)
150		.entered();
151
152		// Repeat the message to fill the entire buffer.
153		let log_output_len = self.log_dim() + log_batch_size + self.log_inv_rate;
154		let output_data = if data.log_len() < P::LOG_WIDTH {
155			let mut scalars = data.iter_scalars().collect::<Vec<_>>();
156			bit_reverse_indices(&mut scalars);
157			let elem_0 = P::from_scalars(scalars.into_iter().cycle());
158			vec![elem_0; 1 << log_output_len.saturating_sub(P::LOG_WIDTH)]
159		} else {
160			let mut output_data = Vec::with_capacity(1 << (log_output_len - P::LOG_WIDTH));
161
162			output_data.extend_from_slice(data.as_ref());
163
164			// Bit-reverse permute the message.
165			bit_reverse_packed(
166				FieldSliceMut::from_slice(data.log_len(), output_data.as_mut_slice())
167					.expect("output_data.len() == data.as_ref.len()"),
168			);
169
170			let log_msg_len_packed = data.log_len() - P::LOG_WIDTH;
171			output_data
172				.spare_capacity_mut()
173				.par_chunks_exact_mut(1 << log_msg_len_packed)
174				.enumerate()
175				.for_each(|(i, output_chunk)| unsafe {
176					let dst_ptr = output_chunk.as_mut_ptr();
177
178					// TODO(https://github.com/rust-lang/rust/issues/81944):
179					// Improve unsafe code with Vec::split_at_spare_mut when stable
180
181					// Safety:
182					// - log_output_len == log_msg_len_packed + self.log_inv_rate
183					// - i + 1 is in the range 1..1 << self.log_inv_rate
184					// - dst_ptr is disjoint from src_ptr and within the Vec capacity
185					let src_ptr = dst_ptr.sub((i + 1) << log_msg_len_packed);
186					ptr::copy_nonoverlapping(src_ptr, dst_ptr, 1 << log_msg_len_packed);
187				});
188
189			unsafe {
190				// Safety: the vec's spare capacity is fully initialized above.
191				output_data.set_len(1 << (log_output_len - P::LOG_WIDTH));
192			}
193
194			output_data
195		};
196		let mut output = FieldBuffer::new(log_output_len, output_data.into_boxed_slice())
197			.expect("preconditions satisfied");
198
199		ntt.forward_transform(output.to_mut(), self.log_inv_rate, log_batch_size);
200		Ok(output)
201	}
202}
203
204#[derive(Debug, thiserror::Error)]
205pub enum Error {
206	#[error("the evaluation domain of the code does not match the subspace of the NTT encoder")]
207	EncoderSubspaceMismatch,
208	#[error("the dimension of the evaluation domain of the code does not match the parameters")]
209	SubspaceDimensionMismatch,
210	#[error("math error: {0}")]
211	Math(#[from] MathError),
212}
213
214#[cfg(test)]
215mod tests {
216	use binius_field::{
217		BinaryField, PackedBinaryGhash1x128b, PackedBinaryGhash4x128b, PackedField,
218	};
219	use rand::{SeedableRng, rngs::StdRng};
220
221	use super::*;
222	use crate::{
223		FieldBuffer,
224		bit_reverse::reverse_bits,
225		ntt::{NeighborsLastReference, domain_context::GenericPreExpanded},
226		test_utils::random_field_buffer,
227	};
228
229	fn test_encode_batch_helper<P: PackedField>(
230		log_dim: usize,
231		log_inv_rate: usize,
232		log_batch_size: usize,
233	) where
234		P::Scalar: BinaryField,
235	{
236		let mut rng = StdRng::seed_from_u64(0);
237
238		let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate)
239			.expect("Failed to create Reed-Solomon code");
240
241		// Create NTT with matching subspace
242		let subspace = rs_code.subspace().clone();
243		let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
244		let ntt = NeighborsLastReference {
245			domain_context: &domain_context,
246		};
247
248		// Generate random message buffer
249		let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
250
251		// Test the new encode_batch interface
252		let encoded_buffer = rs_code
253			.encode_batch(&ntt, message.to_ref(), log_batch_size)
254			.expect("encode_batch failed");
255
256		// Method 2: Reference implementation - apply NTT with zero-padded coefficients to the
257		// bit-reversal permuted message.
258		let mut reference_buffer = FieldBuffer::zeros(rs_code.log_len() + log_batch_size);
259		for (i, val) in message.iter_scalars().enumerate() {
260			let bits = (rs_code.log_dim() + log_batch_size) as u32;
261			reference_buffer.set(reverse_bits(i, bits), val);
262		}
263
264		// Perform large NTT with zero-padded coefficients.
265		ntt.forward_transform(reference_buffer.to_mut(), 0, log_batch_size);
266
267		// Compare results
268		assert_eq!(
269			encoded_buffer.as_ref(),
270			reference_buffer.as_ref(),
271			"encode_batch_inplace result differs from reference NTT implementation"
272		);
273	}
274
275	#[test]
276	fn test_encode_batch_above_packing_width() {
277		// Test with PackedBinaryGhash1x128b
278		test_encode_batch_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
279		test_encode_batch_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
280		test_encode_batch_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
281
282		// Test with PackedBinaryGhash4x128b
283		test_encode_batch_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
284		test_encode_batch_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
285		test_encode_batch_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
286	}
287
288	#[test]
289	fn test_encode_batch_below_packing_width() {
290		// Test where message length is less than the packing width and codeword length is greater.
291		test_encode_batch_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
292	}
293}