1use std::{
4 ops::{Deref, DerefMut},
5 slice,
6};
7
8use binius_field::{
9 Field, PackedField,
10 packed::{get_packed_slice_unchecked, set_packed_slice_unchecked},
11};
12use binius_utils::{
13 checked_arithmetics::{checked_log_2, strict_log_2},
14 rayon::{prelude::*, slice::ParallelSlice},
15};
16use bytemuck::zeroed_vec;
17
18use crate::Error;
19
20#[derive(Debug, Clone, Eq)]
26pub struct FieldBuffer<P: PackedField, Data: Deref<Target = [P]> = Box<[P]>> {
27 log_len: usize,
29 values: Data,
31}
32
33impl<P: PackedField, Data: Deref<Target = [P]>> PartialEq for FieldBuffer<P, Data> {
34 fn eq(&self, other: &Self) -> bool {
35 if self.log_len < P::LOG_WIDTH {
38 let iter_1 = self
39 .values
40 .first()
41 .expect("len >= 1")
42 .iter()
43 .take(1 << self.log_len);
44 let iter_2 = other
45 .values
46 .first()
47 .expect("len >= 1")
48 .iter()
49 .take(1 << self.log_len);
50 iter_1.eq(iter_2)
51 } else {
52 let prefix = 1 << (self.log_len - P::LOG_WIDTH);
53 self.log_len == other.log_len && self.values[..prefix] == other.values[..prefix]
54 }
55 }
56}
57
58impl<P: PackedField> FieldBuffer<P> {
59 pub fn from_values(values: &[P::Scalar]) -> Result<Self, Error> {
65 let Some(log_len) = strict_log_2(values.len()) else {
66 return Err(Error::PowerOfTwoLengthRequired);
67 };
68
69 Self::from_values_truncated(values, log_len)
70 }
71
72 pub fn from_values_truncated(values: &[P::Scalar], log_cap: usize) -> Result<Self, Error> {
81 if !values.len().is_power_of_two() {
82 return Err(Error::PowerOfTwoLengthRequired);
83 }
84
85 let log_len = values.len().ilog2() as usize;
86 if log_len > log_cap {
87 return Err(Error::IncorrectArgumentLength {
88 arg: "values".to_string(),
89 expected: 1 << log_cap,
90 });
91 }
92
93 let packed_cap = 1 << log_cap.saturating_sub(P::LOG_WIDTH);
94 let mut packed_values = Vec::with_capacity(packed_cap);
95 packed_values.extend(
96 values
97 .chunks(P::WIDTH)
98 .map(|chunk| P::from_scalars(chunk.iter().copied())),
99 );
100 packed_values.resize(packed_cap, P::zero());
101
102 Ok(Self {
103 log_len,
104 values: packed_values.into_boxed_slice(),
105 })
106 }
107
108 pub fn zeros(log_len: usize) -> Self {
110 Self::zeros_truncated(log_len, log_len).expect("log_len == log_cap")
111 }
112
113 pub fn zeros_truncated(log_len: usize, log_cap: usize) -> Result<Self, Error> {
117 if log_len > log_cap {
118 return Err(Error::IncorrectArgumentLength {
119 arg: "log_len".to_string(),
120 expected: log_cap,
121 });
122 }
123 let packed_len = 1 << log_cap.saturating_sub(P::LOG_WIDTH);
124 let values = zeroed_vec(packed_len).into_boxed_slice();
125 Ok(Self { log_len, values })
126 }
127}
128
129#[allow(clippy::len_without_is_empty)]
130impl<P: PackedField, Data: Deref<Target = [P]>> FieldBuffer<P, Data> {
131 pub fn new(log_len: usize, values: Data) -> Result<Self, Error> {
138 let expected_packed_len = 1 << log_len.saturating_sub(P::LOG_WIDTH);
139 if values.len() != expected_packed_len {
140 return Err(Error::IncorrectArgumentLength {
141 arg: "values".to_string(),
142 expected: expected_packed_len,
143 });
144 }
145 Self::new_truncated(log_len, values)
146 }
147
148 pub fn new_truncated(log_len: usize, values: Data) -> Result<Self, Error> {
155 let min_packed_len = 1 << log_len.saturating_sub(P::LOG_WIDTH);
156 if values.len() < min_packed_len {
157 return Err(Error::IncorrectArgumentLength {
158 arg: "values".to_string(),
159 expected: min_packed_len,
160 });
161 }
162
163 if !values.len().is_power_of_two() {
164 return Err(Error::PowerOfTwoLengthRequired);
165 }
166
167 Ok(Self { log_len, values })
168 }
169
170 pub fn log_cap(&self) -> usize {
172 checked_log_2(self.values.len()) + P::LOG_WIDTH
173 }
174
175 pub fn cap(&self) -> usize {
177 1 << self.log_cap()
178 }
179
180 pub fn log_len(&self) -> usize {
182 self.log_len
183 }
184
185 pub fn len(&self) -> usize {
187 1 << self.log_len
188 }
189
190 pub fn to_ref(&self) -> FieldSlice<'_, P> {
192 FieldSlice::from_slice(self.log_len, self.as_ref())
193 .expect("log_len matches values.len() by struct invariant")
194 }
195
196 pub fn get(&self, index: usize) -> Result<P::Scalar, Error> {
202 if index >= self.len() {
203 return Err(Error::ArgumentRangeError {
204 arg: "index".to_string(),
205 range: 0..self.len(),
206 });
207 }
208
209 let val = unsafe { get_packed_slice_unchecked(&self.values, index) };
212 Ok(val)
213 }
214
215 pub fn chunk(
221 &self,
222 log_chunk_size: usize,
223 chunk_index: usize,
224 ) -> Result<FieldSlice<'_, P>, Error> {
225 if log_chunk_size > self.log_len {
226 return Err(Error::ArgumentRangeError {
227 arg: "log_chunk_size".to_string(),
228 range: 0..self.log_len + 1,
229 });
230 }
231
232 let chunk_count = 1 << (self.log_len - log_chunk_size);
233 if chunk_index >= chunk_count {
234 return Err(Error::ArgumentRangeError {
235 arg: "chunk_index".to_string(),
236 range: 0..chunk_count,
237 });
238 }
239
240 let values = if log_chunk_size >= P::LOG_WIDTH {
241 let packed_log_chunk_size = log_chunk_size - P::LOG_WIDTH;
242 let chunk =
243 &self.values[chunk_index << packed_log_chunk_size..][..1 << packed_log_chunk_size];
244 FieldSliceData::Slice(chunk)
245 } else {
246 let packed_log_chunks = P::LOG_WIDTH - log_chunk_size;
247 let packed = self.values[chunk_index >> packed_log_chunks];
248 let chunk_subindex = chunk_index & ((1 << packed_log_chunks) - 1);
249 let chunk = P::from_scalars(
250 (0..1 << log_chunk_size).map(|i| packed.get(chunk_subindex << log_chunk_size | i)),
251 );
252 FieldSliceData::Single(chunk)
253 };
254
255 Ok(FieldBuffer {
256 log_len: log_chunk_size,
257 values,
258 })
259 }
260
261 pub fn chunks(
268 &self,
269 log_chunk_size: usize,
270 ) -> Result<impl Iterator<Item = FieldSlice<'_, P>>, Error> {
271 if log_chunk_size < P::LOG_WIDTH || log_chunk_size > self.log_len {
272 return Err(Error::ArgumentRangeError {
273 arg: "log_chunk_size".to_string(),
274 range: P::LOG_WIDTH..self.log_len + 1,
275 });
276 }
277
278 let chunk_count = 1 << (self.log_len - log_chunk_size);
279 let packed_chunk_size = 1 << (log_chunk_size - P::LOG_WIDTH);
280 let chunks = self
281 .values
282 .chunks(packed_chunk_size)
283 .take(chunk_count)
284 .map(move |chunk| FieldBuffer {
285 log_len: log_chunk_size,
286 values: FieldSliceData::Slice(chunk),
287 });
288
289 Ok(chunks)
290 }
291
292 pub fn chunks_par(
299 &self,
300 log_chunk_size: usize,
301 ) -> Result<impl IndexedParallelIterator<Item = FieldSlice<'_, P>>, Error> {
302 if log_chunk_size < P::LOG_WIDTH || log_chunk_size > self.log_len {
303 return Err(Error::ArgumentRangeError {
304 arg: "log_chunk_size".to_string(),
305 range: P::LOG_WIDTH..self.log_len + 1,
306 });
307 }
308
309 let log_len = log_chunk_size.min(self.log_len);
310 let packed_chunk_size = 1 << (log_chunk_size - P::LOG_WIDTH);
311 let chunks = self
312 .values
313 .par_chunks(packed_chunk_size)
314 .map(move |chunk| FieldBuffer {
315 log_len,
316 values: FieldSliceData::Slice(chunk),
317 });
318
319 Ok(chunks)
320 }
321
322 pub fn split_half(&self) -> Result<(FieldSlice<'_, P>, FieldSlice<'_, P>), Error> {
328 if self.log_len == 0 {
329 return Err(Error::CannotSplit);
330 }
331
332 let new_log_len = self.log_len - 1;
333 let (first, second) = if new_log_len < P::LOG_WIDTH {
334 let packed = self.values[0];
337 let zeros = P::default();
338
339 let (first_half, second_half) = packed.interleave(zeros, new_log_len);
340
341 let first = FieldBuffer {
342 log_len: new_log_len,
343 values: FieldSliceData::Single(first_half),
344 };
345 let second = FieldBuffer {
346 log_len: new_log_len,
347 values: FieldSliceData::Single(second_half),
348 };
349
350 (first, second)
351 } else {
352 let half_len = 1 << (new_log_len - P::LOG_WIDTH);
354 let (first_half, second_half) = self.values.split_at(half_len);
355 let second_half = &second_half[..half_len];
356
357 let first = FieldBuffer {
358 log_len: new_log_len,
359 values: FieldSliceData::Slice(first_half),
360 };
361 let second = FieldBuffer {
362 log_len: new_log_len,
363 values: FieldSliceData::Slice(second_half),
364 };
365
366 (first, second)
367 };
368
369 Ok((first, second))
370 }
371}
372
373impl<P: PackedField, Data: DerefMut<Target = [P]>> FieldBuffer<P, Data> {
374 pub fn to_mut(&mut self) -> FieldSliceMut<'_, P> {
376 FieldSliceMut::from_slice(self.log_len, self.as_mut())
377 .expect("log_len matches values.len() by struct invariant")
378 }
379
380 pub fn set(&mut self, index: usize, value: P::Scalar) -> Result<(), Error> {
386 if index >= self.len() {
387 return Err(Error::ArgumentRangeError {
388 arg: "index".to_string(),
389 range: 0..self.len(),
390 });
391 }
392
393 unsafe { set_packed_slice_unchecked(&mut self.values, index, value) };
396 Ok(())
397 }
398
399 pub fn truncate(&mut self, new_log_len: usize) {
403 self.log_len = self.log_len.min(new_log_len);
404 }
405
406 pub fn zero_extend(&mut self, new_log_len: usize) -> Result<(), Error> {
413 if new_log_len <= self.log_len {
414 return Ok(());
415 }
416
417 if new_log_len > self.log_cap() {
418 return Err(Error::IncorrectArgumentLength {
419 arg: "new_log_len".to_string(),
420 expected: self.log_cap(),
421 });
422 }
423
424 if self.log_len < P::LOG_WIDTH {
425 let first_elem = self.values.first_mut().expect("values.len() >= 1");
426 for i in 1 << self.log_len..(1 << new_log_len).min(P::WIDTH) {
427 first_elem.set(i, P::Scalar::ZERO);
428 }
429 }
430
431 let packed_start = 1 << self.log_len.saturating_sub(P::LOG_WIDTH);
432 let packed_end = 1 << new_log_len.saturating_sub(P::LOG_WIDTH);
433 self.values[packed_start..packed_end].fill(P::zero());
434
435 self.log_len = new_log_len;
436 Ok(())
437 }
438
439 pub fn resize(&mut self, new_log_len: usize) -> Result<(), Error> {
447 if new_log_len > self.log_cap() {
448 return Err(Error::IncorrectArgumentLength {
449 arg: "new_log_len".to_string(),
450 expected: self.log_cap(),
451 });
452 }
453
454 self.log_len = new_log_len;
455 Ok(())
456 }
457
458 pub fn chunks_mut(
465 &mut self,
466 log_chunk_size: usize,
467 ) -> Result<impl Iterator<Item = FieldSliceMut<'_, P>>, Error> {
468 if log_chunk_size < P::LOG_WIDTH || log_chunk_size > self.log_len {
469 return Err(Error::ArgumentRangeError {
470 arg: "log_chunk_size".to_string(),
471 range: P::LOG_WIDTH..self.log_len + 1,
472 });
473 }
474
475 let chunk_count = 1 << (self.log_len - log_chunk_size);
476 let packed_chunk_size = 1 << log_chunk_size.saturating_sub(P::LOG_WIDTH);
477 let chunks = self
478 .values
479 .chunks_mut(packed_chunk_size)
480 .take(chunk_count)
481 .map(move |chunk| FieldBuffer {
482 log_len: log_chunk_size,
483 values: FieldSliceDataMut::Slice(chunk),
484 });
485
486 Ok(chunks)
487 }
488
489 pub fn split_half_mut<F, R>(&mut self, f: F) -> Result<R, Error>
499 where
500 F: FnOnce(&mut FieldSliceMut<'_, P>, &mut FieldSliceMut<'_, P>) -> R,
501 {
502 if self.log_len == 0 {
503 return Err(Error::CannotSplit);
504 }
505
506 let new_log_len = self.log_len - 1;
507
508 if new_log_len < P::LOG_WIDTH {
509 let packed = self.values[0];
511 let zeros = P::default();
512 let (mut first_half, mut second_half) = packed.interleave(zeros, new_log_len);
513
514 let mut first = FieldBuffer {
521 log_len: new_log_len,
522 values: FieldSliceDataMut::Slice(slice::from_mut(&mut first_half)),
523 };
524 let mut second = FieldBuffer {
525 log_len: new_log_len,
526 values: FieldSliceDataMut::Slice(slice::from_mut(&mut second_half)),
527 };
528
529 let result = f(&mut first, &mut second);
531
532 (self.values[0], _) = first_half.interleave(second_half, new_log_len);
535
536 Ok(result)
537 } else {
538 let half_len = 1 << (new_log_len - P::LOG_WIDTH);
540 let (first_half, second_half) = self.values.split_at_mut(half_len);
541 let second_half = &mut second_half[..half_len];
542
543 let mut first = FieldBuffer {
544 log_len: new_log_len,
545 values: FieldSliceDataMut::Slice(first_half),
546 };
547 let mut second = FieldBuffer {
548 log_len: new_log_len,
549 values: FieldSliceDataMut::Slice(second_half),
550 };
551
552 Ok(f(&mut first, &mut second))
553 }
554 }
555
556 pub fn split_half_mut_no_closure(&mut self) -> Result<FieldBufferSplitMut<'_, P>, Error> {
568 if self.log_len == 0 {
569 return Err(Error::CannotSplit);
570 }
571
572 let new_log_len = self.log_len - 1;
573 if new_log_len < P::LOG_WIDTH {
574 let packed = self.values[0];
576 let zeros = P::default();
577 let (lo_half, hi_half) = packed.interleave(zeros, new_log_len);
578
579 Ok(FieldBufferSplitMut(FieldBufferSplitMutInner::Singles {
580 log_len: new_log_len,
581 lo_half,
582 hi_half,
583 parent: &mut self.values[0],
584 }))
585 } else {
586 let half_len = 1 << (new_log_len - P::LOG_WIDTH);
588 let (lo_half, hi_half) = self.values.split_at_mut(half_len);
589 let hi_half = &mut hi_half[..half_len];
590
591 Ok(FieldBufferSplitMut(FieldBufferSplitMutInner::Slices {
592 log_len: new_log_len,
593 lo_half,
594 hi_half,
595 }))
596 }
597 }
598}
599
600impl<P: PackedField, Data: Deref<Target = [P]>> AsRef<[P]> for FieldBuffer<P, Data> {
601 #[inline]
602 fn as_ref(&self) -> &[P] {
603 &self.values[..1 << self.log_len.saturating_sub(P::LOG_WIDTH)]
604 }
605}
606
607impl<P: PackedField, Data: DerefMut<Target = [P]>> AsMut<[P]> for FieldBuffer<P, Data> {
608 #[inline]
609 fn as_mut(&mut self) -> &mut [P] {
610 &mut self.values[..1 << self.log_len.saturating_sub(P::LOG_WIDTH)]
611 }
612}
613
614pub type FieldSlice<'a, P> = FieldBuffer<P, FieldSliceData<'a, P>>;
616
617pub type FieldSliceMut<'a, P> = FieldBuffer<P, FieldSliceDataMut<'a, P>>;
619
620impl<'a, P: PackedField> FieldSlice<'a, P> {
621 pub fn from_slice(log_len: usize, slice: &'a [P]) -> Result<Self, Error> {
628 FieldBuffer::new(log_len, FieldSliceData::Slice(slice))
629 }
630}
631
632impl<'a, P: PackedField, Data: Deref<Target = [P]>> From<&'a FieldBuffer<P, Data>>
633 for FieldSlice<'a, P>
634{
635 fn from(buffer: &'a FieldBuffer<P, Data>) -> Self {
636 buffer.to_ref()
637 }
638}
639
640impl<'a, P: PackedField> FieldSliceMut<'a, P> {
641 pub fn from_slice(log_len: usize, slice: &'a mut [P]) -> Result<Self, Error> {
648 FieldBuffer::new(log_len, FieldSliceDataMut::Slice(slice))
649 }
650}
651
652impl<'a, P: PackedField, Data: DerefMut<Target = [P]>> From<&'a mut FieldBuffer<P, Data>>
653 for FieldSliceMut<'a, P>
654{
655 fn from(buffer: &'a mut FieldBuffer<P, Data>) -> Self {
656 buffer.to_mut()
657 }
658}
659
660#[derive(Debug)]
661pub enum FieldSliceData<'a, P> {
662 Single(P),
663 Slice(&'a [P]),
664}
665
666impl<'a, P> Deref for FieldSliceData<'a, P> {
667 type Target = [P];
668
669 fn deref(&self) -> &Self::Target {
670 match self {
671 FieldSliceData::Single(val) => slice::from_ref(val),
672 FieldSliceData::Slice(slice) => slice,
673 }
674 }
675}
676
677#[derive(Debug)]
678pub enum FieldSliceDataMut<'a, P> {
679 Single(P),
680 Slice(&'a mut [P]),
681}
682
683impl<'a, P> Deref for FieldSliceDataMut<'a, P> {
684 type Target = [P];
685
686 fn deref(&self) -> &Self::Target {
687 match self {
688 FieldSliceDataMut::Single(val) => slice::from_ref(val),
689 FieldSliceDataMut::Slice(slice) => slice,
690 }
691 }
692}
693
694impl<'a, P> DerefMut for FieldSliceDataMut<'a, P> {
695 fn deref_mut(&mut self) -> &mut Self::Target {
696 match self {
697 FieldSliceDataMut::Single(val) => slice::from_mut(val),
698 FieldSliceDataMut::Slice(slice) => slice,
699 }
700 }
701}
702
703#[derive(Debug)]
705pub struct FieldBufferSplitMut<'a, P: PackedField>(FieldBufferSplitMutInner<'a, P>);
706
707impl<'a, P: PackedField> FieldBufferSplitMut<'a, P> {
708 pub fn halves(&mut self) -> (FieldSliceMut<'_, P>, FieldSliceMut<'_, P>) {
709 match &mut self.0 {
710 FieldBufferSplitMutInner::Singles {
711 log_len,
712 lo_half,
713 hi_half,
714 parent: _,
715 } => (
716 FieldBuffer {
717 log_len: *log_len,
718 values: FieldSliceDataMut::Slice(slice::from_mut(lo_half)),
719 },
720 FieldBuffer {
721 log_len: *log_len,
722 values: FieldSliceDataMut::Slice(slice::from_mut(hi_half)),
723 },
724 ),
725 FieldBufferSplitMutInner::Slices {
726 log_len,
727 lo_half,
728 hi_half,
729 } => (
730 FieldBuffer {
731 log_len: *log_len,
732 values: FieldSliceDataMut::Slice(lo_half),
733 },
734 FieldBuffer {
735 log_len: *log_len,
736 values: FieldSliceDataMut::Slice(hi_half),
737 },
738 ),
739 }
740 }
741}
742
743#[derive(Debug)]
744enum FieldBufferSplitMutInner<'a, P: PackedField> {
745 Singles {
746 log_len: usize,
747 lo_half: P,
748 hi_half: P,
749 parent: &'a mut P,
750 },
751 Slices {
752 log_len: usize,
753 lo_half: &'a mut [P],
754 hi_half: &'a mut [P],
755 },
756}
757
758impl<'a, P: PackedField> Drop for FieldBufferSplitMutInner<'a, P> {
759 fn drop(&mut self) {
760 match self {
761 Self::Singles {
762 log_len,
763 lo_half,
764 hi_half,
765 parent,
766 } => {
767 (**parent, _) = (*lo_half).interleave(*hi_half, *log_len);
770 }
771 Self::Slices { .. } => {}
772 }
773 }
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779 use crate::test_utils::{B128, Packed128b};
780
781 type P = Packed128b;
782 type F = B128;
783
784 #[test]
785 fn test_zeros() {
786 let buffer = FieldBuffer::<P>::zeros(6); assert_eq!(buffer.log_len(), 6);
790 assert_eq!(buffer.len(), 64);
791
792 for i in 0..64 {
794 assert_eq!(buffer.get(i).unwrap(), F::ZERO);
795 }
796
797 let buffer = FieldBuffer::<P>::zeros(1); assert_eq!(buffer.log_len(), 1);
800 assert_eq!(buffer.len(), 2);
801
802 for i in 0..2 {
804 assert_eq!(buffer.get(i).unwrap(), F::ZERO);
805 }
806 }
807
808 #[test]
809 fn test_from_values_below_packing_width() {
810 let values = vec![F::new(1), F::new(2)]; let buffer = FieldBuffer::<P>::from_values(&values).unwrap();
815
816 assert_eq!(buffer.log_len(), 1); assert_eq!(buffer.len(), 2);
818
819 assert_eq!(buffer.get(0).unwrap(), F::new(1));
821 assert_eq!(buffer.get(1).unwrap(), F::new(2));
822 }
823
824 #[test]
825 fn test_from_values_above_packing_width() {
826 let values: Vec<F> = (0..16).map(F::new).collect(); let buffer = FieldBuffer::<P>::from_values(&values).unwrap();
831
832 assert_eq!(buffer.log_len(), 4); assert_eq!(buffer.len(), 16);
834
835 for i in 0..16 {
837 assert_eq!(buffer.get(i).unwrap(), F::new(i as u128));
838 }
839 }
840
841 #[test]
842 fn test_from_values_non_power_of_two() {
843 let values: Vec<F> = (0..7).map(F::new).collect(); let result = FieldBuffer::<P>::from_values(&values);
847
848 assert!(matches!(result, Err(Error::PowerOfTwoLengthRequired)));
849
850 let values: Vec<F> = vec![];
852 let result = FieldBuffer::<P>::from_values(&values);
853 assert!(matches!(result, Err(Error::PowerOfTwoLengthRequired)));
854 }
855
856 #[test]
857 fn test_new_below_packing_width() {
858 let mut packed_values = vec![P::default()];
863 let mut buffer = FieldBuffer::new(1, packed_values.as_mut_slice()).unwrap();
864
865 assert_eq!(buffer.log_len(), 1);
866 assert_eq!(buffer.len(), 2);
867
868 buffer.set(0, F::new(10)).unwrap();
870 buffer.set(1, F::new(20)).unwrap();
871 assert_eq!(buffer.get(0).unwrap(), F::new(10));
872 assert_eq!(buffer.get(1).unwrap(), F::new(20));
873 }
874
875 #[test]
876 fn test_new_above_packing_width() {
877 let mut packed_values = vec![P::default(); 4];
882 let mut buffer = FieldBuffer::new(4, packed_values.as_mut_slice()).unwrap();
883
884 assert_eq!(buffer.log_len(), 4);
885 assert_eq!(buffer.len(), 16);
886
887 for i in 0..16 {
889 buffer.set(i, F::new(i as u128 * 10)).unwrap();
890 }
891 for i in 0..16 {
892 assert_eq!(buffer.get(i).unwrap(), F::new(i as u128 * 10));
893 }
894 }
895
896 #[test]
897 fn test_new_non_power_of_two() {
898 let packed_values = vec![P::default(); 3]; let result = FieldBuffer::new(4, packed_values.as_slice());
903
904 assert!(matches!(result, Err(Error::IncorrectArgumentLength { .. })));
905
906 let packed_values = vec![P::default(); 5]; let result = FieldBuffer::new(4, packed_values.as_slice());
909
910 assert!(matches!(result, Err(Error::IncorrectArgumentLength { .. })));
911 }
912
913 #[test]
914 fn test_get_set() {
915 let mut buffer = FieldBuffer::<P>::zeros(3); for i in 0..8 {
919 buffer.set(i, F::new(i as u128)).unwrap();
920 }
921
922 for i in 0..8 {
924 assert_eq!(buffer.get(i).unwrap(), F::new(i as u128));
925 }
926
927 assert!(buffer.get(8).is_err());
929 assert!(buffer.set(8, F::new(0)).is_err());
930 }
931
932 #[test]
933 fn test_chunk() {
934 let log_len = 8;
935 let values: Vec<F> = (0..1 << log_len).map(F::new).collect();
936 let buffer = FieldBuffer::<P>::from_values(&values).unwrap();
937
938 assert!(buffer.chunk(log_len + 1, 0).is_err());
940
941 for log_chunk_size in 0..=log_len {
942 let chunk_count = 1 << (log_len - log_chunk_size);
943
944 assert!(buffer.chunk(log_chunk_size, chunk_count).is_err());
946
947 for chunk_index in 0..chunk_count {
948 let chunk = buffer.chunk(log_chunk_size, chunk_index).unwrap();
949 for i in 0..1 << log_chunk_size {
950 assert_eq!(
951 chunk.get(i).unwrap(),
952 buffer.get(chunk_index << log_chunk_size | i).unwrap()
953 );
954 }
955 }
956 }
957 }
958
959 #[test]
960 fn test_chunks() {
961 let values: Vec<F> = (0..16).map(F::new).collect();
962 let buffer = FieldBuffer::<P>::from_values(&values).unwrap();
963
964 let chunks: Vec<_> = buffer.chunks(2).unwrap().collect();
966 assert_eq!(chunks.len(), 4);
967
968 for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
969 assert_eq!(chunk.len(), 4);
970 for i in 0..4 {
971 let expected = F::new((chunk_idx * 4 + i) as u128);
972 assert_eq!(chunk.get(i).unwrap(), expected);
973 }
974 }
975
976 assert!(buffer.chunks(5).is_err());
978
979 assert!(buffer.chunks(0).is_err());
982 assert!(buffer.chunks(1).is_err());
983 }
984
985 #[test]
986 fn test_chunks_par() {
987 let values: Vec<F> = (0..16).map(F::new).collect();
988 let buffer = FieldBuffer::<P>::from_values(&values).unwrap();
989
990 let chunks: Vec<_> = buffer.chunks_par(2).unwrap().collect();
992 assert_eq!(chunks.len(), 4);
993
994 for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
995 assert_eq!(chunk.len(), 4);
996 for i in 0..4 {
997 let expected = F::new((chunk_idx * 4 + i) as u128);
998 assert_eq!(chunk.get(i).unwrap(), expected);
999 }
1000 }
1001
1002 assert!(buffer.chunks_par(5).is_err());
1004
1005 assert!(buffer.chunks_par(0).is_err());
1008 assert!(buffer.chunks_par(1).is_err());
1009 }
1010
1011 #[test]
1012 fn test_chunks_mut() {
1013 let mut buffer = FieldBuffer::<P>::zeros(4); let mut chunks: Vec<_> = buffer.chunks_mut(2).unwrap().collect();
1017 assert_eq!(chunks.len(), 4);
1018
1019 for (chunk_idx, chunk) in chunks.iter_mut().enumerate() {
1020 for i in 0..chunk.len() {
1021 chunk.set(i, F::new((chunk_idx * 10 + i) as u128)).unwrap();
1022 }
1023 }
1024
1025 for chunk_idx in 0..4 {
1027 for i in 0..4 {
1028 let expected = F::new((chunk_idx * 10 + i) as u128);
1029 assert_eq!(buffer.get(chunk_idx * 4 + i).unwrap(), expected);
1030 }
1031 }
1032
1033 assert!(buffer.chunks_mut(0).is_err());
1035 assert!(buffer.chunks_mut(1).is_err());
1036 }
1037
1038 #[test]
1039 fn test_to_ref_to_mut() {
1040 let mut buffer = FieldBuffer::<P>::zeros_truncated(3, 5).unwrap();
1041
1042 let slice_ref = buffer.to_ref();
1044 assert_eq!(slice_ref.len(), buffer.len());
1045 assert_eq!(slice_ref.log_len(), buffer.log_len());
1046 assert_eq!(slice_ref.as_ref().len(), 1 << slice_ref.log_len().saturating_sub(P::LOG_WIDTH));
1047
1048 let mut slice_mut = buffer.to_mut();
1050 slice_mut.set(0, F::new(123)).unwrap();
1051 assert_eq!(slice_mut.as_mut().len(), 1 << slice_mut.log_len().saturating_sub(P::LOG_WIDTH));
1052 assert_eq!(buffer.get(0).unwrap(), F::new(123));
1053 }
1054
1055 #[test]
1056 fn test_split_half() {
1057 let values: Vec<F> = (0..16).map(F::new).collect();
1059 let buffer = FieldBuffer::<P>::from_values_truncated(&values, 5).unwrap();
1061
1062 let (first, second) = buffer.split_half().unwrap();
1063 assert_eq!(first.len(), 8);
1064 assert_eq!(second.len(), 8);
1065
1066 for i in 0..8 {
1068 assert_eq!(first.get(i).unwrap(), F::new(i as u128));
1069 assert_eq!(second.get(i).unwrap(), F::new((i + 8) as u128));
1070 }
1071
1072 let values: Vec<F> = (0..4).map(F::new).collect();
1076 let buffer = FieldBuffer::<P>::from_values_truncated(&values, 3).unwrap();
1077
1078 let (first, second) = buffer.split_half().unwrap();
1079 assert_eq!(first.len(), 2);
1080 assert_eq!(second.len(), 2);
1081
1082 match &first.values {
1084 FieldSliceData::Single(_) => {}
1085 _ => panic!("Expected Single variant for first half"),
1086 }
1087 match &second.values {
1088 FieldSliceData::Single(_) => {}
1089 _ => panic!("Expected Single variant for second half"),
1090 }
1091
1092 assert_eq!(first.get(0).unwrap(), F::new(0));
1094 assert_eq!(first.get(1).unwrap(), F::new(1));
1095 assert_eq!(second.get(0).unwrap(), F::new(2));
1096 assert_eq!(second.get(1).unwrap(), F::new(3));
1097
1098 let values: Vec<F> = vec![F::new(10), F::new(20)];
1100 let buffer = FieldBuffer::<P>::from_values_truncated(&values, 3).unwrap();
1101
1102 let (first, second) = buffer.split_half().unwrap();
1103 assert_eq!(first.len(), 1);
1104 assert_eq!(second.len(), 1);
1105
1106 match &first.values {
1108 FieldSliceData::Single(_) => {}
1109 _ => panic!("Expected Single variant for first half"),
1110 }
1111 match &second.values {
1112 FieldSliceData::Single(_) => {}
1113 _ => panic!("Expected Single variant for second half"),
1114 }
1115
1116 assert_eq!(first.get(0).unwrap(), F::new(10));
1117 assert_eq!(second.get(0).unwrap(), F::new(20));
1118
1119 let values = vec![F::new(42)];
1121 let buffer = FieldBuffer::<P>::from_values(&values).unwrap();
1122
1123 let result = buffer.split_half();
1124 assert!(matches!(result, Err(Error::CannotSplit)));
1125 }
1126
1127 #[test]
1128 fn test_split_half_mut() {
1129 let mut buffer = FieldBuffer::<P>::zeros_truncated(4, 5).unwrap(); for i in 0..16 {
1134 buffer.set(i, F::new(i as u128)).unwrap();
1135 }
1136
1137 buffer
1138 .split_half_mut(|first, second| {
1139 assert_eq!(first.len(), 8);
1140 assert_eq!(second.len(), 8);
1141
1142 for i in 0..8 {
1144 first.set(i, F::new((i * 10) as u128)).unwrap();
1145 second.set(i, F::new((i * 20) as u128)).unwrap();
1146 }
1147 })
1148 .unwrap();
1149
1150 for i in 0..8 {
1152 assert_eq!(buffer.get(i).unwrap(), F::new((i * 10) as u128));
1153 assert_eq!(buffer.get(i + 8).unwrap(), F::new((i * 20) as u128));
1154 }
1155
1156 let mut buffer = FieldBuffer::<P>::zeros_truncated(2, 4).unwrap(); for i in 0..4 {
1162 buffer.set(i, F::new(i as u128)).unwrap();
1163 }
1164
1165 buffer
1166 .split_half_mut(|first, second| {
1167 assert_eq!(first.len(), 2);
1168 assert_eq!(second.len(), 2);
1169
1170 first.set(0, F::new(100)).unwrap();
1172 first.set(1, F::new(101)).unwrap();
1173 second.set(0, F::new(200)).unwrap();
1174 second.set(1, F::new(201)).unwrap();
1175 })
1176 .unwrap();
1177
1178 assert_eq!(buffer.get(0).unwrap(), F::new(100));
1180 assert_eq!(buffer.get(1).unwrap(), F::new(101));
1181 assert_eq!(buffer.get(2).unwrap(), F::new(200));
1182 assert_eq!(buffer.get(3).unwrap(), F::new(201));
1183
1184 let mut buffer = FieldBuffer::<P>::zeros_truncated(1, 4).unwrap(); buffer.set(0, F::new(10)).unwrap();
1188 buffer.set(1, F::new(20)).unwrap();
1189
1190 buffer
1191 .split_half_mut(|first, second| {
1192 assert_eq!(first.len(), 1);
1193 assert_eq!(second.len(), 1);
1194
1195 first.set(0, F::new(30)).unwrap();
1197 second.set(0, F::new(40)).unwrap();
1198 })
1199 .unwrap();
1200
1201 assert_eq!(buffer.get(0).unwrap(), F::new(30));
1203 assert_eq!(buffer.get(1).unwrap(), F::new(40));
1204
1205 let mut buffer = FieldBuffer::<P>::zeros(0); let result = buffer.split_half_mut(|_, _| {});
1209 assert!(matches!(result, Err(Error::CannotSplit)));
1210 }
1211
1212 #[test]
1213 fn test_zero_extend() {
1214 let log_len = 10;
1215 let nonzero_scalars = (0..1 << log_len).map(|i| F::new(i + 1)).collect::<Vec<_>>();
1216 let mut buffer = FieldBuffer::<P>::from_values(&nonzero_scalars).unwrap();
1217 buffer.truncate(0);
1218
1219 for i in 0..log_len {
1220 buffer.zero_extend(i + 1).unwrap();
1221
1222 for j in 1 << i..1 << (i + 1) {
1223 assert!(buffer.get(j).unwrap().is_zero());
1224 }
1225 }
1226 }
1227
1228 #[test]
1229 fn test_resize() {
1230 let mut buffer = FieldBuffer::<P>::zeros(4); for i in 0..16 {
1234 buffer.set(i, F::new(i as u128)).unwrap();
1235 }
1236
1237 buffer.resize(3).unwrap();
1238 assert_eq!(buffer.log_len(), 3);
1239 assert_eq!(buffer.get(7).unwrap(), F::new(7));
1240
1241 buffer.resize(4).unwrap();
1242 assert_eq!(buffer.log_len(), 4);
1243 assert_eq!(buffer.get(15).unwrap(), F::new(15));
1244
1245 assert!(
1246 matches!(buffer.resize(5), Err(Error::IncorrectArgumentLength { arg, expected }) if arg == "new_log_len" && expected == 4)
1247 );
1248
1249 buffer.resize(2).unwrap();
1250 assert_eq!(buffer.log_len(), 2);
1251 }
1252
1253 #[test]
1254 fn test_split_half_mut_no_closure() {
1255 let mut buffer = FieldBuffer::<P>::zeros(4); for i in 0..16 {
1260 buffer.set(i, F::new(i as u128)).unwrap();
1261 }
1262
1263 {
1264 let mut split = buffer.split_half_mut_no_closure().unwrap();
1265 let (mut first, mut second) = split.halves();
1266
1267 assert_eq!(first.len(), 8);
1268 assert_eq!(second.len(), 8);
1269
1270 for i in 0..8 {
1272 first.set(i, F::new((i * 10) as u128)).unwrap();
1273 second.set(i, F::new((i * 20) as u128)).unwrap();
1274 }
1275 }
1277
1278 for i in 0..8 {
1280 assert_eq!(buffer.get(i).unwrap(), F::new((i * 10) as u128));
1281 assert_eq!(buffer.get(i + 8).unwrap(), F::new((i * 20) as u128));
1282 }
1283
1284 let mut buffer = FieldBuffer::<P>::zeros(2); for i in 0..4 {
1290 buffer.set(i, F::new(i as u128)).unwrap();
1291 }
1292
1293 {
1294 let mut split = buffer.split_half_mut_no_closure().unwrap();
1295 let (mut first, mut second) = split.halves();
1296
1297 assert_eq!(first.len(), 2);
1298 assert_eq!(second.len(), 2);
1299
1300 first.set(0, F::new(100)).unwrap();
1302 first.set(1, F::new(101)).unwrap();
1303 second.set(0, F::new(200)).unwrap();
1304 second.set(1, F::new(201)).unwrap();
1305 }
1307
1308 assert_eq!(buffer.get(0).unwrap(), F::new(100));
1310 assert_eq!(buffer.get(1).unwrap(), F::new(101));
1311 assert_eq!(buffer.get(2).unwrap(), F::new(200));
1312 assert_eq!(buffer.get(3).unwrap(), F::new(201));
1313
1314 let mut buffer = FieldBuffer::<P>::zeros(1); buffer.set(0, F::new(10)).unwrap();
1318 buffer.set(1, F::new(20)).unwrap();
1319
1320 {
1321 let mut split = buffer.split_half_mut_no_closure().unwrap();
1322 let (mut first, mut second) = split.halves();
1323
1324 assert_eq!(first.len(), 1);
1325 assert_eq!(second.len(), 1);
1326
1327 first.set(0, F::new(30)).unwrap();
1329 second.set(0, F::new(40)).unwrap();
1330 }
1332
1333 assert_eq!(buffer.get(0).unwrap(), F::new(30));
1335 assert_eq!(buffer.get(1).unwrap(), F::new(40));
1336
1337 let mut buffer = FieldBuffer::<P>::zeros(0); let result = buffer.split_half_mut_no_closure();
1341 assert!(matches!(result, Err(Error::CannotSplit)));
1342 }
1343}