1use std::{
4 cell::{Ref, RefCell, RefMut},
5 iter,
6 ops::{Deref, DerefMut},
7 slice,
8 sync::Arc,
9};
10
11use binius_compute::alloc::{ComputeAllocator, HostBumpAllocator};
12use binius_core::witness::{MultilinearExtensionIndex, MultilinearWitness};
13use binius_fast_compute::arith_circuit::ArithCircuitPoly;
14use binius_field::{
15 ExtensionField, PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield, TowerField,
16 arch::OptimalUnderlier,
17 as_packed_field::PackedType,
18 packed::{get_packed_slice, set_packed_slice},
19};
20use binius_math::{
21 ArithCircuit, CompositionPoly, MultilinearExtension, MultilinearPoly, RowsBatchRef,
22};
23use binius_maybe_rayon::prelude::*;
24use binius_utils::checked_arithmetics::checked_log_2;
25use bytemuck::{Pod, must_cast_slice, must_cast_slice_mut, zeroed_vec};
26use either::Either;
27use getset::CopyGetters;
28use itertools::Itertools;
29
30use super::{
31 ColumnDef, ColumnId, ColumnInfo, ConstraintSystem, Expr,
32 column::{Col, ColumnShape},
33 constraint_system::OracleMapping,
34 error::Error,
35 table::{self, Table, TableId},
36 types::{B1, B8, B16, B32, B64, B128},
37};
38use crate::builder::multi_iter::MultiIterator;
39
40pub struct WitnessIndex<'cs, 'alloc, P = PackedType<OptimalUnderlier, B128>>
48where
49 P: PackedField,
50 P::Scalar: TowerField,
51{
52 cs: &'cs ConstraintSystem<P::Scalar>,
53 allocator: &'alloc HostBumpAllocator<'alloc, P>,
54 tables: Vec<Either<&'cs Table<P::Scalar>, TableWitnessIndex<'cs, 'alloc, P>>>,
56}
57
58impl<'cs, 'alloc, F: TowerField, P: PackedField<Scalar = F>> WitnessIndex<'cs, 'alloc, P> {
59 pub fn new(
61 cs: &'cs ConstraintSystem<F>,
62 allocator: &'alloc HostBumpAllocator<'alloc, P>,
63 ) -> Self {
64 Self {
65 cs,
66 allocator,
67 tables: cs.tables.iter().map(Either::Left).collect(),
68 }
69 }
70
71 pub fn init_table(
72 &mut self,
73 table_id: TableId,
74 size: usize,
75 ) -> Result<&mut TableWitnessIndex<'cs, 'alloc, P>, Error> {
76 match self.tables.get_mut(table_id) {
77 Some(entry) => match entry {
78 Either::Left(table) => {
79 if size == 0 {
80 Err(Error::EmptyTable { table_id })
81 } else {
82 let table_witness = TableWitnessIndex::new(self.allocator, table, size)?;
83 *entry = Either::Right(table_witness);
84 let Either::Right(table_witness) = entry else {
85 unreachable!("entry is assigned to this pattern on the previous line")
86 };
87 Ok(table_witness)
88 }
89 }
90 Either::Right(_) => Err(Error::TableIndexAlreadyInitialized { table_id }),
91 },
92 None => Err(Error::MissingTable { table_id }),
93 }
94 }
95
96 pub fn get_table(
97 &mut self,
98 table_id: TableId,
99 ) -> Option<&mut TableWitnessIndex<'cs, 'alloc, P>> {
100 self.tables.get_mut(table_id).and_then(|table| match table {
101 Either::Left(_) => None,
102 Either::Right(index) => Some(index),
103 })
104 }
105
106 pub fn fill_table_sequential<T: TableFiller<P>>(
107 &mut self,
108 filler: &T,
109 rows: &[T::Event],
110 ) -> Result<(), Error> {
111 self.init_and_fill_table(
112 filler.id(),
113 |table_witness, rows| table_witness.fill_sequential(filler, rows),
114 rows,
115 )
116 }
117
118 pub fn fill_table_parallel<T>(&mut self, filler: &T, rows: &[T::Event]) -> Result<(), Error>
119 where
120 T: TableFiller<P> + Sync,
121 T::Event: Sync,
122 {
123 self.init_and_fill_table(
124 filler.id(),
125 |table_witness, rows| table_witness.fill_parallel(filler, rows),
126 rows,
127 )
128 }
129
130 fn init_and_fill_table<Event>(
131 &mut self,
132 table_id: TableId,
133 fill: impl FnOnce(&mut TableWitnessIndex<'cs, 'alloc, P>, &[Event]) -> Result<(), Error>,
134 rows: &[Event],
135 ) -> Result<(), Error> {
136 match self.tables.get_mut(table_id) {
137 Some(entry) => match entry {
138 Either::Right(witness) => fill(witness, rows),
139 Either::Left(table) => {
140 if rows.is_empty() {
141 Ok(())
142 } else {
143 let mut table_witness =
144 TableWitnessIndex::new(self.allocator, table, rows.len())?;
145 fill(&mut table_witness, rows)?;
146 *entry = Either::Right(table_witness);
147 Ok(())
148 }
149 }
150 },
151 None => Err(Error::MissingTable { table_id }),
152 }
153 }
154
155 pub fn table_sizes(&self) -> Vec<usize> {
157 self.tables
158 .iter()
159 .map(|entry| match entry {
160 Either::Left(_) => 0,
161 Either::Right(index) => index.size(),
162 })
163 .collect()
164 }
165
166 fn mk_column_witness<'a>(
167 log_capacity: usize,
168 shape: ColumnShape,
169 data: &'a [P],
170 ) -> MultilinearWitness<'a, P>
171 where
172 P: PackedExtension<B1>
173 + PackedExtension<B8>
174 + PackedExtension<B16>
175 + PackedExtension<B32>
176 + PackedExtension<B64>
177 + PackedExtension<B128>,
178 {
179 let n_vars = log_capacity + shape.log_values_per_row;
180 let underlier_count =
181 1 << (n_vars + shape.tower_height).saturating_sub(P::LOG_WIDTH + F::TOWER_LEVEL);
182 multilin_poly_from_underlier_data(&data[..underlier_count], n_vars, shape.tower_height)
183 }
184
185 pub fn into_multilinear_extension_index(self) -> MultilinearExtensionIndex<'alloc, P>
189 where
190 P: PackedExtension<B1>
191 + PackedExtension<B8>
192 + PackedExtension<B16>
193 + PackedExtension<B32>
194 + PackedExtension<B64>
195 + PackedExtension<B128>,
196 {
197 let oracle_lookup = self.cs.oracle_lookup();
198
199 let mut index = MultilinearExtensionIndex::new();
200
201 for table_witness in self.tables {
202 let Either::Right(table_witness) = table_witness else {
203 continue;
204 };
205 let cols = immutable_witness_index_columns(table_witness.cols);
206
207 for col in cols.into_iter() {
222 let oracle_mapping = *oracle_lookup.lookup(col.column_id);
223 match oracle_mapping {
224 OracleMapping::Regular(oracle_id) => index
225 .update_multilin_poly([(
226 oracle_id,
227 Self::mk_column_witness(
228 table_witness.log_capacity,
229 col.shape,
230 col.data,
231 ),
232 )])
233 .unwrap(),
234 OracleMapping::TransparentCompound {
235 original,
236 repeating,
237 } => {
238 let original_witness = Self::mk_column_witness(0, col.shape, col.data);
241 let repeating_witness = Self::mk_column_witness(
242 table_witness.log_capacity,
243 col.shape,
244 col.data,
245 );
246 index
247 .update_multilin_poly([
248 (original, original_witness),
249 (repeating, repeating_witness),
250 ])
251 .unwrap();
252 }
253 }
254 }
255 }
256 index
257 }
258}
259
260impl<'cs, 'alloc, P> WitnessIndex<'cs, 'alloc, P>
261where
262 P: PackedField<Scalar: TowerField>
263 + PackedExtension<B1>
264 + PackedExtension<B8>
265 + PackedExtension<B16>
266 + PackedExtension<B32>
267 + PackedExtension<B64>
268 + PackedExtension<B128>,
269{
270 pub fn fill_constant_cols(&mut self) -> Result<(), Error> {
273 for table in self.tables.iter_mut() {
274 match table.as_mut() {
275 Either::Left(_) => (),
276 Either::Right(table_witness_index) => {
278 let table = table_witness_index.table();
279 let segment = table_witness_index.full_segment();
280 for col in table.columns.iter() {
281 if let ColumnDef::Constant { data, .. } = &col.col {
282 let mut witness_data = segment.get_dyn_mut(col.id)?;
283 let len = witness_data.size();
284 for (i, scalar) in data.iter().cycle().take(len).enumerate() {
285 witness_data.set(i, *scalar)?
286 }
287 }
288 }
289 }
290 }
291 }
292 Ok(())
293 }
294}
295
296fn multilin_poly_from_underlier_data<P>(
297 data: &[P],
298 n_vars: usize,
299 tower_height: usize,
300) -> Arc<dyn MultilinearPoly<P> + Send + Sync + '_>
301where
302 P: PackedExtension<B1>
303 + PackedExtension<B8>
304 + PackedExtension<B16>
305 + PackedExtension<B32>
306 + PackedExtension<B64>
307 + PackedExtension<B128>,
308{
309 match tower_height {
310 0 => MultilinearExtension::new(n_vars, PackedExtension::<B1>::cast_bases(data))
311 .unwrap()
312 .specialize_arc_dyn(),
313 3 => MultilinearExtension::new(n_vars, PackedExtension::<B8>::cast_bases(data))
314 .unwrap()
315 .specialize_arc_dyn(),
316 4 => MultilinearExtension::new(n_vars, PackedExtension::<B16>::cast_bases(data))
317 .unwrap()
318 .specialize_arc_dyn(),
319 5 => MultilinearExtension::new(n_vars, PackedExtension::<B32>::cast_bases(data))
320 .unwrap()
321 .specialize_arc_dyn(),
322 6 => MultilinearExtension::new(n_vars, PackedExtension::<B64>::cast_bases(data))
323 .unwrap()
324 .specialize_arc_dyn(),
325 7 => MultilinearExtension::new(n_vars, PackedExtension::<B128>::cast_bases(data))
326 .unwrap()
327 .specialize_arc_dyn(),
328 _ => {
329 panic!("Unsupported tower height: {tower_height}");
330 }
331 }
332}
333
334#[derive(Debug, CopyGetters)]
336pub struct TableWitnessIndex<'cs, 'alloc, P = PackedType<OptimalUnderlier, B128>>
337where
338 P: PackedField,
339 P::Scalar: TowerField,
340{
341 #[get_copy = "pub"]
342 table: &'cs Table<P::Scalar>,
343 cols: Vec<WitnessIndexColumn<'alloc, P>>,
344 #[get_copy = "pub"]
346 size: usize,
347 #[get_copy = "pub"]
348 log_capacity: usize,
349 #[get_copy = "pub"]
354 min_log_segment_size: usize,
355}
356
357#[derive(Debug)]
358pub struct WitnessIndexColumn<'a, P: PackedField> {
359 shape: ColumnShape,
360 data: WitnessDataMut<'a, P>,
361 column_id: ColumnId,
362}
363
364#[derive(Debug, Clone)]
365enum WitnessColumnInfo<T> {
366 Owned(T),
367 SameAsIndex(usize),
369}
370
371type WitnessDataMut<'a, P> = WitnessColumnInfo<&'a mut [P]>;
372
373impl<'a, P: PackedField> WitnessDataMut<'a, P> {
374 pub fn new_owned(allocator: &'a HostBumpAllocator<'a, P>, log_underlier_count: usize) -> Self {
375 let slice = allocator
376 .alloc(1 << log_underlier_count)
377 .expect("failed to allocate witness data slice");
378
379 Self::Owned(slice)
380 }
381}
382
383type RefCellData<'a, P> = WitnessColumnInfo<RefCell<&'a mut [P]>>;
384
385#[derive(Debug)]
386struct ImmutableWitnessIndexColumn<'a, P: PackedField> {
387 shape: ColumnShape,
388 data: &'a [P],
389 column_id: ColumnId,
390}
391
392fn immutable_witness_index_columns<P: PackedField>(
395 cols: Vec<WitnessIndexColumn<P>>,
396) -> Vec<ImmutableWitnessIndexColumn<P>> {
397 let mut result = Vec::<ImmutableWitnessIndexColumn<_>>::with_capacity(cols.len());
398 for col in cols {
399 result.push(ImmutableWitnessIndexColumn {
400 shape: col.shape,
401 data: match col.data {
402 WitnessDataMut::Owned(data) => data,
403 WitnessDataMut::SameAsIndex(index) => result[index].data,
404 },
405 column_id: col.column_id,
406 });
407 }
408 result
409}
410
411impl<'cs, 'alloc, F: TowerField, P: PackedField<Scalar = F>> TableWitnessIndex<'cs, 'alloc, P> {
412 pub(crate) fn new(
413 allocator: &'alloc HostBumpAllocator<'alloc, P>,
414 table: &'cs Table<F>,
415 size: usize,
416 ) -> Result<Self, Error> {
417 if size == 0 {
418 return Err(Error::EmptyTable {
419 table_id: table.id(),
420 });
421 }
422
423 let log_capacity = table::log_capacity(size);
424 let packed_elem_log_bits = P::LOG_WIDTH + F::TOWER_LEVEL;
425
426 let mut cols = Vec::with_capacity(table.columns.len());
427 for ColumnInfo { id, col, shape, .. } in &table.columns {
428 let data: WitnessDataMut<P> = if let ColumnDef::Packed { col: source, .. } = col {
429 WitnessDataMut::SameAsIndex(source.table_index.0)
431 } else {
432 WitnessDataMut::new_owned(
434 allocator,
435 (shape.log_cell_size() + log_capacity).saturating_sub(packed_elem_log_bits),
436 )
437 };
438 cols.push(WitnessIndexColumn {
439 shape: *shape,
440 data,
441 column_id: *id,
442 });
443 }
444
445 let min_log_segment_size = packed_elem_log_bits
448 - table
449 .columns
450 .iter()
451 .map(|col| col.shape.log_cell_size())
452 .fold(packed_elem_log_bits, |a, b| a.min(b));
453
454 let min_log_segment_size = min_log_segment_size.min(log_capacity);
458
459 Ok(Self {
460 table,
461 cols,
462 size,
463 log_capacity,
464 min_log_segment_size,
465 })
466 }
467
468 pub fn table_id(&self) -> TableId {
469 self.table.id
470 }
471
472 pub fn capacity(&self) -> usize {
473 1 << self.log_capacity
474 }
475
476 pub fn full_segment(&mut self) -> TableWitnessSegment<P> {
478 let cols = self
479 .cols
480 .iter_mut()
481 .map(|col| match &mut col.data {
482 WitnessDataMut::SameAsIndex(id) => RefCellData::SameAsIndex(*id),
483 WitnessDataMut::Owned(data) => RefCellData::Owned(RefCell::new(data)),
484 })
485 .collect();
486 TableWitnessSegment {
487 table: self.table,
488 cols,
489 log_size: self.log_capacity,
490 index: 0,
491 }
492 }
493
494 pub fn fill_sequential<T: TableFiller<P>>(
498 &mut self,
499 table: &T,
500 rows: &[T::Event],
501 ) -> Result<(), Error> {
502 let log_size = self.optimal_segment_size_heuristic();
503 self.fill_sequential_with_segment_size(table, rows, log_size)
504 }
505
506 pub fn fill_parallel<T>(&mut self, table: &T, rows: &[T::Event]) -> Result<(), Error>
510 where
511 T: TableFiller<P> + Sync,
512 T::Event: Sync,
513 {
514 let log_size = self.optimal_segment_size_heuristic();
515 self.fill_parallel_with_segment_size(table, rows, log_size)
516 }
517
518 fn optimal_segment_size_heuristic(&self) -> usize {
519 const TARGET_SEGMENT_LOG_BITS: usize = 12 + 3;
521
522 let n_cols = self.table.columns.len();
523 let median_col_log_bits = self
524 .table
525 .columns
526 .iter()
527 .map(|col| col.shape.log_cell_size())
528 .sorted()
529 .nth(n_cols / 2)
530 .unwrap_or_default();
531
532 TARGET_SEGMENT_LOG_BITS.saturating_sub(median_col_log_bits)
533 }
534
535 pub fn fill_sequential_with_segment_size<T: TableFiller<P>>(
539 &mut self,
540 table: &T,
541 rows: &[T::Event],
542 log_size: usize,
543 ) -> Result<(), Error> {
544 if rows.len() != self.size {
545 return Err(Error::IncorrectNumberOfTableEvents {
546 expected: self.size,
547 actual: rows.len(),
548 });
549 }
550
551 let mut segmented_view = TableWitnessSegmentedView::new(self, log_size);
552
553 let log_size = segmented_view.log_segment_size;
555 let segment_size = 1 << log_size;
556
557 debug_assert_ne!(rows.len(), 0);
559 let n_chunks = (rows.len() - 1) / segment_size + 1;
561
562 let (full_chunk_segments, mut rest_segments) = segmented_view.split_at(n_chunks - 1);
563
564 full_chunk_segments
566 .into_iter()
567 .zip(rows.chunks(segment_size).take(n_chunks - 1))
569 .try_for_each(|(mut witness_segment, row_chunk)| {
570 table
571 .fill(row_chunk, &mut witness_segment)
572 .map_err(Error::TableFill)
573 })?;
574
575 let row_chunk = &rows[(n_chunks - 1) * segment_size..];
580 let mut padded_row_chunk = Vec::new();
581 let row_chunk = if row_chunk.len() != segment_size {
582 padded_row_chunk.reserve(segment_size);
583 padded_row_chunk.extend_from_slice(row_chunk);
584 let last_event = row_chunk
585 .last()
586 .expect("row_chunk must be non-empty because of how n_chunk is calculated")
587 .clone();
588 padded_row_chunk.resize(segment_size, last_event);
589 &padded_row_chunk
590 } else {
591 row_chunk
592 };
593
594 let (partial_chunk_segments, rest_segments) = rest_segments.split_at(1);
595 let mut partial_chunk_segment_iter = partial_chunk_segments.into_iter();
596 let mut witness_segment = partial_chunk_segment_iter.next().expect(
597 "segmented_view.split_at called with 1 must return a view with exactly one segment",
598 );
599 table
600 .fill(row_chunk, &mut witness_segment)
601 .map_err(Error::TableFill)?;
602 assert!(partial_chunk_segment_iter.next().is_none());
603
604 let last_segment_cols = witness_segment
607 .cols
608 .iter_mut()
609 .map(|col| match col {
610 RefCellData::Owned(data) => WitnessColumnInfo::Owned(data.get_mut()),
611 RefCellData::SameAsIndex(idx) => WitnessColumnInfo::SameAsIndex(*idx),
612 })
613 .collect::<Vec<_>>();
614
615 rest_segments.into_iter().for_each(|mut segment| {
616 for (dst_col, src_col) in iter::zip(&mut segment.cols, &last_segment_cols) {
617 if let (RefCellData::Owned(dst), WitnessColumnInfo::Owned(src)) = (dst_col, src_col)
618 {
619 dst.get_mut().copy_from_slice(src)
620 }
621 }
622 });
623
624 Ok(())
625 }
626
627 pub fn fill_parallel_with_segment_size<T>(
631 &mut self,
632 table: &T,
633 rows: &[T::Event],
634 log_size: usize,
635 ) -> Result<(), Error>
636 where
637 T: TableFiller<P> + Sync,
638 T::Event: Sync,
639 {
640 if rows.len() != self.size {
641 return Err(Error::IncorrectNumberOfTableEvents {
642 expected: self.size,
643 actual: rows.len(),
644 });
645 }
646
647 let mut segmented_view = TableWitnessSegmentedView::new(self, log_size);
652
653 let log_size = segmented_view.log_segment_size;
655 let segment_size = 1 << log_size;
656
657 debug_assert_ne!(rows.len(), 0);
659 let n_chunks = (rows.len() - 1) / segment_size + 1;
661
662 let (full_chunk_segments, mut rest_segments) = segmented_view.split_at(n_chunks - 1);
663
664 full_chunk_segments
666 .into_par_iter()
667 .zip(rows.par_chunks(segment_size).take(n_chunks - 1))
669 .try_for_each(|(mut witness_segment, row_chunk)| {
670 table
671 .fill(row_chunk, &mut witness_segment)
672 .map_err(Error::TableFill)
673 })?;
674
675 let row_chunk = &rows[(n_chunks - 1) * segment_size..];
680 let mut padded_row_chunk = Vec::new();
681 let row_chunk = if row_chunk.len() != segment_size {
682 padded_row_chunk.reserve(segment_size);
683 padded_row_chunk.extend_from_slice(row_chunk);
684 let last_event = row_chunk
685 .last()
686 .expect("row_chunk must be non-empty because of how n_chunk is calculated")
687 .clone();
688 padded_row_chunk.resize(segment_size, last_event);
689 &padded_row_chunk
690 } else {
691 row_chunk
692 };
693
694 let (partial_chunk_segments, rest_segments) = rest_segments.split_at(1);
695 let mut partial_chunk_segment_iter = partial_chunk_segments.into_iter();
696 let mut witness_segment = partial_chunk_segment_iter.next().expect(
697 "segmented_view.split_at called with 1 must return a view with exactly one segment",
698 );
699 table
700 .fill(row_chunk, &mut witness_segment)
701 .map_err(Error::TableFill)?;
702 assert!(partial_chunk_segment_iter.next().is_none());
703
704 let last_segment_cols = witness_segment
707 .cols
708 .iter_mut()
709 .map(|col| match col {
710 RefCellData::Owned(data) => WitnessColumnInfo::Owned(data.get_mut()),
711 RefCellData::SameAsIndex(idx) => WitnessColumnInfo::SameAsIndex(*idx),
712 })
713 .collect::<Vec<_>>();
714
715 rest_segments.into_par_iter().for_each(|mut segment| {
716 for (dst_col, src_col) in iter::zip(&mut segment.cols, &last_segment_cols) {
717 if let (RefCellData::Owned(dst), WitnessColumnInfo::Owned(src)) = (dst_col, src_col)
718 {
719 dst.get_mut().copy_from_slice(src)
720 }
721 }
722 });
723
724 Ok(())
725 }
726
727 pub fn segments(&mut self, log_size: usize) -> impl Iterator<Item = TableWitnessSegment<P>> {
733 TableWitnessSegmentedView::new(self, log_size).into_iter()
734 }
735
736 pub fn par_segments(
737 &mut self,
738 log_size: usize,
739 ) -> impl IndexedParallelIterator<Item = TableWitnessSegment<'_, P>> {
740 TableWitnessSegmentedView::new(self, log_size).into_par_iter()
741 }
742}
743
744#[derive(Debug)]
750struct TableWitnessSegmentedView<'a, P = PackedType<OptimalUnderlier, B128>>
751where
752 P: PackedField,
753 P::Scalar: TowerField,
754{
755 table: &'a Table<P::Scalar>,
756 cols: Vec<WitnessColumnInfo<(&'a mut [P], usize)>>,
757 log_segment_size: usize,
758 start_index: usize,
759 n_segments: usize,
760}
761
762impl<'a, F: TowerField, P: PackedField<Scalar = F>> TableWitnessSegmentedView<'a, P> {
763 fn new(witness: &'a mut TableWitnessIndex<P>, log_segment_size: usize) -> Self {
764 let log_segment_size = log_segment_size
766 .min(witness.log_capacity)
767 .max(witness.min_log_segment_size);
768
769 let cols = witness
770 .cols
771 .iter_mut()
772 .map(|col| match &mut col.data {
773 WitnessColumnInfo::Owned(data) => {
774 let chunk_size = (log_segment_size + col.shape.log_cell_size())
775 .saturating_sub(P::LOG_WIDTH + F::TOWER_LEVEL);
776 WitnessColumnInfo::Owned((&mut **data, 1 << chunk_size))
777 }
778 WitnessColumnInfo::SameAsIndex(id) => WitnessColumnInfo::SameAsIndex(*id),
779 })
780 .collect::<Vec<_>>();
781 Self {
782 table: witness.table,
783 cols,
784 log_segment_size,
785 start_index: 0,
786 n_segments: 1 << (witness.log_capacity - log_segment_size),
787 }
788 }
789
790 fn split_at(
791 &mut self,
792 index: usize,
793 ) -> (TableWitnessSegmentedView<P>, TableWitnessSegmentedView<P>) {
794 assert!(index <= self.n_segments);
795 let (cols_0, cols_1) = self
796 .cols
797 .iter_mut()
798 .map(|col| match col {
799 WitnessColumnInfo::Owned((data, chunk_size)) => {
800 let (data_0, data_1) = data.split_at_mut(*chunk_size * index);
801 (
802 WitnessColumnInfo::Owned((data_0, *chunk_size)),
803 WitnessColumnInfo::Owned((data_1, *chunk_size)),
804 )
805 }
806 WitnessColumnInfo::SameAsIndex(id) => {
807 (WitnessColumnInfo::SameAsIndex(*id), WitnessColumnInfo::SameAsIndex(*id))
808 }
809 })
810 .unzip();
811 (
812 TableWitnessSegmentedView {
813 table: self.table,
814 cols: cols_0,
815 log_segment_size: self.log_segment_size,
816 start_index: self.start_index,
817 n_segments: index,
818 },
819 TableWitnessSegmentedView {
820 table: self.table,
821 cols: cols_1,
822 log_segment_size: self.log_segment_size,
823 start_index: self.start_index + index,
824 n_segments: self.n_segments - index,
825 },
826 )
827 }
828
829 fn into_iter(self) -> impl Iterator<Item = TableWitnessSegment<'a, P>> {
830 let TableWitnessSegmentedView {
831 table,
832 cols,
833 log_segment_size,
834 start_index,
835 n_segments,
836 } = self;
837
838 if n_segments == 0 {
839 itertools::Either::Left(iter::empty())
840 } else {
841 let iter = MultiIterator::new(
842 cols.into_iter()
843 .map(|col| match col {
844 WitnessColumnInfo::Owned((data, chunk_size)) => itertools::Either::Left(
845 data.chunks_mut(chunk_size)
846 .map(|chunk| RefCellData::Owned(RefCell::new(chunk))),
847 ),
848 WitnessColumnInfo::SameAsIndex(id) => itertools::Either::Right(
849 iter::repeat_n(id, n_segments).map(RefCellData::SameAsIndex),
850 ),
851 })
852 .collect(),
853 )
854 .enumerate()
855 .map(move |(index, cols)| TableWitnessSegment {
856 table,
857 cols,
858 log_size: log_segment_size,
859 index: start_index + index,
860 });
861 itertools::Either::Right(iter)
862 }
863 }
864
865 fn into_par_iter(self) -> impl IndexedParallelIterator<Item = TableWitnessSegment<'a, P>> {
866 let TableWitnessSegmentedView {
867 table,
868 cols,
869 log_segment_size,
870 start_index,
871 n_segments,
872 } = self;
873
874 #[allow(clippy::mut_from_ref)]
880 unsafe fn cast_slice_ref_to_mut<T>(slice: &[T]) -> &mut [T] {
881 unsafe { slice::from_raw_parts_mut(slice.as_ptr() as *mut T, slice.len()) }
882 }
883
884 let cols = cols
887 .into_iter()
888 .map(|col| -> WitnessColumnInfo<(&'a [P], usize)> {
889 match col {
890 WitnessColumnInfo::Owned((data, chunk_size)) => {
891 WitnessColumnInfo::Owned((data, chunk_size))
892 }
893 WitnessColumnInfo::SameAsIndex(id) => WitnessColumnInfo::SameAsIndex(id),
894 }
895 })
896 .collect::<Vec<_>>();
897
898 (0..n_segments).into_par_iter().map(move |i| {
899 let col_strides = cols
900 .iter()
901 .map(|col| match col {
902 WitnessColumnInfo::SameAsIndex(id) => RefCellData::SameAsIndex(*id),
903 WitnessColumnInfo::Owned((data, chunk_size)) => {
904 RefCellData::Owned(RefCell::new(unsafe {
905 cast_slice_ref_to_mut(&data[i * chunk_size..(i + 1) * chunk_size])
910 }))
911 }
912 })
913 .collect();
914 TableWitnessSegment {
915 table,
916 cols: col_strides,
917 log_size: log_segment_size,
918 index: start_index + i,
919 }
920 })
921 }
922}
923
924#[derive(Debug, CopyGetters)]
929pub struct TableWitnessSegment<'a, P = PackedType<OptimalUnderlier, B128>>
930where
931 P: PackedField,
932 P::Scalar: TowerField,
933{
934 table: &'a Table<P::Scalar>,
935 cols: Vec<RefCellData<'a, P>>,
939 #[get_copy = "pub"]
940 log_size: usize,
941 #[get_copy = "pub"]
943 index: usize,
944}
945
946impl<'a, F: TowerField, P: PackedField<Scalar = F>> TableWitnessSegment<'a, P> {
947 pub fn get<FSub: TowerField, const V: usize>(
948 &self,
949 col: Col<FSub, V>,
950 ) -> Result<Ref<[PackedSubfield<P, FSub>]>, Error>
951 where
952 P: PackedExtension<FSub>,
953 {
954 if col.table_id != self.table.id() {
955 return Err(Error::TableMismatch {
956 column_table_id: col.table_id,
957 witness_table_id: self.table.id(),
958 });
959 }
960
961 let col = self
962 .get_col_data(col.id())
963 .ok_or_else(|| Error::MissingColumn(col.id()))?;
964 let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
965 Ok(Ref::map(col_ref, |packed| PackedExtension::cast_bases(packed)))
966 }
967
968 pub fn get_mut<FSub: TowerField, const V: usize>(
969 &self,
970 col: Col<FSub, V>,
971 ) -> Result<RefMut<[PackedSubfield<P, FSub>]>, Error>
972 where
973 P: PackedExtension<FSub>,
974 F: ExtensionField<FSub>,
975 {
976 if col.table_id != self.table.id() {
977 return Err(Error::TableMismatch {
978 column_table_id: col.table_id,
979 witness_table_id: self.table.id(),
980 });
981 }
982
983 let col = self
984 .get_col_data(col.id())
985 .ok_or_else(|| Error::MissingColumn(col.id()))?;
986 let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
987 Ok(RefMut::map(col_ref, |packed| PackedExtension::cast_bases_mut(packed)))
988 }
989
990 pub fn get_scalars<FSub: TowerField, const V: usize>(
991 &self,
992 col: Col<FSub, V>,
993 ) -> Result<Ref<[FSub]>, Error>
994 where
995 P: PackedExtension<FSub>,
996 F: ExtensionField<FSub>,
997 PackedSubfield<P, FSub>: PackedFieldIndexable,
998 {
999 self.get(col)
1000 .map(|packed| Ref::map(packed, <PackedSubfield<P, FSub>>::unpack_scalars))
1001 }
1002
1003 pub fn get_scalars_mut<FSub: TowerField, const V: usize>(
1004 &self,
1005 col: Col<FSub, V>,
1006 ) -> Result<RefMut<[FSub]>, Error>
1007 where
1008 P: PackedExtension<FSub>,
1009 F: ExtensionField<FSub>,
1010 PackedSubfield<P, FSub>: PackedFieldIndexable,
1011 {
1012 self.get_mut(col)
1013 .map(|packed| RefMut::map(packed, <PackedSubfield<P, FSub>>::unpack_scalars_mut))
1014 }
1015
1016 pub fn get_as<T: Pod, FSub: TowerField, const V: usize>(
1017 &self,
1018 col: Col<FSub, V>,
1019 ) -> Result<Ref<[T]>, Error>
1020 where
1021 P: PackedExtension<FSub> + PackedFieldIndexable,
1022 F: ExtensionField<FSub> + Pod,
1023 {
1024 let col = self
1025 .get_col_data(col.id())
1026 .ok_or_else(|| Error::MissingColumn(col.id()))?;
1027 let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
1028 Ok(Ref::map(col_ref, |col| must_cast_slice(P::unpack_scalars(col))))
1029 }
1030
1031 pub fn get_mut_as<T: Pod, FSub: TowerField, const V: usize>(
1032 &self,
1033 col: Col<FSub, V>,
1034 ) -> Result<RefMut<[T]>, Error>
1035 where
1036 P: PackedExtension<FSub> + PackedFieldIndexable,
1037 F: ExtensionField<FSub> + Pod,
1038 {
1039 if col.table_id != self.table.id() {
1040 return Err(Error::TableMismatch {
1041 column_table_id: col.table_id,
1042 witness_table_id: self.table.id(),
1043 });
1044 }
1045
1046 let col = self
1047 .get_col_data(col.id())
1048 .ok_or_else(|| Error::MissingColumn(col.id()))?;
1049 let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
1050 Ok(RefMut::map(col_ref, |col| must_cast_slice_mut(P::unpack_scalars_mut(col))))
1051 }
1052
1053 pub fn eval_expr<FSub: TowerField, const V: usize>(
1060 &self,
1061 expr: &Expr<FSub, V>,
1062 ) -> Result<impl Iterator<Item = PackedSubfield<P, FSub>> + use<FSub, V, F, P>, Error>
1063 where
1064 P: PackedExtension<FSub>,
1065 {
1066 let log_vals_per_row = checked_log_2(V);
1067
1068 let partition =
1069 self.table
1070 .partitions
1071 .get(log_vals_per_row)
1072 .ok_or_else(|| Error::MissingPartition {
1073 table_id: self.table.id(),
1074 log_vals_per_row,
1075 })?;
1076 let expr_circuit = ArithCircuit::from(expr.expr());
1077
1078 let col_refs = partition
1079 .columns
1080 .iter()
1081 .zip(expr_circuit.vars_usage())
1082 .map(|(col_id, used)| {
1083 used.then(|| {
1084 assert_eq!(FSub::TOWER_LEVEL, self.table[*col_id].shape.tower_height);
1085 let col = self
1086 .get_col_data(*col_id)
1087 .ok_or_else(|| Error::MissingColumn(*col_id))?;
1088 let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
1089 Ok::<_, Error>(Ref::map(col_ref, |packed| PackedExtension::cast_bases(packed)))
1090 })
1091 .transpose()
1092 })
1093 .collect::<Result<Vec<_>, _>>()?;
1094
1095 let log_packed_elems =
1096 (self.log_size + log_vals_per_row).saturating_sub(<PackedSubfield<P, FSub>>::LOG_WIDTH);
1097
1098 let dummy_col = zeroed_vec(1 << log_packed_elems);
1100
1101 let cols = col_refs
1102 .iter()
1103 .map(|col| col.as_ref().map(|col_ref| &**col_ref).unwrap_or(&dummy_col))
1104 .collect::<Vec<_>>();
1105 let cols = RowsBatchRef::new(&cols, 1 << log_packed_elems);
1106
1107 let mut evals = zeroed_vec(1 << log_packed_elems);
1112 ArithCircuitPoly::new(expr_circuit).batch_evaluate(&cols, &mut evals)?;
1113 Ok(evals.into_iter())
1114 }
1115
1116 pub fn size(&self) -> usize {
1117 1 << self.log_size
1118 }
1119
1120 fn get_col_data(&self, column_id: ColumnId) -> Option<&RefCell<&'a mut [P]>> {
1121 let column_index = column_id.table_index.0;
1122 self.get_col_data_by_index(column_index)
1123 }
1124
1125 fn get_col_data_by_index(&self, index: usize) -> Option<&RefCell<&'a mut [P]>> {
1126 match self.cols.get(index) {
1127 Some(RefCellData::Owned(data)) => Some(data),
1128 Some(RefCellData::SameAsIndex(index)) => self.get_col_data_by_index(*index),
1129 None => None,
1130 }
1131 }
1132}
1133
1134impl<'a, P> TableWitnessSegment<'a, P>
1135where
1136 P: PackedField<Scalar: TowerField>
1137 + PackedExtension<B1>
1138 + PackedExtension<B8>
1139 + PackedExtension<B16>
1140 + PackedExtension<B32>
1141 + PackedExtension<B64>
1142 + PackedExtension<B128>,
1143{
1144 pub fn get_dyn(
1146 &self,
1147 col_id: ColumnId,
1148 ) -> Result<Box<dyn WitnessColView<P::Scalar> + '_>, Error> {
1149 let col = self
1150 .get_col_data(col_id)
1151 .ok_or_else(|| Error::MissingColumn(col_id))?;
1152 let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
1153 let tower_level = self.table[col_id].shape.tower_height;
1154 let ret: Box<dyn WitnessColView<_>> = match tower_level {
1155 0 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1156 PackedExtension::<B1>::cast_bases(packed)
1157 }))),
1158 3 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1159 PackedExtension::<B8>::cast_bases(packed)
1160 }))),
1161 4 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1162 PackedExtension::<B16>::cast_bases(packed)
1163 }))),
1164 5 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1165 PackedExtension::<B32>::cast_bases(packed)
1166 }))),
1167 6 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1168 PackedExtension::<B64>::cast_bases(packed)
1169 }))),
1170 7 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1171 PackedExtension::<B128>::cast_bases(packed)
1172 }))),
1173 _ => panic!("tower_level must be in the range [0, 7]"),
1174 };
1175 Ok(ret)
1176 }
1177
1178 pub fn get_dyn_mut(
1180 &self,
1181 col_id: ColumnId,
1182 ) -> Result<Box<dyn WitnessColViewMut<P::Scalar> + '_>, Error> {
1183 let col = self
1184 .get_col_data(col_id)
1185 .ok_or_else(|| Error::MissingColumn(col_id))?;
1186 let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
1187 let tower_level = self.table[col_id].shape.tower_height;
1188 let ret: Box<dyn WitnessColViewMut<_>> = match tower_level {
1189 0 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1190 PackedExtension::<B1>::cast_bases_mut(packed)
1191 }))),
1192 3 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1193 PackedExtension::<B8>::cast_bases_mut(packed)
1194 }))),
1195 4 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1196 PackedExtension::<B16>::cast_bases_mut(packed)
1197 }))),
1198 5 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1199 PackedExtension::<B32>::cast_bases_mut(packed)
1200 }))),
1201 6 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1202 PackedExtension::<B64>::cast_bases_mut(packed)
1203 }))),
1204 7 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1205 PackedExtension::<B128>::cast_bases_mut(packed)
1206 }))),
1207 _ => panic!("tower_level must be in the range [0, 7]"),
1208 };
1209 Ok(ret)
1210 }
1211}
1212
1213pub trait WitnessColView<F> {
1216 fn get(&self, index: usize) -> F;
1218
1219 fn size(&self) -> usize;
1221}
1222
1223pub trait WitnessColViewMut<F>: WitnessColView<F> {
1225 fn set(&mut self, index: usize, val: F) -> Result<(), Error>;
1227}
1228
1229struct WitnessColViewImpl<Data>(Data);
1230
1231impl<F, P, Data> WitnessColView<F> for WitnessColViewImpl<Data>
1232where
1233 F: ExtensionField<P::Scalar>,
1234 P: PackedField,
1235 Data: Deref<Target = [P]>,
1236{
1237 fn get(&self, index: usize) -> F {
1238 get_packed_slice(&self.0, index).into()
1239 }
1240
1241 fn size(&self) -> usize {
1242 self.0.len() << P::LOG_WIDTH
1243 }
1244}
1245
1246impl<F, P, Data> WitnessColViewMut<F> for WitnessColViewImpl<Data>
1247where
1248 F: ExtensionField<P::Scalar>,
1249 P: PackedField,
1250 Data: DerefMut<Target = [P]>,
1251{
1252 fn set(&mut self, index: usize, val: F) -> Result<(), Error> {
1253 let subfield_val = val.try_into().map_err(|_| Error::FieldElementTooBig)?;
1254 set_packed_slice(&mut self.0, index, subfield_val);
1255 Ok(())
1256 }
1257}
1258
1259pub trait TableFiller<P = PackedType<OptimalUnderlier, B128>>
1261where
1262 P: PackedField,
1263 P::Scalar: TowerField,
1264{
1265 type Event: Clone;
1267
1268 fn id(&self) -> TableId;
1270
1271 fn fill(
1277 &self,
1278 rows: &[Self::Event],
1279 witness: &mut TableWitnessSegment<P>,
1280 ) -> anyhow::Result<()>;
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285 use std::{array, iter::repeat_with};
1286
1287 use assert_matches::assert_matches;
1288 use binius_compute::cpu::alloc::CpuComputeAllocator;
1289 use binius_core::oracle::{OracleId, SymbolicMultilinearOracleSet};
1290 use binius_field::{
1291 arch::{OptimalUnderlier128b, OptimalUnderlier256b},
1292 packed::{len_packed_slice, set_packed_slice},
1293 };
1294 use rand::{Rng, SeedableRng, rngs::StdRng};
1295
1296 use super::*;
1297 use crate::builder::{
1298 ConstraintSystem, TableBuilder,
1299 types::{B1, B8, B16, B32},
1300 };
1301
1302 #[test]
1303 fn test_table_witness_borrows() {
1304 let table_id = 0;
1305 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1306 let mut table = TableBuilder::new(&mut inner_table);
1307 let col0 = table.add_committed::<B1, 8>("col0");
1308 let col1 = table.add_committed::<B1, 32>("col1");
1309 let col2 = table.add_committed::<B8, 1>("col2");
1310 let col3 = table.add_committed::<B32, 1>("col3");
1311
1312 let mut allocator = CpuComputeAllocator::new(1 << 12);
1313 let allocator = allocator.into_bump_allocator();
1314 let table_size = 64;
1315 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1316 &allocator,
1317 &inner_table,
1318 table_size,
1319 )
1320 .unwrap();
1321 let segment = index.full_segment();
1322
1323 {
1324 let col0_ref0 = segment.get(col0).unwrap();
1325 let _col0_ref1 = segment.get(col0).unwrap();
1326 assert_matches!(segment.get_mut(col0), Err(Error::WitnessBorrowMut(_)));
1327 drop(col0_ref0);
1328
1329 let col1_ref = segment.get_mut(col1).unwrap();
1330 assert_matches!(segment.get(col1), Err(Error::WitnessBorrow(_)));
1331 drop(col1_ref);
1332 }
1333
1334 assert_eq!(len_packed_slice(&segment.get_mut(col0).unwrap()), 1 << 9);
1335 assert_eq!(len_packed_slice(&segment.get_mut(col1).unwrap()), 1 << 11);
1336 assert_eq!(len_packed_slice(&segment.get_mut(col2).unwrap()), 1 << 6);
1337 assert_eq!(len_packed_slice(&segment.get_mut(col3).unwrap()), 1 << 6);
1338 }
1339
1340 #[test]
1341 fn test_table_witness_segments() {
1342 let table_id = 0;
1343 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1344 let mut table = TableBuilder::new(&mut inner_table);
1345 let col0 = table.add_committed::<B1, 8>("col0");
1346 let col1 = table.add_committed::<B1, 32>("col1");
1347 let col2 = table.add_committed::<B8, 1>("col2");
1348 let col3 = table.add_committed::<B32, 1>("col3");
1349
1350 let mut allocator = CpuComputeAllocator::new(1 << 12);
1351 let allocator = allocator.into_bump_allocator();
1352 let table_size = 64;
1353 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1354 &allocator,
1355 &inner_table,
1356 table_size,
1357 )
1358 .unwrap();
1359
1360 assert_eq!(index.min_log_segment_size(), 4);
1361 let mut iter = index.segments(5);
1362 let seg0 = iter.next().unwrap();
1363 let seg1 = iter.next().unwrap();
1364 assert!(iter.next().is_none());
1365
1366 assert_eq!(len_packed_slice(&seg0.get_mut(col0).unwrap()), 1 << 8);
1367 assert_eq!(len_packed_slice(&seg0.get_mut(col1).unwrap()), 1 << 10);
1368 assert_eq!(len_packed_slice(&seg0.get_mut(col2).unwrap()), 1 << 5);
1369 assert_eq!(len_packed_slice(&seg0.get_mut(col3).unwrap()), 1 << 5);
1370
1371 assert_eq!(len_packed_slice(&seg1.get_mut(col0).unwrap()), 1 << 8);
1372 assert_eq!(len_packed_slice(&seg1.get_mut(col1).unwrap()), 1 << 10);
1373 assert_eq!(len_packed_slice(&seg1.get_mut(col2).unwrap()), 1 << 5);
1374 assert_eq!(len_packed_slice(&seg1.get_mut(col3).unwrap()), 1 << 5);
1375 }
1376
1377 #[test]
1378 fn test_eval_expr() {
1379 let table_id = 0;
1380 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1381 let mut table = TableBuilder::new(&mut inner_table);
1382 let col0 = table.add_committed::<B8, 2>("col0");
1383 let col1 = table.add_committed::<B8, 2>("col1");
1384 let col2 = table.add_committed::<B8, 2>("col2");
1385
1386 let mut allocator = CpuComputeAllocator::new(1 << 12);
1387 let allocator = allocator.into_bump_allocator();
1388 let table_size = 1 << 6;
1389 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1390 &allocator,
1391 &inner_table,
1392 table_size,
1393 )
1394 .unwrap();
1395
1396 let segment = index.full_segment();
1397 assert_eq!(segment.log_size(), 6);
1398
1399 {
1401 let mut col0 = segment.get_mut(col0).unwrap();
1402 let mut col1 = segment.get_mut(col1).unwrap();
1403 let mut col2 = segment.get_mut(col2).unwrap();
1404
1405 let expected_slice_len = 1 << 3;
1407 assert_eq!(col0.len(), expected_slice_len);
1408 assert_eq!(col0.len(), expected_slice_len);
1409 assert_eq!(col0.len(), expected_slice_len);
1410
1411 for i in 0..expected_slice_len * <PackedType<OptimalUnderlier128b, B8>>::WIDTH {
1412 set_packed_slice(&mut col0, i, B8::new(i as u8) + B8::new(0x00));
1413 set_packed_slice(&mut col1, i, B8::new(i as u8) + B8::new(0x40));
1414 set_packed_slice(&mut col2, i, B8::new(i as u8) + B8::new(0x80));
1415 }
1416 }
1417
1418 let evals = segment.eval_expr(&(col0 * col1 - col2)).unwrap();
1419 for (i, eval_i) in evals
1420 .into_iter()
1421 .flat_map(PackedField::into_iter)
1422 .enumerate()
1423 {
1424 let col0_val = B8::new(i as u8) + B8::new(0x00);
1425 let col1_val = B8::new(i as u8) + B8::new(0x40);
1426 let col2_val = B8::new(i as u8) + B8::new(0x80);
1427 assert_eq!(eval_i, col0_val * col1_val - col2_val);
1428 }
1429 }
1430
1431 #[test]
1432 fn test_eval_expr_different_cols() {
1433 let table_id = 0;
1434 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1435 let mut table = TableBuilder::new(&mut inner_table);
1436 let col0 = table.add_committed::<B32, 1>("col1");
1437 let col1 = table.add_committed::<B8, 8>("col2");
1438 let col2 = table.add_committed::<B16, 4>("col3");
1439 let col3 = table.add_committed::<B8, 8>("col4");
1440
1441 let mut allocator = CpuComputeAllocator::new(1 << 12);
1442 let allocator = allocator.into_bump_allocator();
1443 let table_size = 4;
1444 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1445 &allocator,
1446 &inner_table,
1447 table_size,
1448 )
1449 .unwrap();
1450
1451 let segment = index.full_segment();
1452
1453 {
1455 let mut col0: RefMut<'_, [u32]> = segment.get_mut_as(col0).unwrap();
1456 let mut col1: RefMut<'_, [[u8; 8]]> = segment.get_mut_as(col1).unwrap();
1457 let mut col2: RefMut<'_, [[u16; 4]]> = segment.get_mut_as(col2).unwrap();
1458 let mut col3: RefMut<'_, [[u8; 8]]> = segment.get_mut_as(col3).unwrap();
1459
1460 col0[0] = 0x40;
1461 col1[0] = [0x45, 0, 0, 0, 0, 0, 0, 0];
1462 col2[0] = [0, 0, 0, 0x80];
1463 col3[0] = [0x85, 0, 0, 0, 0, 0, 0, 00];
1464 }
1465
1466 let evals = segment.eval_expr(&(col1 + col3)).unwrap();
1467 for (i, eval_i) in evals
1468 .into_iter()
1469 .flat_map(PackedField::into_iter)
1470 .enumerate()
1471 {
1472 if i == 0 {
1473 assert_eq!(eval_i, B8::new(0x45) + B8::new(0x85));
1474 } else {
1475 assert_eq!(eval_i, B8::new(0x00));
1476 }
1477 }
1478 }
1479
1480 #[test]
1481 fn test_small_tables() {
1482 let table_id = 0;
1483 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1484 let mut table = TableBuilder::new(&mut inner_table);
1485 let _col0 = table.add_committed::<B1, 8>("col0");
1486 let _col1 = table.add_committed::<B1, 32>("col1");
1487 let _col2 = table.add_committed::<B8, 1>("col2");
1488 let _col3 = table.add_committed::<B32, 1>("col3");
1489
1490 let mut allocator = CpuComputeAllocator::new(1 << 12);
1491 let allocator = allocator.into_bump_allocator();
1492 let table_size = 7;
1493 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier256b, B128>>::new(
1494 &allocator,
1495 &inner_table,
1496 table_size,
1497 )
1498 .unwrap();
1499
1500 assert_eq!(index.log_capacity(), 3);
1501 assert_eq!(index.min_log_segment_size(), 3);
1502
1503 let mut iter = index.segments(5);
1504 assert_eq!(iter.next().unwrap().log_size(), 3);
1506 assert!(iter.next().is_none());
1507 drop(iter);
1508
1509 let mut iter = index.segments(2);
1510 assert_eq!(iter.next().unwrap().log_size(), 3);
1512 assert!(iter.next().is_none());
1513 drop(iter);
1514 }
1515
1516 struct TestTable {
1517 id: TableId,
1518 col0: Col<B32>,
1519 col1: Col<B32>,
1520 }
1521
1522 impl TestTable {
1523 fn new(cs: &mut ConstraintSystem) -> Self {
1524 let mut table = cs.add_table("test");
1525
1526 let col0 = table.add_committed("col0");
1527 let col1 = table.add_computed("col1", col0 * col0 + B32::new(0x03));
1528
1529 Self {
1530 id: table.id(),
1531 col0,
1532 col1,
1533 }
1534 }
1535 }
1536
1537 impl TableFiller<PackedType<OptimalUnderlier128b, B128>> for TestTable {
1538 type Event = u32;
1539
1540 fn id(&self) -> TableId {
1541 self.id
1542 }
1543
1544 fn fill(
1545 &self,
1546 rows: &[Self::Event],
1547 witness: &mut TableWitnessSegment<PackedType<OptimalUnderlier128b, B128>>,
1548 ) -> anyhow::Result<()> {
1549 let mut col0 = witness.get_scalars_mut(self.col0)?;
1550 let mut col1 = witness.get_scalars_mut(self.col1)?;
1551 for (i, &val) in rows.iter().enumerate() {
1552 col0[i] = B32::new(val);
1553 col1[i] = col0[i].pow(2) + B32::new(0x03);
1554 }
1555 Ok(())
1556 }
1557 }
1558
1559 #[test]
1560 fn test_fill_sequential_with_incomplete_events() {
1561 let mut cs = ConstraintSystem::new();
1562 let test_table = TestTable::new(&mut cs);
1563
1564 let mut allocator = CpuComputeAllocator::new(1 << 12);
1565 let allocator = allocator.into_bump_allocator();
1566
1567 let table_size = 11;
1568 let mut index = WitnessIndex::new(&cs, &allocator);
1569 let table_index = index.init_table(test_table.id(), table_size).unwrap();
1570
1571 let mut rng = StdRng::seed_from_u64(0);
1572 let rows = repeat_with(|| rng.random())
1573 .take(table_size)
1574 .collect::<Vec<_>>();
1575
1576 assert_eq!(table_index.log_capacity(), 4);
1578
1579 assert_eq!(table_index.min_log_segment_size(), 2);
1581
1582 assert_matches!(
1584 table_index.fill_sequential_with_segment_size(&test_table, &rows[1..], 2),
1585 Err(Error::IncorrectNumberOfTableEvents { .. })
1586 );
1587
1588 table_index
1589 .fill_sequential_with_segment_size(&test_table, &rows, 2)
1590 .unwrap();
1591
1592 let segment = table_index.full_segment();
1593 let col0 = segment.get_scalars(test_table.col0).unwrap();
1594 for i in 0..11 {
1595 assert_eq!(col0[i].val(), rows[i]);
1596 }
1597 assert_eq!(col0[11].val(), rows[10]);
1598 assert_eq!(col0[12].val(), rows[8]);
1599 assert_eq!(col0[13].val(), rows[9]);
1600 assert_eq!(col0[14].val(), rows[10]);
1601 assert_eq!(col0[15].val(), rows[10]);
1602 }
1603
1604 #[test]
1605 fn test_fill_parallel_with_incomplete_events() {
1606 let mut cs = ConstraintSystem::new();
1607 let test_table = TestTable::new(&mut cs);
1608
1609 let mut allocator = CpuComputeAllocator::new(1 << 12);
1610 let allocator = allocator.into_bump_allocator();
1611
1612 let table_size = 11;
1613 let mut index = WitnessIndex::new(&cs, &allocator);
1614 let table_index = index.init_table(test_table.id(), table_size).unwrap();
1615
1616 let mut rng = StdRng::seed_from_u64(0);
1617 let rows = repeat_with(|| rng.random())
1618 .take(table_size)
1619 .collect::<Vec<_>>();
1620
1621 assert_eq!(table_index.log_capacity(), 4);
1623
1624 assert_eq!(table_index.min_log_segment_size(), 2);
1626
1627 assert_matches!(
1629 table_index.fill_parallel_with_segment_size(&test_table, &rows[1..], 2),
1630 Err(Error::IncorrectNumberOfTableEvents { .. })
1631 );
1632
1633 table_index
1634 .fill_parallel_with_segment_size(&test_table, &rows, 2)
1635 .unwrap();
1636
1637 let segment = table_index.full_segment();
1638 let col0 = segment.get_scalars(test_table.col0).unwrap();
1639 for i in 0..11 {
1640 assert_eq!(col0[i].val(), rows[i]);
1641 }
1642 assert_eq!(col0[11].val(), rows[10]);
1643 assert_eq!(col0[12].val(), rows[8]);
1644 assert_eq!(col0[13].val(), rows[9]);
1645 assert_eq!(col0[14].val(), rows[10]);
1646 assert_eq!(col0[15].val(), rows[10]);
1647 }
1648
1649 #[test]
1650 fn test_fill_empty_rows_non_empty_table() {
1651 let mut cs = ConstraintSystem::new();
1652 let test_table = TestTable::new(&mut cs);
1653
1654 let mut allocator = CpuComputeAllocator::new(1 << 12);
1655 let allocator = allocator.into_bump_allocator();
1656
1657 let table_size = 11;
1658 let mut index = WitnessIndex::new(&cs, &allocator);
1659
1660 index.init_table(test_table.id(), table_size).unwrap();
1661
1662 assert_matches!(
1663 index.fill_table_sequential(&test_table, &[]),
1664 Err(Error::IncorrectNumberOfTableEvents { .. })
1665 );
1666 }
1667
1668 #[test]
1669 fn test_dyn_witness() {
1670 let mut cs = ConstraintSystem::new();
1671 let mut test_table = cs.add_table("test");
1672 let test_col: Col<B32, 4> = test_table.add_committed("col");
1673 let table_id = test_table.id();
1674
1675 let mut allocator = CpuComputeAllocator::new(1 << 12);
1676 let allocator = allocator.into_bump_allocator();
1677
1678 let table_size = 11;
1679 let mut index = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
1680
1681 let table_index = index.init_table(table_id, table_size).unwrap();
1682 let segment = table_index.full_segment();
1683 let mut rng = StdRng::seed_from_u64(0);
1684 let row = repeat_with(|| B32::random(&mut rng))
1685 .take(table_size * 4)
1686 .collect::<Vec<_>>();
1687 {
1688 let mut data: Box<dyn WitnessColViewMut<_>> =
1689 segment.get_dyn_mut(test_col.id()).unwrap();
1690 row.iter().enumerate().for_each(|(i, val)| {
1691 data.set(i, (*val).into()).unwrap();
1692 })
1693 }
1694 let data = segment.get_dyn(test_col.id()).unwrap();
1695 row.iter().enumerate().for_each(|(i, val)| {
1696 let down_cast: B32 = data.get(i).try_into().unwrap();
1697 assert_eq!(down_cast, *val)
1698 })
1699 }
1700
1701 fn find_oracle_id_with_name(
1702 oracles: &SymbolicMultilinearOracleSet<B128>,
1703 name: &str,
1704 ) -> Option<OracleId> {
1705 oracles
1706 .iter()
1707 .find(|(_, oracle)| oracle.name.as_deref() == Some(name))
1708 .map(|(id, _)| id)
1709 }
1710
1711 #[test]
1712 fn test_constant_filling() {
1713 let mut cs = ConstraintSystem::new();
1714
1715 let mut test_table = cs.add_table("test");
1716 let mut rng = StdRng::seed_from_u64(0);
1717 let unpack_value = B16::random(&mut rng);
1718 let pack_const_arr: [B32; 4] = array::from_fn(|_| B32::random(&mut rng));
1719
1720 let _ = test_table.add_constant("unpacked_col", [unpack_value]);
1721 let _ = test_table.add_constant("packed_col", pack_const_arr);
1722 let table_id = test_table.id();
1723
1724 let mut allocator = CpuComputeAllocator::new(1 << 12);
1725 let allocator = allocator.into_bump_allocator();
1726
1727 let table_size = 123;
1728 let ccs = cs.compile().unwrap();
1729 let mut index = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
1730
1731 {
1732 let _ = index.init_table(table_id, table_size).unwrap();
1733 index.fill_constant_cols().unwrap();
1734 }
1735
1736 let witness = index.into_multilinear_extension_index();
1737 let non_packed_col_id = find_oracle_id_with_name(&ccs.oracles, "unpacked_col").unwrap();
1738
1739 let non_pack_witness = witness.get_multilin_poly(non_packed_col_id).unwrap();
1741 for index in 0..non_pack_witness.size() {
1742 let got = non_pack_witness.evaluate_on_hypercube(index).unwrap();
1743 assert_eq!(got, unpack_value.into());
1744 }
1745 let packed_col_id = find_oracle_id_with_name(&ccs.oracles, "packed_col").unwrap();
1746
1747 let pack_witness = witness.get_multilin_poly(packed_col_id).unwrap();
1748 for index in 0..pack_witness.size() {
1749 let got = pack_witness.evaluate_on_hypercube(index).unwrap();
1750
1751 assert_eq!(got, pack_const_arr[index % 4].into());
1752 }
1753 }
1754}