1use std::mem::MaybeUninit;
8
9use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField};
10use binius_utils::rayon::{
11 iter::{IntoParallelRefMutIterator, ParallelBridge, ParallelIterator},
12 slice::ParallelSliceMut,
13};
14use getset::{CopyGetters, Getters};
15
16use super::{binary_subspace::BinarySubspace, error::Error as MathError, ntt::AdditiveNTT};
17use crate::{FieldSliceMut, ntt::DomainContext};
18
19#[derive(Debug, Clone, Getters, CopyGetters)]
29pub struct ReedSolomonCode<F> {
30 #[get = "pub"]
31 subspace: BinarySubspace<F>,
32 log_dimension: usize,
33 #[get_copy = "pub"]
34 log_inv_rate: usize,
35}
36
37impl<F: BinaryField> ReedSolomonCode<F> {
38 pub fn new(log_dimension: usize, log_inv_rate: usize) -> Result<Self, Error> {
39 let subspace = BinarySubspace::with_dim(log_dimension + log_inv_rate)?;
40 Self::with_subspace(subspace, log_dimension, log_inv_rate)
41 }
42
43 pub fn with_ntt_subspace(
44 ntt: &impl AdditiveNTT<Field = F>,
45 log_dimension: usize,
46 log_inv_rate: usize,
47 ) -> Result<Self, Error> {
48 Self::with_domain_context_subspace(ntt.domain_context(), log_dimension, log_inv_rate)
49 }
50
51 pub fn with_domain_context_subspace(
52 domain_context: &impl DomainContext<Field = F>,
53 log_dimension: usize,
54 log_inv_rate: usize,
55 ) -> Result<Self, Error> {
56 let subspace_dim = log_dimension + log_inv_rate;
57 if subspace_dim > domain_context.log_domain_size() {
58 return Err(Error::SubspaceDimensionMismatch);
59 }
60 let subspace = domain_context.subspace(subspace_dim);
61 Self::with_subspace(subspace, log_dimension, log_inv_rate)
62 }
63
64 pub fn with_subspace(
65 subspace: BinarySubspace<F>,
66 log_dimension: usize,
67 log_inv_rate: usize,
68 ) -> Result<Self, Error> {
69 if subspace.dim() != log_dimension + log_inv_rate {
70 return Err(Error::SubspaceDimensionMismatch);
71 }
72 Ok(Self {
73 subspace,
74 log_dimension,
75 log_inv_rate,
76 })
77 }
78
79 pub const fn dim(&self) -> usize {
81 1 << self.dim_bits()
82 }
83
84 pub const fn log_dim(&self) -> usize {
85 self.log_dimension
86 }
87
88 pub const fn log_len(&self) -> usize {
89 self.log_dimension + self.log_inv_rate
90 }
91
92 #[allow(clippy::len_without_is_empty)]
94 pub const fn len(&self) -> usize {
95 1 << (self.log_dimension + self.log_inv_rate)
96 }
97
98 const fn dim_bits(&self) -> usize {
100 self.log_dimension
101 }
102
103 pub const fn inv_rate(&self) -> usize {
105 1 << self.log_inv_rate
106 }
107
108 pub fn encode_batch_inplace<P: PackedField<Scalar = F>, NTT: AdditiveNTT<Field = F> + Sync>(
118 &self,
119 ntt: &NTT,
120 code: &mut [P],
121 log_batch_size: usize,
122 ) -> Result<(), Error> {
123 if ntt.subspace(self.log_len()) != self.subspace {
124 return Err(Error::EncoderSubspaceMismatch);
125 }
126
127 let mut code = FieldSliceMut::from_slice(self.log_len() + log_batch_size, code)?;
128
129 let _scope = tracing::trace_span!(
130 "Reed-Solomon encode",
131 log_len = self.log_len(),
132 log_batch_size = log_batch_size,
133 symbol_bits = F::N_BITS,
134 )
135 .entered();
136
137 let chunk_size = self.log_dim() + log_batch_size;
142 let chunk_size = if chunk_size < P::LOG_WIDTH {
143 let elem_0 = &mut code.as_mut()[0];
144 let repeated_values = elem_0
145 .into_iter()
146 .take(1 << (self.log_dim() + log_batch_size))
147 .cycle();
148 *elem_0 = P::from_scalars(repeated_values);
149 P::LOG_WIDTH
150 } else {
151 chunk_size
152 };
153
154 if chunk_size < code.log_len() {
155 let mut chunks = code.chunks_mut(chunk_size).expect(
156 "chunk_size >= P::LOG_WIDTH from assignment above; \
157 chunk_size < code.log_len() in conditional",
158 );
159 let first_chunk = chunks.next().expect("chunks_mut cannot be empty");
160 chunks.par_bridge().for_each(|mut chunk| {
161 chunk.as_mut().copy_from_slice(first_chunk.as_ref());
162 });
163 }
164
165 let skip_early = self.log_inv_rate;
166 let skip_late = log_batch_size;
167 ntt.forward_transform(code, skip_early, skip_late);
168 Ok(())
169 }
170
171 pub fn encode_batch<P: PackedField<Scalar = F>, NTT: AdditiveNTT<Field = F> + Sync>(
191 &self,
192 ntt: &NTT,
193 data: &[P],
194 output: &mut [MaybeUninit<P>],
195 log_batch_size: usize,
196 ) -> Result<(), Error> {
197 if ntt.subspace(self.log_len()) != self.subspace {
198 return Err(Error::EncoderSubspaceMismatch);
199 }
200
201 let data_log_len = self.log_dim() + log_batch_size;
203 let output_log_len = self.log_len() + log_batch_size;
204
205 let expected_data_len = if data_log_len >= P::LOG_WIDTH {
206 1 << (data_log_len - P::LOG_WIDTH)
207 } else {
208 1
209 };
210
211 let expected_output_len = if output_log_len >= P::LOG_WIDTH {
212 1 << (output_log_len - P::LOG_WIDTH)
213 } else {
214 1
215 };
216
217 if data.len() != expected_data_len {
218 return Err(Error::Math(MathError::IncorrectArgumentLength {
219 arg: "data".to_string(),
220 expected: expected_data_len,
221 }));
222 }
223
224 if output.len() != expected_output_len {
225 return Err(Error::Math(MathError::IncorrectArgumentLength {
226 arg: "output".to_string(),
227 expected: expected_output_len,
228 }));
229 }
230
231 let _scope = tracing::trace_span!(
232 "Reed-Solomon encode",
233 log_len = self.log_len(),
234 log_batch_size = log_batch_size,
235 symbol_bits = F::N_BITS,
236 )
237 .entered();
238
239 let log_chunk_size = self.log_dim() + log_batch_size;
241 if log_chunk_size < P::LOG_WIDTH {
242 let repeated_values = data[0]
243 .into_iter()
244 .take(1 << (self.log_dim() + log_batch_size))
245 .cycle();
246 let elem_0 = P::from_scalars(repeated_values);
247 output.par_iter_mut().for_each(|elem| {
248 elem.write(elem_0);
249 });
250 } else {
251 output
252 .par_chunks_mut(1 << (log_chunk_size - P::LOG_WIDTH))
253 .for_each(|chunk| {
254 let out = uninit::out_ref::Out::from(chunk);
255 out.copy_from_slice(data);
256 });
257 };
258
259 let output_initialized = unsafe { uninit::out_ref::Out::<[P]>::from(output).assume_init() };
261 let code = FieldSliceMut::from_slice(self.log_len() + log_batch_size, output_initialized)?;
262
263 let skip_early = self.log_inv_rate;
264 let skip_late = log_batch_size;
265 ntt.forward_transform(code, skip_early, skip_late);
266 Ok(())
267 }
268
269 pub fn encode_ext_batch_inplace<PE: PackedExtension<F>, NTT: AdditiveNTT<Field = F> + Sync>(
283 &self,
284 ntt: &NTT,
285 code: &mut [PE],
286 log_batch_size: usize,
287 ) -> Result<(), Error> {
288 self.encode_batch_inplace(
289 ntt,
290 PE::cast_bases_mut(code),
291 log_batch_size + PE::Scalar::LOG_DEGREE,
292 )
293 }
294
295 pub fn encode_ext_batch<PE: PackedExtension<F>, NTT: AdditiveNTT<Field = F> + Sync>(
315 &self,
316 ntt: &NTT,
317 data: &[PE],
318 output: &mut [MaybeUninit<PE>],
319 log_batch_size: usize,
320 ) -> Result<(), Error> {
321 let output_bases = unsafe {
325 std::slice::from_raw_parts_mut(
326 output.as_mut_ptr() as *mut MaybeUninit<PE::PackedSubfield>,
327 output.len() * PE::Scalar::DEGREE,
328 )
329 };
330
331 self.encode_batch(
332 ntt,
333 PE::cast_bases(data),
334 output_bases,
335 log_batch_size + PE::Scalar::LOG_DEGREE,
336 )
337 }
338}
339
340#[derive(Debug, thiserror::Error)]
341pub enum Error {
342 #[error("the evaluation domain of the code does not match the subspace of the NTT encoder")]
343 EncoderSubspaceMismatch,
344 #[error("the dimension of the evaluation domain of the code does not match the parameters")]
345 SubspaceDimensionMismatch,
346 #[error("math error: {0}")]
347 Math(#[from] MathError),
348}
349
350#[cfg(test)]
351mod tests {
352 use binius_field::{
353 BinaryField, BinaryField128bGhash, PackedBinaryGhash1x128b, PackedBinaryGhash4x128b,
354 PackedField,
355 };
356 use rand::{SeedableRng, rngs::StdRng};
357
358 use super::*;
359 use crate::{
360 FieldBuffer,
361 ntt::{NeighborsLastReference, domain_context::GenericPreExpanded},
362 test_utils::random_field_buffer,
363 };
364
365 fn test_encode_batch_inplace_helper<P: PackedField>(
366 log_dim: usize,
367 log_inv_rate: usize,
368 log_batch_size: usize,
369 ) where
370 P::Scalar: BinaryField,
371 {
372 let mut rng = StdRng::seed_from_u64(0);
373
374 let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate)
375 .expect("Failed to create Reed-Solomon code");
376
377 let subspace = rs_code.subspace().clone();
379 let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
380 let ntt = NeighborsLastReference {
381 domain_context: &domain_context,
382 };
383
384 let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
386
387 let log_encoded_len = rs_code.log_len() + log_batch_size;
389
390 let mut encoded_data = message.as_ref().to_vec();
393 let encoded_capacity = 1 << log_encoded_len.saturating_sub(P::LOG_WIDTH);
395 encoded_data.resize(encoded_capacity, P::zero());
396
397 let mut encoded_buffer =
398 FieldBuffer::new_truncated(log_dim + log_batch_size, encoded_data.into_boxed_slice())
399 .expect("Failed to create encoded buffer");
400
401 encoded_buffer
403 .zero_extend(log_encoded_len)
404 .expect("Failed to zero-extend encoded buffer");
405
406 rs_code
407 .encode_batch_inplace(&ntt, encoded_buffer.as_mut(), log_batch_size)
408 .expect("encode_batch_inplace failed");
409
410 let mut reference_data = message.as_ref().to_vec();
413 reference_data.resize(encoded_capacity, P::zero());
415
416 let mut reference_buffer =
417 FieldBuffer::new_truncated(log_dim + log_batch_size, reference_data.into_boxed_slice())
418 .expect("Failed to create reference buffer");
419
420 reference_buffer
422 .zero_extend(log_encoded_len)
423 .expect("Failed to zero-extend reference buffer");
424
425 ntt.forward_transform(reference_buffer.to_mut(), 0, log_batch_size);
427
428 assert_eq!(
430 encoded_buffer.as_ref(),
431 reference_buffer.as_ref(),
432 "encode_batch_inplace result differs from reference NTT implementation"
433 );
434 }
435
436 #[test]
437 fn test_encode_batch_inplace() {
438 test_encode_batch_inplace_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
440 test_encode_batch_inplace_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
441 test_encode_batch_inplace_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
442
443 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
445 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
446 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
447
448 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
450 }
451
452 fn test_encode_batch_helper<P: PackedField>(
453 log_dim: usize,
454 log_inv_rate: usize,
455 log_batch_size: usize,
456 ) where
457 P::Scalar: BinaryField,
458 {
459 let mut rng = StdRng::seed_from_u64(0);
460
461 let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate)
462 .expect("Failed to create Reed-Solomon code");
463
464 let subspace = rs_code.subspace().clone();
466 let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
467 let ntt = NeighborsLastReference {
468 domain_context: &domain_context,
469 };
470
471 let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
473
474 let log_encoded_len = rs_code.log_len() + log_batch_size;
476 let encoded_capacity = 1 << log_encoded_len.saturating_sub(P::LOG_WIDTH);
477 let mut encoded_output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity);
478
479 unsafe {
480 encoded_output.set_len(encoded_capacity);
481 }
482
483 rs_code
484 .encode_batch(&ntt, message.as_ref(), &mut encoded_output, log_batch_size)
485 .expect("encode_batch failed");
486
487 let encoded_result: Vec<P> = unsafe {
489 encoded_output
490 .into_iter()
491 .map(|x| x.assume_init())
492 .collect()
493 };
494
495 let mut encoded_data = message.as_ref().to_vec();
497 encoded_data.resize(encoded_capacity, P::zero());
498
499 let mut reference_buffer =
500 FieldBuffer::new_truncated(log_dim + log_batch_size, encoded_data.into_boxed_slice())
501 .expect("Failed to create reference buffer");
502
503 reference_buffer
504 .zero_extend(log_encoded_len)
505 .expect("Failed to zero-extend reference buffer");
506
507 rs_code
508 .encode_batch_inplace(&ntt, reference_buffer.as_mut(), log_batch_size)
509 .expect("encode_batch_inplace failed");
510
511 assert_eq!(
513 encoded_result,
514 reference_buffer.as_ref(),
515 "encode_batch result differs from encode_batch_inplace"
516 );
517 }
518
519 #[test]
520 fn test_encode_batch() {
521 test_encode_batch_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
523 test_encode_batch_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
524 test_encode_batch_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
525
526 test_encode_batch_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
528 test_encode_batch_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
529 test_encode_batch_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
530
531 test_encode_batch_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
533 }
534
535 #[test]
536 fn test_encode_ext_batch() {
537 let mut rng = StdRng::seed_from_u64(0);
540 let log_dim = 6;
541 let log_inv_rate = 2;
542 let log_batch_size = 0;
543
544 type PE = PackedBinaryGhash4x128b;
545 type F = <PE as PackedField>::Scalar;
546
547 let rs_code = ReedSolomonCode::<F>::new(log_dim, log_inv_rate)
548 .expect("Failed to create Reed-Solomon code");
549
550 let subspace = rs_code.subspace().clone();
551 let domain_context = GenericPreExpanded::<F>::generate_from_subspace(&subspace);
552 let ntt = NeighborsLastReference {
553 domain_context: &domain_context,
554 };
555
556 let message = random_field_buffer::<PE>(&mut rng, log_dim + log_batch_size);
557
558 let log_encoded_len = rs_code.log_len() + log_batch_size;
559 let encoded_capacity = if log_encoded_len >= PE::LOG_WIDTH {
560 1 << (log_encoded_len - PE::LOG_WIDTH)
561 } else {
562 1
563 };
564
565 let mut encoded_output = Vec::<MaybeUninit<PE>>::with_capacity(encoded_capacity);
566 unsafe {
567 encoded_output.set_len(encoded_capacity);
568 }
569
570 rs_code
572 .encode_ext_batch(&ntt, message.as_ref(), &mut encoded_output, log_batch_size)
573 .expect("encode_ext_batch failed");
574
575 let encoded_result: Vec<PE> = unsafe {
577 encoded_output
578 .into_iter()
579 .map(|x| x.assume_init())
580 .collect()
581 };
582
583 assert_eq!(encoded_result.len(), encoded_capacity);
585 }
586
587 #[test]
588 #[ignore = "Test setup hits edge case in NTT domain configuration - dimension validation logic is correct"]
589 fn test_encode_batch_dimension_validation() {
590 let mut rng = StdRng::seed_from_u64(0);
591 let log_dim = 6; let log_inv_rate = 2;
593 let log_batch_size = 1;
594
595 type P = PackedBinaryGhash4x128b;
596 type F = <P as PackedField>::Scalar;
597
598 let rs_code = ReedSolomonCode::<F>::new(log_dim, log_inv_rate)
599 .expect("Failed to create Reed-Solomon code");
600
601 let subspace = rs_code.subspace().clone();
602 let domain_context = GenericPreExpanded::<F>::generate_from_subspace(&subspace);
603 let ntt = NeighborsLastReference {
604 domain_context: &domain_context,
605 };
606
607 let wrong_data = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size - 1);
609 let log_encoded_len = rs_code.log_len() + log_batch_size;
610 let encoded_capacity = 1 << log_encoded_len.saturating_sub(P::LOG_WIDTH);
611 let mut output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity);
612
613 unsafe {
614 output.set_len(encoded_capacity);
615 }
616
617 let result = rs_code.encode_batch(&ntt, wrong_data.as_ref(), &mut output, log_batch_size);
618 assert!(result.is_err(), "Expected error for incorrect input data length");
619 assert!(
620 matches!(result, Err(Error::Math(MathError::IncorrectArgumentLength { arg, .. })) if arg == "data"),
621 "Expected IncorrectArgumentLength error for data"
622 );
623
624 let correct_data = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
626 let mut wrong_output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity - 1);
627
628 unsafe {
629 wrong_output.set_len(encoded_capacity - 1);
630 }
631
632 let result =
633 rs_code.encode_batch(&ntt, correct_data.as_ref(), &mut wrong_output, log_batch_size);
634 assert!(result.is_err(), "Expected error for incorrect output buffer length");
635 assert!(
636 matches!(result, Err(Error::Math(MathError::IncorrectArgumentLength { arg, .. })) if arg == "output"),
637 "Expected IncorrectArgumentLength error for output"
638 );
639
640 let wrong_rs_code = ReedSolomonCode::<BinaryField128bGhash>::new(log_dim + 1, log_inv_rate)
642 .expect("Failed to create Reed-Solomon code");
643
644 let mut correct_output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity);
645 unsafe {
646 correct_output.set_len(encoded_capacity);
647 }
648
649 let result = wrong_rs_code.encode_batch(
650 &ntt,
651 correct_data.as_ref(),
652 &mut correct_output,
653 log_batch_size,
654 );
655 assert!(result.is_err(), "Expected error for NTT subspace mismatch");
656 assert!(
657 matches!(result, Err(Error::EncoderSubspaceMismatch)),
658 "Expected EncoderSubspaceMismatch error"
659 );
660 }
661
662 #[test]
663 fn test_encode_ext_batch_dimension_validation() {
664 let mut rng = StdRng::seed_from_u64(0);
666 let log_dim = 6;
667 let log_inv_rate = 2;
668 let log_batch_size = 1;
669
670 type PE = PackedBinaryGhash4x128b;
671 type F = <PE as PackedField>::Scalar;
672
673 let rs_code = ReedSolomonCode::<F>::new(log_dim, log_inv_rate)
674 .expect("Failed to create Reed-Solomon code");
675
676 let subspace = rs_code.subspace().clone();
677 let domain_context = GenericPreExpanded::<F>::generate_from_subspace(&subspace);
678 let ntt = NeighborsLastReference {
679 domain_context: &domain_context,
680 };
681
682 let wrong_data = random_field_buffer::<PE>(&mut rng, log_dim + log_batch_size - 1);
684 let log_encoded_len = rs_code.log_len() + log_batch_size;
685 let encoded_capacity = if log_encoded_len >= PE::LOG_WIDTH {
686 1 << (log_encoded_len - PE::LOG_WIDTH)
687 } else {
688 1
689 };
690 let mut output = Vec::<MaybeUninit<PE>>::with_capacity(encoded_capacity);
691
692 unsafe {
693 output.set_len(encoded_capacity);
694 }
695
696 let result =
697 rs_code.encode_ext_batch(&ntt, wrong_data.as_ref(), &mut output, log_batch_size);
698 assert!(result.is_err(), "Expected error for incorrect input data length");
699 }
700}