1use std::ptr;
8
9use binius_field::{BinaryField, PackedField};
10use binius_utils::rayon::prelude::*;
11use getset::{CopyGetters, Getters};
12
13use super::{
14 FieldBuffer, FieldSlice, FieldSliceMut, binary_subspace::BinarySubspace, ntt::AdditiveNTT,
15};
16use crate::{
17 bit_reverse::{bit_reverse_indices, bit_reverse_packed},
18 ntt::DomainContext,
19};
20
21#[derive(Debug, Clone, Getters, CopyGetters)]
31pub struct ReedSolomonCode<F> {
32 #[get = "pub"]
33 subspace: BinarySubspace<F>,
34 log_dimension: usize,
35 #[get_copy = "pub"]
36 log_inv_rate: usize,
37}
38
39impl<F: BinaryField> ReedSolomonCode<F> {
40 pub fn new(log_dimension: usize, log_inv_rate: usize) -> Self {
41 let subspace = BinarySubspace::with_dim(log_dimension + log_inv_rate);
42 Self::with_subspace(subspace, log_dimension, log_inv_rate)
43 }
44
45 pub fn with_ntt_subspace(
46 ntt: &impl AdditiveNTT<Field = F>,
47 log_dimension: usize,
48 log_inv_rate: usize,
49 ) -> Self {
50 Self::with_domain_context_subspace(ntt.domain_context(), log_dimension, log_inv_rate)
51 }
52
53 pub fn with_domain_context_subspace(
54 domain_context: &impl DomainContext<Field = F>,
55 log_dimension: usize,
56 log_inv_rate: usize,
57 ) -> Self {
58 let subspace_dim = log_dimension + log_inv_rate;
59 assert!(
60 subspace_dim <= domain_context.log_domain_size(),
61 "precondition: subspace dimension must not exceed domain context log size"
62 );
63 let subspace = domain_context.subspace(subspace_dim);
64 Self::with_subspace(subspace, log_dimension, log_inv_rate)
65 }
66
67 pub fn with_subspace(
68 subspace: BinarySubspace<F>,
69 log_dimension: usize,
70 log_inv_rate: usize,
71 ) -> Self {
72 assert_eq!(
73 subspace.dim(),
74 log_dimension + log_inv_rate,
75 "precondition: subspace dimension must equal log_dimension + log_inv_rate"
76 );
77 Self {
78 subspace,
79 log_dimension,
80 log_inv_rate,
81 }
82 }
83
84 pub const fn dim(&self) -> usize {
86 1 << self.dim_bits()
87 }
88
89 pub const fn log_dim(&self) -> usize {
90 self.log_dimension
91 }
92
93 pub const fn log_len(&self) -> usize {
94 self.log_dimension + self.log_inv_rate
95 }
96
97 #[allow(clippy::len_without_is_empty)]
99 pub const fn len(&self) -> usize {
100 1 << (self.log_dimension + self.log_inv_rate)
101 }
102
103 const fn dim_bits(&self) -> usize {
105 self.log_dimension
106 }
107
108 pub const fn inv_rate(&self) -> usize {
110 1 << self.log_inv_rate
111 }
112
113 pub fn encode_batch<P, NTT>(
127 &self,
128 ntt: &NTT,
129 data: FieldSlice<P>,
130 log_batch_size: usize,
131 ) -> FieldBuffer<P>
132 where
133 P: PackedField<Scalar = F>,
134 NTT: AdditiveNTT<Field = F> + Sync,
135 {
136 assert_eq!(
137 ntt.subspace(self.log_len()),
138 self.subspace,
139 "precondition: NTT subspace must match code subspace"
140 );
141 assert_eq!(
142 data.log_len(),
143 self.log_dim() + log_batch_size,
144 "precondition: data.log_len() must equal log_dim() + log_batch_size"
145 );
146
147 let _scope = tracing::trace_span!(
148 "Reed-Solomon encode",
149 log_len = self.log_len(),
150 log_batch_size = log_batch_size,
151 symbol_bits = F::N_BITS,
152 )
153 .entered();
154
155 let log_output_len = self.log_dim() + log_batch_size + self.log_inv_rate;
157 let output_data = if data.log_len() < P::LOG_WIDTH {
158 let mut scalars = data.iter_scalars().collect::<Vec<_>>();
159 bit_reverse_indices(&mut scalars);
160 let elem_0 = P::from_scalars(scalars.into_iter().cycle());
161 vec![elem_0; 1 << log_output_len.saturating_sub(P::LOG_WIDTH)]
162 } else {
163 let mut output_data = Vec::with_capacity(1 << (log_output_len - P::LOG_WIDTH));
164
165 output_data.extend_from_slice(data.as_ref());
166
167 bit_reverse_packed(FieldSliceMut::from_slice(
169 data.log_len(),
170 output_data.as_mut_slice(),
171 ));
172
173 let log_msg_len_packed = data.log_len() - P::LOG_WIDTH;
174 output_data
175 .spare_capacity_mut()
176 .par_chunks_exact_mut(1 << log_msg_len_packed)
177 .enumerate()
178 .for_each(|(i, output_chunk)| unsafe {
179 let dst_ptr = output_chunk.as_mut_ptr();
180
181 let src_ptr = dst_ptr.sub((i + 1) << log_msg_len_packed);
189 ptr::copy_nonoverlapping(src_ptr, dst_ptr, 1 << log_msg_len_packed);
190 });
191
192 unsafe {
193 output_data.set_len(1 << (log_output_len - P::LOG_WIDTH));
195 }
196
197 output_data
198 };
199 let mut output = FieldBuffer::new(log_output_len, output_data.into_boxed_slice());
200
201 ntt.forward_transform(output.to_mut(), self.log_inv_rate, log_batch_size);
202 output
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use binius_field::{
209 BinaryField, PackedBinaryGhash1x128b, PackedBinaryGhash4x128b, PackedField,
210 };
211 use rand::{SeedableRng, rngs::StdRng};
212
213 use super::*;
214 use crate::{
215 FieldBuffer,
216 bit_reverse::reverse_bits,
217 ntt::{NeighborsLastReference, domain_context::GenericPreExpanded},
218 test_utils::random_field_buffer,
219 };
220
221 fn test_encode_batch_helper<P: PackedField>(
222 log_dim: usize,
223 log_inv_rate: usize,
224 log_batch_size: usize,
225 ) where
226 P::Scalar: BinaryField,
227 {
228 let mut rng = StdRng::seed_from_u64(0);
229
230 let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate);
231
232 let subspace = rs_code.subspace().clone();
234 let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
235 let ntt = NeighborsLastReference {
236 domain_context: &domain_context,
237 };
238
239 let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
241
242 let encoded_buffer = rs_code.encode_batch(&ntt, message.to_ref(), log_batch_size);
244
245 let mut reference_buffer = FieldBuffer::zeros(rs_code.log_len() + log_batch_size);
248 for (i, val) in message.iter_scalars().enumerate() {
249 let bits = (rs_code.log_dim() + log_batch_size) as u32;
250 reference_buffer.set(reverse_bits(i, bits), val);
251 }
252
253 ntt.forward_transform(reference_buffer.to_mut(), 0, log_batch_size);
255
256 assert_eq!(
258 encoded_buffer.as_ref(),
259 reference_buffer.as_ref(),
260 "encode_batch_inplace result differs from reference NTT implementation"
261 );
262 }
263
264 #[test]
265 fn test_encode_batch_above_packing_width() {
266 test_encode_batch_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
268 test_encode_batch_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
269 test_encode_batch_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
270
271 test_encode_batch_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
273 test_encode_batch_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
274 test_encode_batch_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
275 }
276
277 #[test]
278 fn test_encode_batch_below_packing_width() {
279 test_encode_batch_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
281 }
282}