binius_core/reed_solomon/
reed_solomon.rs1use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField};
8use binius_math::BinarySubspace;
9use binius_ntt::{AdditiveNTT, NTTShape};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15#[derive(Debug, Getters, CopyGetters)]
25pub struct ReedSolomonCode<F: BinaryField> {
26 #[get = "pub"]
27 subspace: BinarySubspace<F>,
28 log_dimension: usize,
29 #[get_copy = "pub"]
30 log_inv_rate: usize,
31}
32
33impl<F: BinaryField> ReedSolomonCode<F> {
34 pub fn new(log_dimension: usize, log_inv_rate: usize) -> Result<Self, Error> {
35 Self::with_subspace(
36 BinarySubspace::with_dim(log_dimension + log_inv_rate)?,
37 log_dimension,
38 log_inv_rate,
39 )
40 }
41
42 pub fn with_subspace(
43 subspace: BinarySubspace<F>,
44 log_dimension: usize,
45 log_inv_rate: usize,
46 ) -> Result<Self, Error> {
47 if subspace.dim() != log_dimension + log_inv_rate {
48 return Err(Error::SubspaceDimensionMismatch);
49 }
50 Ok(Self {
51 subspace,
52 log_dimension,
53 log_inv_rate,
54 })
55 }
56
57 pub const fn dim(&self) -> usize {
59 1 << self.dim_bits()
60 }
61
62 pub const fn log_dim(&self) -> usize {
63 self.log_dimension
64 }
65
66 pub const fn log_len(&self) -> usize {
67 self.log_dimension + self.log_inv_rate
68 }
69
70 #[allow(clippy::len_without_is_empty)]
72 pub const fn len(&self) -> usize {
73 1 << (self.log_dimension + self.log_inv_rate)
74 }
75
76 const fn dim_bits(&self) -> usize {
78 self.log_dimension
79 }
80
81 pub const fn inv_rate(&self) -> usize {
83 1 << self.log_inv_rate
84 }
85
86 fn encode_batch_inplace<P: PackedField<Scalar = F>, NTT: AdditiveNTT<F> + Sync>(
97 &self,
98 ntt: &NTT,
99 code: &mut [P],
100 log_batch_size: usize,
101 ) -> Result<(), Error> {
102 if ntt.subspace(ntt.log_domain_size() - self.log_len()) != self.subspace {
103 bail!(Error::EncoderSubspaceMismatch);
104 }
105 let expected_buffer_len =
106 1 << (self.log_len() + log_batch_size).saturating_sub(P::LOG_WIDTH);
107 if code.len() != expected_buffer_len {
108 bail!(Error::IncorrectBufferLength {
109 expected: expected_buffer_len,
110 actual: code.len(),
111 });
112 }
113
114 let _scope = tracing::trace_span!(
115 "Reed–Solomon encode",
116 log_len = self.log_len(),
117 log_batch_size = log_batch_size,
118 symbol_bits = F::N_BITS,
119 )
120 .entered();
121
122 if self.dim() + log_batch_size < P::LOG_WIDTH {
127 let repeated_values = code[0]
128 .into_iter()
129 .take(1 << (self.log_dim() + log_batch_size))
130 .cycle();
131 code[0] = P::from_scalars(repeated_values);
132 }
133
134 let mut chunks =
136 code.chunks_mut(1 << (self.log_dim() + log_batch_size).saturating_sub(P::LOG_WIDTH));
137 let first_chunk = chunks.next().expect("code is not empty; checked above");
138 for chunk in chunks {
139 chunk.copy_from_slice(first_chunk);
140 }
141
142 let shape = NTTShape {
143 log_x: log_batch_size,
144 log_y: self.log_len(),
145 ..Default::default()
146 };
147 ntt.forward_transform(code, shape, 0, self.log_inv_rate)?;
148 Ok(())
149 }
150
151 pub fn encode_ext_batch_inplace<PE: PackedExtension<F>, NTT: AdditiveNTT<F> + Sync>(
165 &self,
166 ntt: &NTT,
167 code: &mut [PE],
168 log_batch_size: usize,
169 ) -> Result<(), Error> {
170 self.encode_batch_inplace(
171 ntt,
172 PE::cast_bases_mut(code),
173 log_batch_size + PE::Scalar::LOG_DEGREE,
174 )
175 }
176}