binius_core/reed_solomon/
reed_solomon.rsuse std::marker::PhantomData;
use binius_field::{BinaryField, PackedField};
use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings};
use binius_utils::bail;
use getset::CopyGetters;
use rayon::prelude::*;
use crate::linear_code::LinearCode;
#[derive(Debug, CopyGetters)]
pub struct ReedSolomonCode<P>
where
P: PackedField,
P::Scalar: BinaryField,
{
ntt: DynamicDispatchNTT<P::Scalar>,
log_dimension: usize,
#[getset(get_copy = "pub")]
log_inv_rate: usize,
multithreaded: bool,
_p_marker: PhantomData<P>,
}
impl<P> ReedSolomonCode<P>
where
P: PackedField<Scalar: BinaryField>,
{
pub fn new(
log_dimension: usize,
log_inv_rate: usize,
ntt_options: NTTOptions,
) -> Result<Self, Error> {
let ntt_log_threads = ntt_options
.thread_settings
.log_threads_count()
.saturating_sub(log_inv_rate);
let ntt = DynamicDispatchNTT::new(
log_dimension + log_inv_rate,
NTTOptions {
thread_settings: ThreadingSettings::ExplicitThreadsCount {
log_threads: ntt_log_threads,
},
..ntt_options
},
)?;
let multithreaded =
!matches!(ntt_options.thread_settings, ThreadingSettings::SingleThreaded);
Ok(Self {
ntt,
log_dimension,
log_inv_rate,
multithreaded,
_p_marker: PhantomData,
})
}
pub fn get_ntt(&self) -> &impl AdditiveNTT<P> {
&self.ntt
}
pub fn log_dim(&self) -> usize {
self.log_dimension
}
pub fn log_len(&self) -> usize {
self.log_dimension + self.log_inv_rate
}
}
impl<P, F> LinearCode for ReedSolomonCode<P>
where
P: PackedField<Scalar = F>,
F: BinaryField,
{
type P = P;
type EncodeError = Error;
fn len(&self) -> usize {
1 << (self.log_dimension + self.log_inv_rate)
}
fn dim_bits(&self) -> usize {
self.log_dimension
}
fn min_dist(&self) -> usize {
self.len() - self.dim() + 1
}
fn inv_rate(&self) -> usize {
1 << self.log_inv_rate
}
fn encode_batch_inplace(
&self,
code: &mut [Self::P],
log_batch_size: usize,
) -> Result<(), Self::EncodeError> {
let _scope = tracing::trace_span!(
"Reed–Solomon encode",
log_len = self.log_len(),
log_batch_size = log_batch_size,
symbol_bits = F::N_BITS,
)
.entered();
if (code.len() << log_batch_size) < self.len() {
bail!(Error::BufferTooSmall {
log_code_len: self.len(),
});
}
if self.dim() % P::WIDTH != 0 {
bail!(Error::PackingWidthMustDivideDimension);
}
let msgs_len = (self.dim() / P::WIDTH) << log_batch_size;
for i in 1..(1 << self.log_inv_rate) {
code.copy_within(0..msgs_len, i * msgs_len);
}
if self.multithreaded {
(0..(1 << self.log_inv_rate))
.into_par_iter()
.zip(code.par_chunks_exact_mut(msgs_len))
.try_for_each(|(i, data)| self.ntt.forward_transform(data, i, log_batch_size))
} else {
(0..(1 << self.log_inv_rate))
.zip(code.chunks_exact_mut(msgs_len))
.try_for_each(|(i, data)| self.ntt.forward_transform(data, i, log_batch_size))
}
}
}