binius_core/reed_solomon/
reed_solomon.rs1use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField};
8use binius_math::BinarySubspace;
9use binius_ntt::{AdditiveNTT, NTTShape, SingleThreadedNTT};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15#[derive(Debug, Getters, CopyGetters)]
25pub struct ReedSolomonCode<F: BinaryField> {
26 #[get = "pub"]
27 subspace: BinarySubspace<F>,
28 log_dimension: usize,
29 #[get_copy = "pub"]
30 log_inv_rate: usize,
31}
32
33impl<F: BinaryField> ReedSolomonCode<F> {
34 pub fn new(log_dimension: usize, log_inv_rate: usize) -> Result<Self, Error> {
35 let ntt = SingleThreadedNTT::new(log_dimension + log_inv_rate)?;
36 Self::with_ntt_subspace(&ntt, log_dimension, log_inv_rate)
37 }
38
39 pub fn with_ntt_subspace(
40 ntt: &impl AdditiveNTT<F>,
41 log_dimension: usize,
42 log_inv_rate: usize,
43 ) -> Result<Self, Error> {
44 if log_dimension + log_inv_rate > ntt.log_domain_size() {
45 return Err(Error::SubspaceDimensionMismatch);
46 }
47 let subspace_idx = ntt.log_domain_size() - (log_dimension + log_inv_rate);
48 Self::with_subspace(ntt.subspace(subspace_idx), log_dimension, log_inv_rate)
49 }
50
51 pub fn with_subspace(
52 subspace: BinarySubspace<F>,
53 log_dimension: usize,
54 log_inv_rate: usize,
55 ) -> Result<Self, Error> {
56 if subspace.dim() != log_dimension + log_inv_rate {
57 return Err(Error::SubspaceDimensionMismatch);
58 }
59 Ok(Self {
60 subspace,
61 log_dimension,
62 log_inv_rate,
63 })
64 }
65
66 pub const fn dim(&self) -> usize {
68 1 << self.dim_bits()
69 }
70
71 pub const fn log_dim(&self) -> usize {
72 self.log_dimension
73 }
74
75 pub const fn log_len(&self) -> usize {
76 self.log_dimension + self.log_inv_rate
77 }
78
79 #[allow(clippy::len_without_is_empty)]
81 pub const fn len(&self) -> usize {
82 1 << (self.log_dimension + self.log_inv_rate)
83 }
84
85 const fn dim_bits(&self) -> usize {
87 self.log_dimension
88 }
89
90 pub const fn inv_rate(&self) -> usize {
92 1 << self.log_inv_rate
93 }
94
95 fn encode_batch_inplace<P: PackedField<Scalar = F>, NTT: AdditiveNTT<F> + Sync>(
105 &self,
106 ntt: &NTT,
107 code: &mut [P],
108 log_batch_size: usize,
109 ) -> Result<(), Error> {
110 if ntt.subspace(ntt.log_domain_size() - self.log_len()) != self.subspace {
111 bail!(Error::EncoderSubspaceMismatch);
112 }
113 let expected_buffer_len =
114 1 << (self.log_len() + log_batch_size).saturating_sub(P::LOG_WIDTH);
115 if code.len() != expected_buffer_len {
116 bail!(Error::IncorrectBufferLength {
117 expected: expected_buffer_len,
118 actual: code.len(),
119 });
120 }
121
122 let _scope = tracing::trace_span!(
123 "Reed–Solomon encode",
124 log_len = self.log_len(),
125 log_batch_size = log_batch_size,
126 symbol_bits = F::N_BITS,
127 )
128 .entered();
129
130 if self.dim() + log_batch_size < P::LOG_WIDTH {
135 let repeated_values = code[0]
136 .into_iter()
137 .take(1 << (self.log_dim() + log_batch_size))
138 .cycle();
139 code[0] = P::from_scalars(repeated_values);
140 }
141
142 let mut chunks =
144 code.chunks_mut(1 << (self.log_dim() + log_batch_size).saturating_sub(P::LOG_WIDTH));
145 let first_chunk = chunks.next().expect("code is not empty; checked above");
146 for chunk in chunks {
147 chunk.copy_from_slice(first_chunk);
148 }
149
150 let shape = NTTShape {
151 log_x: log_batch_size,
152 log_y: self.log_len(),
153 ..Default::default()
154 };
155 ntt.forward_transform(code, shape, 0, 0, self.log_inv_rate)?;
156 Ok(())
157 }
158
159 pub fn encode_ext_batch_inplace<PE: PackedExtension<F>, NTT: AdditiveNTT<F> + Sync>(
173 &self,
174 ntt: &NTT,
175 code: &mut [PE],
176 log_batch_size: usize,
177 ) -> Result<(), Error> {
178 self.encode_batch_inplace(
179 ntt,
180 PE::cast_bases_mut(code),
181 log_batch_size + PE::Scalar::LOG_DEGREE,
182 )
183 }
184}