binius_math/
reed_solomon.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3//! [Reed–Solomon] codes over binary fields.
4//!
5//! See [`ReedSolomonCode`] for details.
6
7use std::ptr;
8
9use binius_field::{BinaryField, PackedField};
10use binius_utils::rayon::prelude::*;
11use getset::{CopyGetters, Getters};
12
13use super::{
14	FieldBuffer, FieldSlice, FieldSliceMut, binary_subspace::BinarySubspace, ntt::AdditiveNTT,
15};
16use crate::{
17	bit_reverse::{bit_reverse_indices, bit_reverse_packed},
18	ntt::DomainContext,
19};
20
21/// [Reed–Solomon] codes over binary fields.
22///
23/// The Reed–Solomon code admits an efficient encoding algorithm over binary fields due to [LCH14].
24/// The additive NTT encoding algorithm encodes messages interpreted as the coefficients of a
25/// polynomial in a non-standard, novel polynomial basis and the codewords are the polynomial
26/// evaluations over a linear subspace of the field. See the [binius-math] crate for more details.
27///
28/// [Reed–Solomon]: <https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction>
29/// [LCH14]: <https://arxiv.org/abs/1404.3458>
30#[derive(Debug, Clone, Getters, CopyGetters)]
31pub struct ReedSolomonCode<F> {
32	#[get = "pub"]
33	subspace: BinarySubspace<F>,
34	log_dimension: usize,
35	#[get_copy = "pub"]
36	log_inv_rate: usize,
37}
38
39impl<F: BinaryField> ReedSolomonCode<F> {
40	pub fn new(log_dimension: usize, log_inv_rate: usize) -> Self {
41		let subspace = BinarySubspace::with_dim(log_dimension + log_inv_rate);
42		Self::with_subspace(subspace, log_dimension, log_inv_rate)
43	}
44
45	pub fn with_ntt_subspace(
46		ntt: &impl AdditiveNTT<Field = F>,
47		log_dimension: usize,
48		log_inv_rate: usize,
49	) -> Self {
50		Self::with_domain_context_subspace(ntt.domain_context(), log_dimension, log_inv_rate)
51	}
52
53	pub fn with_domain_context_subspace(
54		domain_context: &impl DomainContext<Field = F>,
55		log_dimension: usize,
56		log_inv_rate: usize,
57	) -> Self {
58		let subspace_dim = log_dimension + log_inv_rate;
59		assert!(
60			subspace_dim <= domain_context.log_domain_size(),
61			"precondition: subspace dimension must not exceed domain context log size"
62		);
63		let subspace = domain_context.subspace(subspace_dim);
64		Self::with_subspace(subspace, log_dimension, log_inv_rate)
65	}
66
67	pub fn with_subspace(
68		subspace: BinarySubspace<F>,
69		log_dimension: usize,
70		log_inv_rate: usize,
71	) -> Self {
72		assert_eq!(
73			subspace.dim(),
74			log_dimension + log_inv_rate,
75			"precondition: subspace dimension must equal log_dimension + log_inv_rate"
76		);
77		Self {
78			subspace,
79			log_dimension,
80			log_inv_rate,
81		}
82	}
83
84	/// The dimension.
85	pub const fn dim(&self) -> usize {
86		1 << self.dim_bits()
87	}
88
89	pub const fn log_dim(&self) -> usize {
90		self.log_dimension
91	}
92
93	pub const fn log_len(&self) -> usize {
94		self.log_dimension + self.log_inv_rate
95	}
96
97	/// The block length.
98	#[allow(clippy::len_without_is_empty)]
99	pub const fn len(&self) -> usize {
100		1 << (self.log_dimension + self.log_inv_rate)
101	}
102
103	/// The base-2 log of the dimension.
104	const fn dim_bits(&self) -> usize {
105		self.log_dimension
106	}
107
108	/// The reciprocal of the rate, ie. `self.len() / self.dim()`.
109	pub const fn inv_rate(&self) -> usize {
110		1 << self.log_inv_rate
111	}
112
113	/// Encodes a message with an interleaved Reed–Solomon code.
114	///
115	/// This function interprets the message as a batch of independent vectors and applies an
116	/// interleaved Reed–Solomon.
117	///
118	/// ## Preconditions
119	///
120	/// * `data.log_len()` must equal `log_dim() + log_batch_size`.
121	/// * The NTT subspace must match the code's subspace.
122	///
123	/// ## Postconditions
124	///
125	/// * All elements in the output buffer are initialized with the encoded codeword.
126	pub fn encode_batch<P, NTT>(
127		&self,
128		ntt: &NTT,
129		data: FieldSlice<P>,
130		log_batch_size: usize,
131	) -> FieldBuffer<P>
132	where
133		P: PackedField<Scalar = F>,
134		NTT: AdditiveNTT<Field = F> + Sync,
135	{
136		assert_eq!(
137			ntt.subspace(self.log_len()),
138			self.subspace,
139			"precondition: NTT subspace must match code subspace"
140		);
141		assert_eq!(
142			data.log_len(),
143			self.log_dim() + log_batch_size,
144			"precondition: data.log_len() must equal log_dim() + log_batch_size"
145		);
146
147		let _scope = tracing::trace_span!(
148			"Reed-Solomon encode",
149			log_len = self.log_len(),
150			log_batch_size = log_batch_size,
151			symbol_bits = F::N_BITS,
152		)
153		.entered();
154
155		// Repeat the message to fill the entire buffer.
156		let log_output_len = self.log_dim() + log_batch_size + self.log_inv_rate;
157		let output_data = if data.log_len() < P::LOG_WIDTH {
158			let mut scalars = data.iter_scalars().collect::<Vec<_>>();
159			bit_reverse_indices(&mut scalars);
160			let elem_0 = P::from_scalars(scalars.into_iter().cycle());
161			vec![elem_0; 1 << log_output_len.saturating_sub(P::LOG_WIDTH)]
162		} else {
163			let mut output_data = Vec::with_capacity(1 << (log_output_len - P::LOG_WIDTH));
164
165			output_data.extend_from_slice(data.as_ref());
166
167			// Bit-reverse permute the message.
168			bit_reverse_packed(FieldSliceMut::from_slice(
169				data.log_len(),
170				output_data.as_mut_slice(),
171			));
172
173			let log_msg_len_packed = data.log_len() - P::LOG_WIDTH;
174			output_data
175				.spare_capacity_mut()
176				.par_chunks_exact_mut(1 << log_msg_len_packed)
177				.enumerate()
178				.for_each(|(i, output_chunk)| unsafe {
179					let dst_ptr = output_chunk.as_mut_ptr();
180
181					// TODO(https://github.com/rust-lang/rust/issues/81944):
182					// Improve unsafe code with Vec::split_at_spare_mut when stable
183
184					// Safety:
185					// - log_output_len == log_msg_len_packed + self.log_inv_rate
186					// - i + 1 is in the range 1..1 << self.log_inv_rate
187					// - dst_ptr is disjoint from src_ptr and within the Vec capacity
188					let src_ptr = dst_ptr.sub((i + 1) << log_msg_len_packed);
189					ptr::copy_nonoverlapping(src_ptr, dst_ptr, 1 << log_msg_len_packed);
190				});
191
192			unsafe {
193				// Safety: the vec's spare capacity is fully initialized above.
194				output_data.set_len(1 << (log_output_len - P::LOG_WIDTH));
195			}
196
197			output_data
198		};
199		let mut output = FieldBuffer::new(log_output_len, output_data.into_boxed_slice());
200
201		ntt.forward_transform(output.to_mut(), self.log_inv_rate, log_batch_size);
202		output
203	}
204}
205
206#[cfg(test)]
207mod tests {
208	use binius_field::{
209		BinaryField, PackedBinaryGhash1x128b, PackedBinaryGhash4x128b, PackedField,
210	};
211	use rand::{SeedableRng, rngs::StdRng};
212
213	use super::*;
214	use crate::{
215		FieldBuffer,
216		bit_reverse::reverse_bits,
217		ntt::{NeighborsLastReference, domain_context::GenericPreExpanded},
218		test_utils::random_field_buffer,
219	};
220
221	fn test_encode_batch_helper<P: PackedField>(
222		log_dim: usize,
223		log_inv_rate: usize,
224		log_batch_size: usize,
225	) where
226		P::Scalar: BinaryField,
227	{
228		let mut rng = StdRng::seed_from_u64(0);
229
230		let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate);
231
232		// Create NTT with matching subspace
233		let subspace = rs_code.subspace().clone();
234		let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
235		let ntt = NeighborsLastReference {
236			domain_context: &domain_context,
237		};
238
239		// Generate random message buffer
240		let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
241
242		// Test the new encode_batch interface
243		let encoded_buffer = rs_code.encode_batch(&ntt, message.to_ref(), log_batch_size);
244
245		// Method 2: Reference implementation - apply NTT with zero-padded coefficients to the
246		// bit-reversal permuted message.
247		let mut reference_buffer = FieldBuffer::zeros(rs_code.log_len() + log_batch_size);
248		for (i, val) in message.iter_scalars().enumerate() {
249			let bits = (rs_code.log_dim() + log_batch_size) as u32;
250			reference_buffer.set(reverse_bits(i, bits), val);
251		}
252
253		// Perform large NTT with zero-padded coefficients.
254		ntt.forward_transform(reference_buffer.to_mut(), 0, log_batch_size);
255
256		// Compare results
257		assert_eq!(
258			encoded_buffer.as_ref(),
259			reference_buffer.as_ref(),
260			"encode_batch_inplace result differs from reference NTT implementation"
261		);
262	}
263
264	#[test]
265	fn test_encode_batch_above_packing_width() {
266		// Test with PackedBinaryGhash1x128b
267		test_encode_batch_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
268		test_encode_batch_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
269		test_encode_batch_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
270
271		// Test with PackedBinaryGhash4x128b
272		test_encode_batch_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
273		test_encode_batch_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
274		test_encode_batch_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
275	}
276
277	#[test]
278	fn test_encode_batch_below_packing_width() {
279		// Test where message length is less than the packing width and codeword length is greater.
280		test_encode_batch_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
281	}
282}