1use std::mem::MaybeUninit;
8
9use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField};
10use binius_utils::{
11 bail,
12 rayon::{
13 iter::{IntoParallelRefMutIterator, ParallelBridge, ParallelIterator},
14 slice::ParallelSliceMut,
15 },
16};
17use getset::{CopyGetters, Getters};
18
19use super::{binary_subspace::BinarySubspace, error::Error as MathError, ntt::AdditiveNTT};
20use crate::{FieldSliceMut, ntt::DomainContext};
21
22#[derive(Debug, Clone, Getters, CopyGetters)]
32pub struct ReedSolomonCode<F> {
33 #[get = "pub"]
34 subspace: BinarySubspace<F>,
35 log_dimension: usize,
36 #[get_copy = "pub"]
37 log_inv_rate: usize,
38}
39
40impl<F: BinaryField> ReedSolomonCode<F> {
41 pub fn new(log_dimension: usize, log_inv_rate: usize) -> Result<Self, Error> {
42 let subspace = BinarySubspace::with_dim(log_dimension + log_inv_rate)?;
43 Self::with_subspace(subspace, log_dimension, log_inv_rate)
44 }
45
46 pub fn with_ntt_subspace(
47 ntt: &impl AdditiveNTT<Field = F>,
48 log_dimension: usize,
49 log_inv_rate: usize,
50 ) -> Result<Self, Error> {
51 Self::with_domain_context_subspace(ntt.domain_context(), log_dimension, log_inv_rate)
52 }
53
54 pub fn with_domain_context_subspace(
55 domain_context: &impl DomainContext<Field = F>,
56 log_dimension: usize,
57 log_inv_rate: usize,
58 ) -> Result<Self, Error> {
59 let subspace_dim = log_dimension + log_inv_rate;
60 if subspace_dim > domain_context.log_domain_size() {
61 return Err(Error::SubspaceDimensionMismatch);
62 }
63 let subspace = domain_context.subspace(subspace_dim);
64 Self::with_subspace(subspace, log_dimension, log_inv_rate)
65 }
66
67 pub fn with_subspace(
68 subspace: BinarySubspace<F>,
69 log_dimension: usize,
70 log_inv_rate: usize,
71 ) -> Result<Self, Error> {
72 if subspace.dim() != log_dimension + log_inv_rate {
73 return Err(Error::SubspaceDimensionMismatch);
74 }
75 Ok(Self {
76 subspace,
77 log_dimension,
78 log_inv_rate,
79 })
80 }
81
82 pub const fn dim(&self) -> usize {
84 1 << self.dim_bits()
85 }
86
87 pub const fn log_dim(&self) -> usize {
88 self.log_dimension
89 }
90
91 pub const fn log_len(&self) -> usize {
92 self.log_dimension + self.log_inv_rate
93 }
94
95 #[allow(clippy::len_without_is_empty)]
97 pub const fn len(&self) -> usize {
98 1 << (self.log_dimension + self.log_inv_rate)
99 }
100
101 const fn dim_bits(&self) -> usize {
103 self.log_dimension
104 }
105
106 pub const fn inv_rate(&self) -> usize {
108 1 << self.log_inv_rate
109 }
110
111 pub fn encode_batch_inplace<P: PackedField<Scalar = F>, NTT: AdditiveNTT<Field = F> + Sync>(
121 &self,
122 ntt: &NTT,
123 code: &mut [P],
124 log_batch_size: usize,
125 ) -> Result<(), Error> {
126 if ntt.subspace(self.log_len()) != self.subspace {
127 bail!(Error::EncoderSubspaceMismatch);
128 }
129
130 let mut code = FieldSliceMut::from_slice(self.log_len() + log_batch_size, code)?;
131
132 let _scope = tracing::trace_span!(
133 "Reed-Solomon encode",
134 log_len = self.log_len(),
135 log_batch_size = log_batch_size,
136 symbol_bits = F::N_BITS,
137 )
138 .entered();
139
140 let chunk_size = self.log_dim() + log_batch_size;
145 let chunk_size = if chunk_size < P::LOG_WIDTH {
146 let elem_0 = &mut code.as_mut()[0];
147 let repeated_values = elem_0
148 .into_iter()
149 .take(1 << (self.log_dim() + log_batch_size))
150 .cycle();
151 *elem_0 = P::from_scalars(repeated_values);
152 P::LOG_WIDTH
153 } else {
154 chunk_size
155 };
156
157 if chunk_size < code.log_len() {
158 let mut chunks = code.chunks_mut(chunk_size).expect(
159 "chunk_size >= P::LOG_WIDTH from assignment above; \
160 chunk_size < code.log_len() in conditional",
161 );
162 let first_chunk = chunks.next().expect("chunks_mut cannot be empty");
163 chunks.par_bridge().for_each(|mut chunk| {
164 chunk.as_mut().copy_from_slice(first_chunk.as_ref());
165 });
166 }
167
168 let skip_early = self.log_inv_rate;
169 let skip_late = log_batch_size;
170 ntt.forward_transform(code, skip_early, skip_late);
171 Ok(())
172 }
173
174 pub fn encode_batch<P: PackedField<Scalar = F>, NTT: AdditiveNTT<Field = F> + Sync>(
194 &self,
195 ntt: &NTT,
196 data: &[P],
197 output: &mut [MaybeUninit<P>],
198 log_batch_size: usize,
199 ) -> Result<(), Error> {
200 if ntt.subspace(self.log_len()) != self.subspace {
201 bail!(Error::EncoderSubspaceMismatch);
202 }
203
204 let data_log_len = self.log_dim() + log_batch_size;
206 let output_log_len = self.log_len() + log_batch_size;
207
208 let expected_data_len = if data_log_len >= P::LOG_WIDTH {
209 1 << (data_log_len - P::LOG_WIDTH)
210 } else {
211 1
212 };
213
214 let expected_output_len = if output_log_len >= P::LOG_WIDTH {
215 1 << (output_log_len - P::LOG_WIDTH)
216 } else {
217 1
218 };
219
220 if data.len() != expected_data_len {
221 bail!(Error::Math(MathError::IncorrectArgumentLength {
222 arg: "data".to_string(),
223 expected: expected_data_len,
224 }));
225 }
226
227 if output.len() != expected_output_len {
228 bail!(Error::Math(MathError::IncorrectArgumentLength {
229 arg: "output".to_string(),
230 expected: expected_output_len,
231 }));
232 }
233
234 let _scope = tracing::trace_span!(
235 "Reed-Solomon encode",
236 log_len = self.log_len(),
237 log_batch_size = log_batch_size,
238 symbol_bits = F::N_BITS,
239 )
240 .entered();
241
242 let log_chunk_size = self.log_dim() + log_batch_size;
244 if log_chunk_size < P::LOG_WIDTH {
245 let repeated_values = data[0]
246 .into_iter()
247 .take(1 << (self.log_dim() + log_batch_size))
248 .cycle();
249 let elem_0 = P::from_scalars(repeated_values);
250 output.par_iter_mut().for_each(|elem| {
251 elem.write(elem_0);
252 });
253 } else {
254 output
255 .par_chunks_mut(1 << (log_chunk_size - P::LOG_WIDTH))
256 .for_each(|chunk| {
257 let out = uninit::out_ref::Out::from(chunk);
258 out.copy_from_slice(data);
259 });
260 };
261
262 let output_initialized = unsafe { uninit::out_ref::Out::<[P]>::from(output).assume_init() };
264 let code = FieldSliceMut::from_slice(self.log_len() + log_batch_size, output_initialized)?;
265
266 let skip_early = self.log_inv_rate;
267 let skip_late = log_batch_size;
268 ntt.forward_transform(code, skip_early, skip_late);
269 Ok(())
270 }
271
272 pub fn encode_ext_batch_inplace<PE: PackedExtension<F>, NTT: AdditiveNTT<Field = F> + Sync>(
286 &self,
287 ntt: &NTT,
288 code: &mut [PE],
289 log_batch_size: usize,
290 ) -> Result<(), Error> {
291 self.encode_batch_inplace(
292 ntt,
293 PE::cast_bases_mut(code),
294 log_batch_size + PE::Scalar::LOG_DEGREE,
295 )
296 }
297
298 pub fn encode_ext_batch<PE: PackedExtension<F>, NTT: AdditiveNTT<Field = F> + Sync>(
318 &self,
319 ntt: &NTT,
320 data: &[PE],
321 output: &mut [MaybeUninit<PE>],
322 log_batch_size: usize,
323 ) -> Result<(), Error> {
324 let output_bases = unsafe {
328 std::slice::from_raw_parts_mut(
329 output.as_mut_ptr() as *mut MaybeUninit<PE::PackedSubfield>,
330 output.len() * PE::Scalar::DEGREE,
331 )
332 };
333
334 self.encode_batch(
335 ntt,
336 PE::cast_bases(data),
337 output_bases,
338 log_batch_size + PE::Scalar::LOG_DEGREE,
339 )
340 }
341}
342
343#[derive(Debug, thiserror::Error)]
344pub enum Error {
345 #[error("the evaluation domain of the code does not match the subspace of the NTT encoder")]
346 EncoderSubspaceMismatch,
347 #[error("the dimension of the evaluation domain of the code does not match the parameters")]
348 SubspaceDimensionMismatch,
349 #[error("math error: {0}")]
350 Math(#[from] MathError),
351}
352
353#[cfg(test)]
354mod tests {
355 use binius_field::{
356 BinaryField, BinaryField128bGhash, PackedBinaryGhash1x128b, PackedBinaryGhash4x128b,
357 PackedField,
358 };
359 use rand::{SeedableRng, rngs::StdRng};
360
361 use super::*;
362 use crate::{
363 FieldBuffer,
364 ntt::{NeighborsLastReference, domain_context::GenericPreExpanded},
365 test_utils::random_field_buffer,
366 };
367
368 fn test_encode_batch_inplace_helper<P: PackedField>(
369 log_dim: usize,
370 log_inv_rate: usize,
371 log_batch_size: usize,
372 ) where
373 P::Scalar: BinaryField,
374 {
375 let mut rng = StdRng::seed_from_u64(0);
376
377 let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate)
378 .expect("Failed to create Reed-Solomon code");
379
380 let subspace = rs_code.subspace().clone();
382 let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
383 let ntt = NeighborsLastReference {
384 domain_context: &domain_context,
385 };
386
387 let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
389
390 let log_encoded_len = rs_code.log_len() + log_batch_size;
392
393 let mut encoded_data = message.as_ref().to_vec();
396 let encoded_capacity = 1 << log_encoded_len.saturating_sub(P::LOG_WIDTH);
398 encoded_data.resize(encoded_capacity, P::zero());
399
400 let mut encoded_buffer =
401 FieldBuffer::new_truncated(log_dim + log_batch_size, encoded_data.into_boxed_slice())
402 .expect("Failed to create encoded buffer");
403
404 encoded_buffer
406 .zero_extend(log_encoded_len)
407 .expect("Failed to zero-extend encoded buffer");
408
409 rs_code
410 .encode_batch_inplace(&ntt, encoded_buffer.as_mut(), log_batch_size)
411 .expect("encode_batch_inplace failed");
412
413 let mut reference_data = message.as_ref().to_vec();
416 reference_data.resize(encoded_capacity, P::zero());
418
419 let mut reference_buffer =
420 FieldBuffer::new_truncated(log_dim + log_batch_size, reference_data.into_boxed_slice())
421 .expect("Failed to create reference buffer");
422
423 reference_buffer
425 .zero_extend(log_encoded_len)
426 .expect("Failed to zero-extend reference buffer");
427
428 ntt.forward_transform(reference_buffer.to_mut(), 0, log_batch_size);
430
431 assert_eq!(
433 encoded_buffer.as_ref(),
434 reference_buffer.as_ref(),
435 "encode_batch_inplace result differs from reference NTT implementation"
436 );
437 }
438
439 #[test]
440 fn test_encode_batch_inplace() {
441 test_encode_batch_inplace_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
443 test_encode_batch_inplace_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
444 test_encode_batch_inplace_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
445
446 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
448 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
449 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
450
451 test_encode_batch_inplace_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
453 }
454
455 fn test_encode_batch_helper<P: PackedField>(
456 log_dim: usize,
457 log_inv_rate: usize,
458 log_batch_size: usize,
459 ) where
460 P::Scalar: BinaryField,
461 {
462 let mut rng = StdRng::seed_from_u64(0);
463
464 let rs_code = ReedSolomonCode::<P::Scalar>::new(log_dim, log_inv_rate)
465 .expect("Failed to create Reed-Solomon code");
466
467 let subspace = rs_code.subspace().clone();
469 let domain_context = GenericPreExpanded::<P::Scalar>::generate_from_subspace(&subspace);
470 let ntt = NeighborsLastReference {
471 domain_context: &domain_context,
472 };
473
474 let message = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
476
477 let log_encoded_len = rs_code.log_len() + log_batch_size;
479 let encoded_capacity = 1 << log_encoded_len.saturating_sub(P::LOG_WIDTH);
480 let mut encoded_output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity);
481
482 unsafe {
483 encoded_output.set_len(encoded_capacity);
484 }
485
486 rs_code
487 .encode_batch(&ntt, message.as_ref(), &mut encoded_output, log_batch_size)
488 .expect("encode_batch failed");
489
490 let encoded_result: Vec<P> = unsafe {
492 encoded_output
493 .into_iter()
494 .map(|x| x.assume_init())
495 .collect()
496 };
497
498 let mut encoded_data = message.as_ref().to_vec();
500 encoded_data.resize(encoded_capacity, P::zero());
501
502 let mut reference_buffer =
503 FieldBuffer::new_truncated(log_dim + log_batch_size, encoded_data.into_boxed_slice())
504 .expect("Failed to create reference buffer");
505
506 reference_buffer
507 .zero_extend(log_encoded_len)
508 .expect("Failed to zero-extend reference buffer");
509
510 rs_code
511 .encode_batch_inplace(&ntt, reference_buffer.as_mut(), log_batch_size)
512 .expect("encode_batch_inplace failed");
513
514 assert_eq!(
516 encoded_result,
517 reference_buffer.as_ref(),
518 "encode_batch result differs from encode_batch_inplace"
519 );
520 }
521
522 #[test]
523 fn test_encode_batch() {
524 test_encode_batch_helper::<PackedBinaryGhash1x128b>(4, 2, 0);
526 test_encode_batch_helper::<PackedBinaryGhash1x128b>(6, 2, 1);
527 test_encode_batch_helper::<PackedBinaryGhash1x128b>(8, 3, 2);
528
529 test_encode_batch_helper::<PackedBinaryGhash4x128b>(4, 2, 0);
531 test_encode_batch_helper::<PackedBinaryGhash4x128b>(6, 2, 1);
532 test_encode_batch_helper::<PackedBinaryGhash4x128b>(8, 3, 2);
533
534 test_encode_batch_helper::<PackedBinaryGhash4x128b>(1, 2, 0);
536 }
537
538 #[test]
539 fn test_encode_ext_batch() {
540 let mut rng = StdRng::seed_from_u64(0);
543 let log_dim = 6;
544 let log_inv_rate = 2;
545 let log_batch_size = 0;
546
547 type PE = PackedBinaryGhash4x128b;
548 type F = <PE as PackedField>::Scalar;
549
550 let rs_code = ReedSolomonCode::<F>::new(log_dim, log_inv_rate)
551 .expect("Failed to create Reed-Solomon code");
552
553 let subspace = rs_code.subspace().clone();
554 let domain_context = GenericPreExpanded::<F>::generate_from_subspace(&subspace);
555 let ntt = NeighborsLastReference {
556 domain_context: &domain_context,
557 };
558
559 let message = random_field_buffer::<PE>(&mut rng, log_dim + log_batch_size);
560
561 let log_encoded_len = rs_code.log_len() + log_batch_size;
562 let encoded_capacity = if log_encoded_len >= PE::LOG_WIDTH {
563 1 << (log_encoded_len - PE::LOG_WIDTH)
564 } else {
565 1
566 };
567
568 let mut encoded_output = Vec::<MaybeUninit<PE>>::with_capacity(encoded_capacity);
569 unsafe {
570 encoded_output.set_len(encoded_capacity);
571 }
572
573 rs_code
575 .encode_ext_batch(&ntt, message.as_ref(), &mut encoded_output, log_batch_size)
576 .expect("encode_ext_batch failed");
577
578 let encoded_result: Vec<PE> = unsafe {
580 encoded_output
581 .into_iter()
582 .map(|x| x.assume_init())
583 .collect()
584 };
585
586 assert_eq!(encoded_result.len(), encoded_capacity);
588 }
589
590 #[test]
591 #[ignore = "Test setup hits edge case in NTT domain configuration - dimension validation logic is correct"]
592 fn test_encode_batch_dimension_validation() {
593 let mut rng = StdRng::seed_from_u64(0);
594 let log_dim = 6; let log_inv_rate = 2;
596 let log_batch_size = 1;
597
598 type P = PackedBinaryGhash4x128b;
599 type F = <P as PackedField>::Scalar;
600
601 let rs_code = ReedSolomonCode::<F>::new(log_dim, log_inv_rate)
602 .expect("Failed to create Reed-Solomon code");
603
604 let subspace = rs_code.subspace().clone();
605 let domain_context = GenericPreExpanded::<F>::generate_from_subspace(&subspace);
606 let ntt = NeighborsLastReference {
607 domain_context: &domain_context,
608 };
609
610 let wrong_data = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size - 1);
612 let log_encoded_len = rs_code.log_len() + log_batch_size;
613 let encoded_capacity = 1 << log_encoded_len.saturating_sub(P::LOG_WIDTH);
614 let mut output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity);
615
616 unsafe {
617 output.set_len(encoded_capacity);
618 }
619
620 let result = rs_code.encode_batch(&ntt, wrong_data.as_ref(), &mut output, log_batch_size);
621 assert!(result.is_err(), "Expected error for incorrect input data length");
622 assert!(
623 matches!(result, Err(Error::Math(MathError::IncorrectArgumentLength { arg, .. })) if arg == "data"),
624 "Expected IncorrectArgumentLength error for data"
625 );
626
627 let correct_data = random_field_buffer::<P>(&mut rng, log_dim + log_batch_size);
629 let mut wrong_output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity - 1);
630
631 unsafe {
632 wrong_output.set_len(encoded_capacity - 1);
633 }
634
635 let result =
636 rs_code.encode_batch(&ntt, correct_data.as_ref(), &mut wrong_output, log_batch_size);
637 assert!(result.is_err(), "Expected error for incorrect output buffer length");
638 assert!(
639 matches!(result, Err(Error::Math(MathError::IncorrectArgumentLength { arg, .. })) if arg == "output"),
640 "Expected IncorrectArgumentLength error for output"
641 );
642
643 let wrong_rs_code = ReedSolomonCode::<BinaryField128bGhash>::new(log_dim + 1, log_inv_rate)
645 .expect("Failed to create Reed-Solomon code");
646
647 let mut correct_output = Vec::<MaybeUninit<P>>::with_capacity(encoded_capacity);
648 unsafe {
649 correct_output.set_len(encoded_capacity);
650 }
651
652 let result = wrong_rs_code.encode_batch(
653 &ntt,
654 correct_data.as_ref(),
655 &mut correct_output,
656 log_batch_size,
657 );
658 assert!(result.is_err(), "Expected error for NTT subspace mismatch");
659 assert!(
660 matches!(result, Err(Error::EncoderSubspaceMismatch)),
661 "Expected EncoderSubspaceMismatch error"
662 );
663 }
664
665 #[test]
666 fn test_encode_ext_batch_dimension_validation() {
667 let mut rng = StdRng::seed_from_u64(0);
669 let log_dim = 6;
670 let log_inv_rate = 2;
671 let log_batch_size = 1;
672
673 type PE = PackedBinaryGhash4x128b;
674 type F = <PE as PackedField>::Scalar;
675
676 let rs_code = ReedSolomonCode::<F>::new(log_dim, log_inv_rate)
677 .expect("Failed to create Reed-Solomon code");
678
679 let subspace = rs_code.subspace().clone();
680 let domain_context = GenericPreExpanded::<F>::generate_from_subspace(&subspace);
681 let ntt = NeighborsLastReference {
682 domain_context: &domain_context,
683 };
684
685 let wrong_data = random_field_buffer::<PE>(&mut rng, log_dim + log_batch_size - 1);
687 let log_encoded_len = rs_code.log_len() + log_batch_size;
688 let encoded_capacity = if log_encoded_len >= PE::LOG_WIDTH {
689 1 << (log_encoded_len - PE::LOG_WIDTH)
690 } else {
691 1
692 };
693 let mut output = Vec::<MaybeUninit<PE>>::with_capacity(encoded_capacity);
694
695 unsafe {
696 output.set_len(encoded_capacity);
697 }
698
699 let result =
700 rs_code.encode_ext_batch(&ntt, wrong_data.as_ref(), &mut output, log_batch_size);
701 assert!(result.is_err(), "Expected error for incorrect input data length");
702 }
703}