1use std::ptr;
8
9use binius_field::{BinaryField, PackedField};
10use binius_utils::rayon::prelude::*;
11use getset::{CopyGetters, Getters};
12
13use super::{
14 FieldBuffer, FieldSlice, FieldSliceMut, binary_subspace::BinarySubspace,
15 error::Error as MathError, ntt::AdditiveNTT,
16};
17use crate::{
18 bit_reverse::{bit_reverse_indices, bit_reverse_packed},
19 ntt::DomainContext,
20};
21
22#[derive(Debug, Clone, Getters, CopyGetters)]
32pub struct ReedSolomonCode<F> {
33 #[get = "pub"]
34 subspace: BinarySubspace<F>,
35 log_dimension: usize,
36 #[get_copy = "pub"]
37 log_inv_rate: usize,
38}
39
40impl<F: BinaryField> ReedSolomonCode<F> {
41 pub fn new(log_dimension: usize, log_inv_rate: usize) -> Result<Self, Error> {
42 let subspace = BinarySubspace::with_dim(log_dimension + log_inv_rate)?;
43 Self::with_subspace(subspace, log_dimension, log_inv_rate)
44 }
45
46 pub fn with_ntt_subspace(
47 ntt: &impl AdditiveNTT<Field = F>,
48 log_dimension: usize,
49 log_inv_rate: usize,
50 ) -> Result<Self, Error> {
51 Self::with_domain_context_subspace(ntt.domain_context(), log_dimension, log_inv_rate)
52 }
53
54 pub fn with_domain_context_subspace(
55 domain_context: &impl DomainContext<Field = F>,
56 log_dimension: usize,
57 log_inv_rate: usize,
58 ) -> Result<Self, Error> {
59 let subspace_dim = log_dimension + log_inv_rate;
60 if subspace_dim > domain_context.log_domain_size() {
61 return Err(Error::SubspaceDimensionMismatch);
62 }
63 let subspace = domain_context.subspace(subspace_dim);
64 Self::with_subspace(subspace, log_dimension, log_inv_rate)
65 }
66
67 pub fn with_subspace(
68 subspace: BinarySubspace<F>,
69 log_dimension: usize,
70 log_inv_rate: usize,
71 ) -> Result<Self, Error> {
72 if subspace.dim() != log_dimension + log_inv_rate {
73 return Err(Error::SubspaceDimensionMismatch);
74 }
75 Ok(Self {
76 subspace,
77 log_dimension,
78 log_inv_rate,
79 })
80 }
81
82 pub const fn dim(&self) -> usize {
84 1 << self.dim_bits()
85 }
86
87 pub const fn log_dim(&self) -> usize {
88 self.log_dimension
89 }
90
91 pub const fn log_len(&self) -> usize {
92 self.log_dimension + self.log_inv_rate
93 }
94
95 #[allow(clippy::len_without_is_empty)]
97 pub const fn len(&self) -> usize {
98 1 << (self.log_dimension + self.log_inv_rate)
99 }
100
101 const fn dim_bits(&self) -> usize {
103 self.log_dimension
104 }
105
106 pub const fn inv_rate(&self) -> usize {
108 1 << self.log_inv_rate
109 }
110
111 pub fn encode_batch<P, NTT>(
129 &self,
130 ntt: &NTT,
131 data: FieldSlice<P>,
132 log_batch_size: usize,
133 ) -> Result<FieldBuffer<P>, Error>
134 where
135 P: PackedField<Scalar = F>,
136 NTT: AdditiveNTT<Field = F> + Sync,
137 {
138 if ntt.subspace(self.log_len()) != self.subspace {
139 return Err(Error::EncoderSubspaceMismatch);
140 }
141
142 assert_eq!(data.log_len(), self.log_dim() + log_batch_size); let _scope = tracing::trace_span!(
145 "Reed-Solomon encode",
146 log_len = self.log_len(),
147 log_batch_size = log_batch_size,
148 symbol_bits = F::N_BITS,
149 )
150 .entered();
151
152 let log_output_len = self.log_dim() + log_batch_size + self.log_inv_rate;
154 let output_data = if data.log_len() < P::LOG_WIDTH {
155 let mut scalars = data.iter_scalars().collect::<Vec<_>>();
156 bit_reverse_indices(&mut scalars);
157 let elem_0 = P::from_scalars(scalars.into_iter().cycle());
158 vec![elem_0; 1 << log_output_len.saturating_sub(P::LOG_WIDTH)]
159 } else {
160 let mut output_data = Vec::with_capacity(1 << (log_output_len - P::LOG_WIDTH));
161
162 output_data.extend_from_slice(data.as_ref());
163
164 bit_reverse_packed(
166 FieldSliceMut::from_slice(data.log_len(), output_data.as_mut_slice())
167 .expect("output_data.len() == data.as_ref.len()"),
168 );
169
170 let log_msg_len_packed = data.log_len() - P::LOG_WIDTH;
171 output_data
172 .spare_capacity_mut()
173 .par_chunks_exact_mut(1 << log_msg_len_packed)
174 .enumerate()
175 .for_each(|(i, output_chunk)| unsafe {
176 let dst_ptr = output_chunk.as_mut_ptr();
177
178 let src_ptr = dst_ptr.sub((i + 1) << log_msg_len_packed);
186 ptr::copy_nonoverlapping(src_ptr, dst_ptr, 1 << log_msg_len_packed);
187 });
188
189 unsafe {
190 output_data.set_len(1 << (log_output_len - P::LOG_WIDTH));
192 }
193
194 output_data
195 };
196 let mut output = FieldBuffer::new(log_output_len, output_data.into_boxed_slice())
197 .expect("preconditions satisfied");
198
199 ntt.forward_transform(output.to_mut(), self.log_inv_rate, log_batch_size);
200 Ok(output)
201 }
202}
203
204#[derive(Debug, thiserror::Error)]
205pub enum Error {
206 #[error("the evaluation domain of the code does not match the subspace of the NTT encoder")]
207 EncoderSubspaceMismatch,
208 #[error("the dimension of the evaluation domain of the code does not match the parameters")]
209 SubspaceDimensionMismatch,
210 #[error("math error: {0}")]
211 Math(#[from] MathError),
212}
213
214#[cfg(test)]
215mod tests {
216 use binius_field::{
217 BinaryField, PackedBinaryGhash1x128b, PackedBinaryGhash4x128b, PackedField,
218 };
219 use rand::{SeedableRng, rngs::StdRng};
220
221 use super::*;
222 use crate::{
223 FieldBuffer,
224 bit_reverse::reverse_bits,
225 ntt::{NeighborsLastReference, domain_context::GenericPreExpanded},
226 test_utils::random_field_buffer,
227 };
228
229 fn test_encode_batch_helper<P: PackedField>(
230 log_dim: usize,
231 log_inv_rate: usize,
232 log_batch_size: usize,
233 ) where
234 P::Scalar: BinaryField,
235 {
236 let mut rng = StdRng::seed_from_u64(0);
237
238 let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate)
239 .expect("Failed to create Reed-Solomon code");
240
241 let subspace = rs_code.subspace().clone();
243 let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
244 let ntt = NeighborsLastReference {
245 domain_context: &domain_context,
246 };
247
248 let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
250
251 let encoded_buffer = rs_code
253 .encode_batch(&ntt, message.to_ref(), log_batch_size)
254 .expect("encode_batch failed");
255
256 let mut reference_buffer = FieldBuffer::zeros(rs_code.log_len() + log_batch_size);
259 for (i, val) in message.iter_scalars().enumerate() {
260 let bits = (rs_code.log_dim() + log_batch_size) as u32;
261 reference_buffer.set(reverse_bits(i, bits), val);
262 }
263
264 ntt.forward_transform(reference_buffer.to_mut(), 0, log_batch_size);
266
267 assert_eq!(
269 encoded_buffer.as_ref(),
270 reference_buffer.as_ref(),
271 "encode_batch_inplace result differs from reference NTT implementation"
272 );
273 }
274
275 #[test]
276 fn test_encode_batch_above_packing_width() {
277 test_encode_batch_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
279 test_encode_batch_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
280 test_encode_batch_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
281
282 test_encode_batch_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
284 test_encode_batch_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
285 test_encode_batch_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
286 }
287
288 #[test]
289 fn test_encode_batch_below_packing_width() {
290 test_encode_batch_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
292 }
293}