binius_core/reed_solomon/
reed_solomon.rs1use std::marker::PhantomData;
14
15use binius_field::{BinaryField, ExtensionField, PackedField, RepackedExtension};
16use binius_maybe_rayon::prelude::*;
17use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings};
18use binius_utils::bail;
19use getset::CopyGetters;
20use tracing::instrument;
21
22#[derive(Debug, CopyGetters)]
23pub struct ReedSolomonCode<P>
24where
25 P: PackedField,
26 P::Scalar: BinaryField,
27{
28 ntt: DynamicDispatchNTT<P::Scalar>,
29 log_dimension: usize,
30 #[getset(get_copy = "pub")]
31 log_inv_rate: usize,
32 multithreaded: bool,
33 _p_marker: PhantomData<P>,
34}
35
36impl<P> ReedSolomonCode<P>
37where
38 P: PackedField<Scalar: BinaryField>,
39{
40 pub fn new(
41 log_dimension: usize,
42 log_inv_rate: usize,
43 ntt_options: &NTTOptions,
44 ) -> Result<Self, Error> {
45 let ntt_log_threads = ntt_options
47 .thread_settings
48 .log_threads_count()
49 .saturating_sub(log_inv_rate);
50 let ntt = DynamicDispatchNTT::new(
51 log_dimension + log_inv_rate,
52 &NTTOptions {
53 thread_settings: ThreadingSettings::ExplicitThreadsCount {
54 log_threads: ntt_log_threads,
55 },
56 precompute_twiddles: ntt_options.precompute_twiddles,
57 },
58 )?;
59
60 let multithreaded =
61 !matches!(ntt_options.thread_settings, ThreadingSettings::SingleThreaded);
62
63 Ok(Self {
64 ntt,
65 log_dimension,
66 log_inv_rate,
67 multithreaded,
68 _p_marker: PhantomData,
69 })
70 }
71
72 pub const fn get_ntt(&self) -> &impl AdditiveNTT<P::Scalar> {
73 &self.ntt
74 }
75
76 pub const fn dim(&self) -> usize {
78 1 << self.dim_bits()
79 }
80
81 pub const fn log_dim(&self) -> usize {
82 self.log_dimension
83 }
84
85 pub const fn log_len(&self) -> usize {
86 self.log_dimension + self.log_inv_rate
87 }
88
89 #[allow(clippy::len_without_is_empty)]
91 pub const fn len(&self) -> usize {
92 1 << (self.log_dimension + self.log_inv_rate)
93 }
94
95 const fn dim_bits(&self) -> usize {
97 self.log_dimension
98 }
99
100 pub const fn inv_rate(&self) -> usize {
102 1 << self.log_inv_rate
103 }
104
105 fn encode_batch_inplace(&self, code: &mut [P], log_batch_size: usize) -> Result<(), Error> {
116 let _scope = tracing::trace_span!(
117 "Reed–Solomon encode",
118 log_len = self.log_len(),
119 log_batch_size = log_batch_size,
120 symbol_bits = P::Scalar::N_BITS,
121 )
122 .entered();
123 if (code.len() << log_batch_size) < self.len() {
124 bail!(Error::BufferTooSmall {
125 log_code_len: self.len(),
126 });
127 }
128 if self.dim() % P::WIDTH != 0 {
129 bail!(Error::PackingWidthMustDivideDimension);
130 }
131
132 let msgs_len = (self.dim() / P::WIDTH) << log_batch_size;
133 for i in 1..(1 << self.log_inv_rate) {
134 code.copy_within(0..msgs_len, i * msgs_len);
135 }
136
137 if self.multithreaded {
138 (0..(1 << self.log_inv_rate))
139 .into_par_iter()
140 .zip(code.par_chunks_exact_mut(msgs_len))
141 .try_for_each(|(i, data)| {
142 self.ntt
143 .forward_transform(data, i, log_batch_size, self.log_dim())
144 })
145 } else {
146 (0..(1 << self.log_inv_rate))
147 .zip(code.chunks_exact_mut(msgs_len))
148 .try_for_each(|(i, data)| {
149 self.ntt
150 .forward_transform(data, i, log_batch_size, self.log_dim())
151 })
152 }
153 }
154
155 #[instrument(skip_all, level = "debug")]
169 pub fn encode_ext_batch_inplace<PE: RepackedExtension<P>>(
170 &self,
171 code: &mut [PE],
172 log_batch_size: usize,
173 ) -> Result<(), Error> {
174 self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + PE::Scalar::LOG_DEGREE)
175 }
176}