1use std::{
4 cell::{Ref, RefCell, RefMut},
5 iter, slice,
6 sync::Arc,
7};
8
9use binius_core::{
10 oracle::OracleId, polynomial::ArithCircuitPoly, transparent::step_down::StepDown,
11 witness::MultilinearExtensionIndex,
12};
13use binius_field::{
14 arch::OptimalUnderlier, as_packed_field::PackedType, ExtensionField, PackedExtension,
15 PackedField, PackedFieldIndexable, PackedSubfield, TowerField,
16};
17use binius_math::{CompositionPoly, MultilinearExtension, MultilinearPoly, RowsBatchRef};
18use binius_maybe_rayon::prelude::*;
19use binius_utils::checked_arithmetics::checked_log_2;
20use bytemuck::{must_cast_slice, must_cast_slice_mut, zeroed_vec, Pod};
21use getset::CopyGetters;
22use itertools::Itertools;
23
24use super::{
25 column::{Col, ColumnShape},
26 error::Error,
27 table::{Table, TableId},
28 types::{B1, B128, B16, B32, B64, B8},
29 ColumnDef, ColumnId, ColumnIndex, Expr,
30};
31use crate::builder::multi_iter::MultiIterator;
32
33#[derive(Debug, Default, CopyGetters)]
41pub struct WitnessIndex<'cs, 'alloc, P = PackedType<OptimalUnderlier, B128>>
42where
43 P: PackedField,
44 P::Scalar: TowerField,
45{
46 pub tables: Vec<Option<TableWitnessIndex<'cs, 'alloc, P>>>,
47}
48
49impl<'cs, 'alloc, F: TowerField, P: PackedField<Scalar = F>> WitnessIndex<'cs, 'alloc, P> {
50 pub fn get_table(
51 &mut self,
52 table_id: TableId,
53 ) -> Option<&mut TableWitnessIndex<'cs, 'alloc, P>> {
54 self.tables
55 .get_mut(table_id)
56 .and_then(|inner| inner.as_mut())
57 }
58
59 pub fn fill_table_sequential<T: TableFiller<P>>(
60 &mut self,
61 table: &T,
62 rows: &[T::Event],
63 ) -> Result<(), Error> {
64 if !rows.is_empty() {
65 let table_id = table.id();
66 let witness = self
67 .get_table(table_id)
68 .ok_or(Error::MissingTable { table_id })?;
69 witness.fill_sequential(table, rows)?;
70 }
71 Ok(())
72 }
73
74 pub fn fill_table_parallel<T>(&mut self, table: &T, rows: &[T::Event]) -> Result<(), Error>
75 where
76 T: TableFiller<P> + Sync,
77 T::Event: Sync,
78 {
79 let table_id = table.id();
80 let witness = self
81 .get_table(table_id)
82 .ok_or(Error::MissingTable { table_id })?;
83 witness.fill_parallel(table, rows)?;
84 Ok(())
85 }
86
87 pub fn into_multilinear_extension_index(self) -> MultilinearExtensionIndex<'alloc, P>
88 where
89 P: PackedExtension<B1>
90 + PackedExtension<B8>
91 + PackedExtension<B16>
92 + PackedExtension<B32>
93 + PackedExtension<B64>
94 + PackedExtension<B128>,
95 {
96 let mut index = MultilinearExtensionIndex::new();
97 let mut first_oracle_id_in_table = 0;
98 for table_witness in self.tables {
99 let Some(table_witness) = table_witness else {
100 continue;
101 };
102 let table = table_witness.table();
103 let cols = immutable_witness_index_columns(table_witness.cols);
104
105 let mut count = 0;
107 for (oracle_id_offset, col) in cols.into_iter().enumerate() {
108 let oracle_id = first_oracle_id_in_table + oracle_id_offset;
109 let log_capacity = if col.is_single_row {
110 0
111 } else {
112 table_witness.log_capacity
113 };
114 let n_vars = log_capacity + col.shape.log_values_per_row;
115 let underlier_count = 1
116 << (n_vars + col.shape.tower_height)
117 .saturating_sub(P::LOG_WIDTH + F::TOWER_LEVEL);
118 let witness = multilin_poly_from_underlier_data(
119 &col.data[..underlier_count],
120 n_vars,
121 col.shape.tower_height,
122 );
123 index.update_multilin_poly([(oracle_id, witness)]).unwrap();
124 count += 1;
125 }
126
127 for log_values_per_row in table.partitions.keys() {
130 let oracle_id = first_oracle_id_in_table + count;
131 let size = table_witness.size << log_values_per_row;
132 let log_size = table_witness.log_capacity + log_values_per_row;
133 let witness = StepDown::new(log_size, size)
134 .unwrap()
135 .multilinear_extension::<PackedSubfield<P, B1>>()
136 .unwrap()
137 .specialize_arc_dyn();
138 index.update_multilin_poly([(oracle_id, witness)]).unwrap();
139 count += 1;
140 }
141
142 first_oracle_id_in_table += count;
143 }
144 index
145 }
146}
147
148fn multilin_poly_from_underlier_data<P>(
149 data: &[P],
150 n_vars: usize,
151 tower_height: usize,
152) -> Arc<dyn MultilinearPoly<P> + Send + Sync + '_>
153where
154 P: PackedExtension<B1>
155 + PackedExtension<B8>
156 + PackedExtension<B16>
157 + PackedExtension<B32>
158 + PackedExtension<B64>
159 + PackedExtension<B128>,
160{
161 match tower_height {
162 0 => MultilinearExtension::new(n_vars, PackedExtension::<B1>::cast_bases(data))
163 .unwrap()
164 .specialize_arc_dyn(),
165 3 => MultilinearExtension::new(n_vars, PackedExtension::<B8>::cast_bases(data))
166 .unwrap()
167 .specialize_arc_dyn(),
168 4 => MultilinearExtension::new(n_vars, PackedExtension::<B16>::cast_bases(data))
169 .unwrap()
170 .specialize_arc_dyn(),
171 5 => MultilinearExtension::new(n_vars, PackedExtension::<B32>::cast_bases(data))
172 .unwrap()
173 .specialize_arc_dyn(),
174 6 => MultilinearExtension::new(n_vars, PackedExtension::<B64>::cast_bases(data))
175 .unwrap()
176 .specialize_arc_dyn(),
177 7 => MultilinearExtension::new(n_vars, PackedExtension::<B128>::cast_bases(data))
178 .unwrap()
179 .specialize_arc_dyn(),
180 _ => {
181 panic!("Unsupported tower height: {tower_height}");
182 }
183 }
184}
185
186#[derive(Debug, CopyGetters)]
188pub struct TableWitnessIndex<'cs, 'alloc, P = PackedType<OptimalUnderlier, B128>>
189where
190 P: PackedField,
191 P::Scalar: TowerField,
192{
193 #[get_copy = "pub"]
194 table: &'cs Table<P::Scalar>,
195 oracle_offset: usize,
196 cols: Vec<WitnessIndexColumn<'alloc, P>>,
197 #[get_copy = "pub"]
199 size: usize,
200 #[get_copy = "pub"]
201 log_capacity: usize,
202 #[get_copy = "pub"]
207 min_log_segment_size: usize,
208}
209
210#[derive(Debug)]
211pub struct WitnessIndexColumn<'a, P: PackedField> {
212 shape: ColumnShape,
213 data: WitnessDataMut<'a, P>,
214 is_single_row: bool,
215}
216
217#[derive(Debug, Clone)]
218enum WitnessColumnInfo<T> {
219 Owned(T),
220 SameAsOracleIndex(usize),
221}
222
223type WitnessDataMut<'a, P> = WitnessColumnInfo<&'a mut [P]>;
224
225impl<'a, P: PackedField> WitnessDataMut<'a, P> {
226 pub fn new_owned(allocator: &'a bumpalo::Bump, log_underlier_count: usize) -> Self {
227 Self::Owned(allocator.alloc_slice_fill_default(1 << log_underlier_count))
228 }
229}
230
231type RefCellData<'a, P> = WitnessColumnInfo<RefCell<&'a mut [P]>>;
232
233#[derive(Debug)]
234pub struct ImmutableWitnessIndexColumn<'a, P: PackedField> {
235 pub shape: ColumnShape,
236 pub data: &'a [P],
237 pub is_single_row: bool,
238}
239
240fn immutable_witness_index_columns<P: PackedField>(
243 cols: Vec<WitnessIndexColumn<P>>,
244) -> Vec<ImmutableWitnessIndexColumn<P>> {
245 let mut result = Vec::<ImmutableWitnessIndexColumn<_>>::with_capacity(cols.len());
246 for col in cols {
247 result.push(ImmutableWitnessIndexColumn {
248 shape: col.shape,
249 data: match col.data {
250 WitnessDataMut::Owned(data) => data,
251 WitnessDataMut::SameAsOracleIndex(index) => result[index].data,
252 },
253 is_single_row: col.is_single_row,
254 });
255 }
256 result
257}
258
259impl<'cs, 'alloc, F: TowerField, P: PackedField<Scalar = F>> TableWitnessIndex<'cs, 'alloc, P> {
260 pub fn new(
261 allocator: &'alloc bumpalo::Bump,
262 table: &'cs Table<F>,
263 size: usize,
264 ) -> Result<Self, Error> {
265 if size == 0 {
266 return Err(Error::EmptyTable {
267 table_id: table.id(),
268 });
269 }
270
271 let log_capacity = table.log_capacity(size);
272 let packed_elem_log_bits = P::LOG_WIDTH + F::TOWER_LEVEL;
273
274 let mut cols = Vec::new();
275 let mut oracle_offset = 0;
276 let mut transparent_single_backing = vec![None; table.columns.len()];
277
278 for col in &table.columns {
279 if matches!(col.col, ColumnDef::Constant { .. }) {
280 transparent_single_backing[col.id.table_index] = Some(oracle_offset);
281 cols.push(WitnessIndexColumn {
282 shape: col.shape,
283 data: WitnessDataMut::new_owned(
284 allocator,
285 (col.shape.log_cell_size() + log_capacity)
286 .saturating_sub(packed_elem_log_bits),
287 ),
288 is_single_row: true,
289 });
290 oracle_offset += 1;
291 }
292 }
293
294 cols.extend(table.columns.iter().map(|col| WitnessIndexColumn {
295 shape: col.shape,
296 data: match col.col {
297 ColumnDef::Packed { col: inner_col, .. } => {
298 WitnessDataMut::SameAsOracleIndex(oracle_offset + inner_col.table_index)
299 }
300 ColumnDef::Constant { .. } => WitnessDataMut::SameAsOracleIndex(
301 transparent_single_backing[col.id.table_index].unwrap(),
302 ),
303 _ => WitnessDataMut::new_owned(
304 allocator,
305 (col.shape.log_cell_size() + log_capacity).saturating_sub(packed_elem_log_bits),
306 ),
307 },
308 is_single_row: false,
309 }));
310
311 let min_log_segment_size = packed_elem_log_bits
314 - table
315 .columns
316 .iter()
317 .map(|col| col.shape.log_cell_size())
318 .fold(packed_elem_log_bits, |a, b| a.min(b));
319
320 let min_log_segment_size = min_log_segment_size.min(log_capacity);
324
325 Ok(Self {
326 table,
327 cols,
328 size,
329 log_capacity,
330 min_log_segment_size,
331 oracle_offset,
332 })
333 }
334
335 pub fn table_id(&self) -> TableId {
336 self.table.id
337 }
338
339 pub fn capacity(&self) -> usize {
340 1 << self.log_capacity
341 }
342
343 pub fn full_segment(&mut self) -> TableWitnessSegment<P> {
345 let cols = self
346 .cols
347 .iter_mut()
348 .map(|col| match &mut col.data {
349 WitnessDataMut::SameAsOracleIndex(index) => RefCellData::SameAsOracleIndex(*index),
350 WitnessDataMut::Owned(data) => RefCellData::Owned(RefCell::new(data)),
351 })
352 .collect();
353 TableWitnessSegment {
354 table: self.table,
355 cols,
356 log_size: self.log_capacity,
357 oracle_offset: self.oracle_offset,
358 }
359 }
360
361 pub fn fill_sequential<T: TableFiller<P>>(
365 &mut self,
366 table: &T,
367 rows: &[T::Event],
368 ) -> Result<(), Error> {
369 let log_size = self.optimal_segment_size_heuristic();
370 self.fill_sequential_with_segment_size(table, rows, log_size)
371 }
372
373 pub fn fill_parallel<T>(&mut self, table: &T, rows: &[T::Event]) -> Result<(), Error>
377 where
378 T: TableFiller<P> + Sync,
379 T::Event: Sync,
380 {
381 let log_size = self.optimal_segment_size_heuristic();
382 self.fill_parallel_with_segment_size(table, rows, log_size)
383 }
384
385 fn optimal_segment_size_heuristic(&self) -> usize {
386 const TARGET_SEGMENT_LOG_BITS: usize = 12 + 3;
388
389 let n_cols = self.table.columns.len();
390 let median_col_log_bits = self
391 .table
392 .columns
393 .iter()
394 .map(|col| col.shape.log_cell_size())
395 .sorted()
396 .nth(n_cols / 2)
397 .unwrap_or_default();
398
399 TARGET_SEGMENT_LOG_BITS.saturating_sub(median_col_log_bits)
400 }
401
402 pub fn fill_sequential_with_segment_size<T: TableFiller<P>>(
406 &mut self,
407 table: &T,
408 rows: &[T::Event],
409 log_size: usize,
410 ) -> Result<(), Error> {
411 if rows.len() != self.size {
412 return Err(Error::IncorrectNumberOfTableEvents {
413 expected: self.size,
414 actual: rows.len(),
415 });
416 }
417
418 let mut segmented_view = TableWitnessSegmentedView::new(self, log_size);
419
420 let log_size = segmented_view.log_segment_size;
422 let segment_size = 1 << log_size;
423
424 debug_assert_ne!(rows.len(), 0);
426 let n_chunks = (rows.len() - 1) / segment_size + 1;
428
429 let (full_chunk_segments, mut rest_segments) = segmented_view.split_at(n_chunks - 1);
430
431 full_chunk_segments
433 .into_iter()
434 .zip(rows.chunks(segment_size).take(n_chunks - 1))
436 .try_for_each(|(mut witness_segment, row_chunk)| {
437 table
438 .fill(row_chunk.iter(), &mut witness_segment)
439 .map_err(Error::TableFill)
440 })?;
441
442 let row_chunk = &rows[(n_chunks - 1) * segment_size..];
447 let mut padded_row_chunk = Vec::new();
448 let row_chunk = if row_chunk.len() != segment_size {
449 padded_row_chunk.reserve(segment_size);
450 padded_row_chunk.extend_from_slice(row_chunk);
451 let last_event = row_chunk
452 .last()
453 .expect("row_chunk must be non-empty because of how n_chunk is calculated")
454 .clone();
455 padded_row_chunk.resize(segment_size, last_event);
456 &padded_row_chunk
457 } else {
458 row_chunk
459 };
460
461 let (partial_chunk_segments, rest_segments) = rest_segments.split_at(1);
462 let mut partial_chunk_segment_iter = partial_chunk_segments.into_iter();
463 let mut witness_segment = partial_chunk_segment_iter.next().expect(
464 "segmented_view.split_at called with 1 must return a view with exactly one segment",
465 );
466 table
467 .fill(row_chunk.iter(), &mut witness_segment)
468 .map_err(Error::TableFill)?;
469 assert!(partial_chunk_segment_iter.next().is_none());
470
471 let last_segment_cols = witness_segment
474 .cols
475 .iter_mut()
476 .map(|col| match col {
477 RefCellData::Owned(data) => WitnessColumnInfo::Owned(data.get_mut()),
478 RefCellData::SameAsOracleIndex(idx) => WitnessColumnInfo::SameAsOracleIndex(*idx),
479 })
480 .collect::<Vec<_>>();
481
482 rest_segments.into_iter().for_each(|mut segment| {
483 for (dst_col, src_col) in iter::zip(&mut segment.cols, &last_segment_cols) {
484 if let (RefCellData::Owned(dst), WitnessColumnInfo::Owned(src)) = (dst_col, src_col)
485 {
486 dst.get_mut().copy_from_slice(src)
487 }
488 }
489 });
490
491 Ok(())
492 }
493
494 pub fn fill_parallel_with_segment_size<T>(
498 &mut self,
499 table: &T,
500 rows: &[T::Event],
501 log_size: usize,
502 ) -> Result<(), Error>
503 where
504 T: TableFiller<P> + Sync,
505 T::Event: Sync,
506 {
507 if rows.len() != self.size {
508 return Err(Error::IncorrectNumberOfTableEvents {
509 expected: self.size,
510 actual: rows.len(),
511 });
512 }
513
514 let mut segmented_view = TableWitnessSegmentedView::new(self, log_size);
519
520 let log_size = segmented_view.log_segment_size;
522 let segment_size = 1 << log_size;
523
524 debug_assert_ne!(rows.len(), 0);
526 let n_chunks = (rows.len() - 1) / segment_size + 1;
528
529 let (full_chunk_segments, mut rest_segments) = segmented_view.split_at(n_chunks - 1);
530
531 full_chunk_segments
533 .into_par_iter()
534 .zip(rows.par_chunks(segment_size).take(n_chunks - 1))
536 .try_for_each(|(mut witness_segment, row_chunk)| {
537 table
538 .fill(row_chunk.iter(), &mut witness_segment)
539 .map_err(Error::TableFill)
540 })?;
541
542 let row_chunk = &rows[(n_chunks - 1) * segment_size..];
547 let mut padded_row_chunk = Vec::new();
548 let row_chunk = if row_chunk.len() != segment_size {
549 padded_row_chunk.reserve(segment_size);
550 padded_row_chunk.extend_from_slice(row_chunk);
551 let last_event = row_chunk
552 .last()
553 .expect("row_chunk must be non-empty because of how n_chunk is calculated")
554 .clone();
555 padded_row_chunk.resize(segment_size, last_event);
556 &padded_row_chunk
557 } else {
558 row_chunk
559 };
560
561 let (partial_chunk_segments, rest_segments) = rest_segments.split_at(1);
562 let mut partial_chunk_segment_iter = partial_chunk_segments.into_iter();
563 let mut witness_segment = partial_chunk_segment_iter.next().expect(
564 "segmented_view.split_at called with 1 must return a view with exactly one segment",
565 );
566 table
567 .fill(row_chunk.iter(), &mut witness_segment)
568 .map_err(Error::TableFill)?;
569 assert!(partial_chunk_segment_iter.next().is_none());
570
571 let last_segment_cols = witness_segment
574 .cols
575 .iter_mut()
576 .map(|col| match col {
577 RefCellData::Owned(data) => WitnessColumnInfo::Owned(data.get_mut()),
578 RefCellData::SameAsOracleIndex(idx) => WitnessColumnInfo::SameAsOracleIndex(*idx),
579 })
580 .collect::<Vec<_>>();
581
582 rest_segments.into_par_iter().for_each(|mut segment| {
583 for (dst_col, src_col) in iter::zip(&mut segment.cols, &last_segment_cols) {
584 if let (RefCellData::Owned(dst), WitnessColumnInfo::Owned(src)) = (dst_col, src_col)
585 {
586 dst.get_mut().copy_from_slice(src)
587 }
588 }
589 });
590
591 Ok(())
592 }
593
594 pub fn segments(&mut self, log_size: usize) -> impl Iterator<Item = TableWitnessSegment<P>> {
600 TableWitnessSegmentedView::new(self, log_size).into_iter()
601 }
602
603 pub fn par_segments(
604 &mut self,
605 log_size: usize,
606 ) -> impl IndexedParallelIterator<Item = TableWitnessSegment<'_, P>> {
607 TableWitnessSegmentedView::new(self, log_size).into_par_iter()
608 }
609}
610
611#[derive(Debug)]
617struct TableWitnessSegmentedView<'a, P = PackedType<OptimalUnderlier, B128>>
618where
619 P: PackedField,
620 P::Scalar: TowerField,
621{
622 table: &'a Table<P::Scalar>,
623 oracle_offset: usize,
624 cols: Vec<WitnessColumnInfo<(&'a mut [P], usize)>>,
625 log_segment_size: usize,
626 n_segments: usize,
627}
628
629impl<'a, F: TowerField, P: PackedField<Scalar = F>> TableWitnessSegmentedView<'a, P> {
630 fn new(witness: &'a mut TableWitnessIndex<P>, log_segment_size: usize) -> Self {
631 let log_segment_size = log_segment_size
633 .min(witness.log_capacity)
634 .max(witness.min_log_segment_size);
635
636 let cols = witness
637 .cols
638 .iter_mut()
639 .map(|col| match &mut col.data {
640 WitnessColumnInfo::Owned(data) => {
641 let chunk_size = (log_segment_size + col.shape.log_cell_size())
642 .saturating_sub(P::LOG_WIDTH + F::TOWER_LEVEL);
643 WitnessColumnInfo::Owned((&mut **data, 1 << chunk_size))
644 }
645 WitnessColumnInfo::SameAsOracleIndex(idx) => {
646 WitnessColumnInfo::SameAsOracleIndex(*idx)
647 }
648 })
649 .collect::<Vec<_>>();
650 Self {
651 table: witness.table,
652 oracle_offset: witness.oracle_offset,
653 cols,
654 log_segment_size,
655 n_segments: 1 << (witness.log_capacity - log_segment_size),
656 }
657 }
658
659 fn split_at(
660 &mut self,
661 index: usize,
662 ) -> (TableWitnessSegmentedView<P>, TableWitnessSegmentedView<P>) {
663 assert!(index <= self.n_segments);
664 let (cols_0, cols_1) = self
665 .cols
666 .iter_mut()
667 .map(|col| match col {
668 WitnessColumnInfo::Owned((data, chunk_size)) => {
669 let (data_0, data_1) = data.split_at_mut(*chunk_size * index);
670 (
671 WitnessColumnInfo::Owned((data_0, *chunk_size)),
672 WitnessColumnInfo::Owned((data_1, *chunk_size)),
673 )
674 }
675 WitnessColumnInfo::SameAsOracleIndex(idx) => (
676 WitnessColumnInfo::SameAsOracleIndex(*idx),
677 WitnessColumnInfo::SameAsOracleIndex(*idx),
678 ),
679 })
680 .unzip();
681 (
682 TableWitnessSegmentedView {
683 table: self.table,
684 oracle_offset: self.oracle_offset,
685 cols: cols_0,
686 log_segment_size: self.log_segment_size,
687 n_segments: index,
688 },
689 TableWitnessSegmentedView {
690 table: self.table,
691 oracle_offset: self.oracle_offset,
692 cols: cols_1,
693 log_segment_size: self.log_segment_size,
694 n_segments: self.n_segments - index,
695 },
696 )
697 }
698
699 fn into_iter(self) -> impl Iterator<Item = TableWitnessSegment<'a, P>> {
700 let TableWitnessSegmentedView {
701 table,
702 oracle_offset,
703 cols,
704 log_segment_size,
705 n_segments,
706 } = self;
707
708 if n_segments == 0 {
709 itertools::Either::Left(iter::empty())
710 } else {
711 let iter = MultiIterator::new(
712 cols.into_iter()
713 .map(|col| match col {
714 WitnessColumnInfo::Owned((data, chunk_size)) => itertools::Either::Left(
715 data.chunks_mut(chunk_size)
716 .map(|chunk| RefCellData::Owned(RefCell::new(chunk))),
717 ),
718 WitnessColumnInfo::SameAsOracleIndex(index) => itertools::Either::Right(
719 iter::repeat_n(index, n_segments).map(RefCellData::SameAsOracleIndex),
720 ),
721 })
722 .collect(),
723 )
724 .map(move |cols| TableWitnessSegment {
725 table,
726 cols,
727 log_size: log_segment_size,
728 oracle_offset,
729 });
730 itertools::Either::Right(iter)
731 }
732 }
733
734 fn into_par_iter(self) -> impl IndexedParallelIterator<Item = TableWitnessSegment<'a, P>> {
735 let TableWitnessSegmentedView {
736 table,
737 oracle_offset,
738 cols,
739 log_segment_size,
740 n_segments: _,
741 } = self;
742
743 #[allow(clippy::mut_from_ref)]
749 unsafe fn cast_slice_ref_to_mut<T>(slice: &[T]) -> &mut [T] {
750 slice::from_raw_parts_mut(slice.as_ptr() as *mut T, slice.len())
751 }
752
753 let cols = cols
756 .into_iter()
757 .map(|col| -> WitnessColumnInfo<(&'a [P], usize)> {
758 match col {
759 WitnessColumnInfo::Owned((data, chunk_size)) => {
760 WitnessColumnInfo::Owned((data, chunk_size))
761 }
762 WitnessColumnInfo::SameAsOracleIndex(index) => {
763 WitnessColumnInfo::SameAsOracleIndex(index)
764 }
765 }
766 })
767 .collect::<Vec<_>>();
768
769 (0..self.n_segments).into_par_iter().map(move |i| {
770 let col_strides = cols
771 .iter()
772 .map(|col| match col {
773 WitnessColumnInfo::SameAsOracleIndex(index) => {
774 RefCellData::SameAsOracleIndex(*index)
775 }
776 WitnessColumnInfo::Owned((data, chunk_size)) => {
777 RefCellData::Owned(RefCell::new(unsafe {
778 cast_slice_ref_to_mut(&data[i * chunk_size..(i + 1) * chunk_size])
783 }))
784 }
785 })
786 .collect();
787 TableWitnessSegment {
788 table,
789 cols: col_strides,
790 log_size: log_segment_size,
791 oracle_offset,
792 }
793 })
794 }
795}
796
797#[derive(Debug, CopyGetters)]
802pub struct TableWitnessSegment<'a, P = PackedType<OptimalUnderlier, B128>>
803where
804 P: PackedField,
805 P::Scalar: TowerField,
806{
807 table: &'a Table<P::Scalar>,
808 cols: Vec<RefCellData<'a, P>>,
809 #[get_copy = "pub"]
810 log_size: usize,
811 oracle_offset: usize,
812}
813
814impl<'a, F: TowerField, P: PackedField<Scalar = F>> TableWitnessSegment<'a, P> {
815 pub fn get<FSub: TowerField, const V: usize>(
816 &self,
817 col: Col<FSub, V>,
818 ) -> Result<Ref<[PackedSubfield<P, FSub>]>, Error>
819 where
820 P: PackedExtension<FSub>,
821 {
822 if col.table_id != self.table.id() {
823 return Err(Error::TableMismatch {
824 column_table_id: col.table_id,
825 witness_table_id: self.table.id(),
826 });
827 }
828
829 let col = self
830 .get_col_data(col.table_index)
831 .ok_or_else(|| Error::MissingColumn(col.id()))?;
832 let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
833 Ok(Ref::map(col_ref, |packed| PackedExtension::cast_bases(packed)))
834 }
835
836 pub fn get_mut<FSub: TowerField, const V: usize>(
837 &self,
838 col: Col<FSub, V>,
839 ) -> Result<RefMut<[PackedSubfield<P, FSub>]>, Error>
840 where
841 P: PackedExtension<FSub>,
842 F: ExtensionField<FSub>,
843 {
844 if col.table_id != self.table.id() {
845 return Err(Error::TableMismatch {
846 column_table_id: col.table_id,
847 witness_table_id: self.table.id(),
848 });
849 }
850
851 let col = self
852 .get_col_data(col.table_index)
853 .ok_or_else(|| Error::MissingColumn(col.id()))?;
854 let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
855 Ok(RefMut::map(col_ref, |packed| PackedExtension::cast_bases_mut(packed)))
856 }
857
858 pub fn get_scalars<FSub: TowerField, const V: usize>(
859 &self,
860 col: Col<FSub, V>,
861 ) -> Result<Ref<[FSub]>, Error>
862 where
863 P: PackedExtension<FSub>,
864 F: ExtensionField<FSub>,
865 PackedSubfield<P, FSub>: PackedFieldIndexable,
866 {
867 self.get(col)
868 .map(|packed| Ref::map(packed, <PackedSubfield<P, FSub>>::unpack_scalars))
869 }
870
871 pub fn get_scalars_mut<FSub: TowerField, const V: usize>(
872 &self,
873 col: Col<FSub, V>,
874 ) -> Result<RefMut<[FSub]>, Error>
875 where
876 P: PackedExtension<FSub>,
877 F: ExtensionField<FSub>,
878 PackedSubfield<P, FSub>: PackedFieldIndexable,
879 {
880 self.get_mut(col)
881 .map(|packed| RefMut::map(packed, <PackedSubfield<P, FSub>>::unpack_scalars_mut))
882 }
883
884 pub fn get_as<T: Pod, FSub: TowerField, const V: usize>(
885 &self,
886 col: Col<FSub, V>,
887 ) -> Result<Ref<[T]>, Error>
888 where
889 P: PackedExtension<FSub> + PackedFieldIndexable,
890 F: ExtensionField<FSub> + Pod,
891 {
892 let col = self
893 .get_col_data(col.table_index)
894 .ok_or_else(|| Error::MissingColumn(col.id()))?;
895 let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
896 Ok(Ref::map(col_ref, |col| must_cast_slice(P::unpack_scalars(col))))
897 }
898
899 pub fn get_mut_as<T: Pod, FSub: TowerField, const V: usize>(
900 &self,
901 col: Col<FSub, V>,
902 ) -> Result<RefMut<[T]>, Error>
903 where
904 P: PackedExtension<FSub> + PackedFieldIndexable,
905 F: ExtensionField<FSub> + Pod,
906 {
907 if col.table_id != self.table.id() {
908 return Err(Error::TableMismatch {
909 column_table_id: col.table_id,
910 witness_table_id: self.table.id(),
911 });
912 }
913
914 let col = self
915 .get_col_data(col.table_index)
916 .ok_or_else(|| Error::MissingColumn(col.id()))?;
917 let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
918 Ok(RefMut::map(col_ref, |col| must_cast_slice_mut(P::unpack_scalars_mut(col))))
919 }
920
921 pub fn eval_expr<FSub: TowerField, const V: usize>(
928 &self,
929 expr: &Expr<FSub, V>,
930 ) -> Result<impl Iterator<Item = PackedSubfield<P, FSub>>, Error>
931 where
932 P: PackedExtension<FSub>,
933 {
934 let log_vals_per_row = checked_log_2(V);
935
936 let partition =
937 self.table
938 .partitions
939 .get(log_vals_per_row)
940 .ok_or_else(|| Error::MissingPartition {
941 table_id: self.table.id(),
942 log_vals_per_row,
943 })?;
944 let col_refs = partition
945 .columns
946 .iter()
947 .zip(expr.expr().vars_usage())
948 .map(|(col_index, used)| {
949 used.then(|| {
950 self.get(Col::<FSub, V>::new(
951 ColumnId {
952 table_id: self.table.id(),
953 table_index: partition.columns[*col_index],
954 },
955 *col_index,
956 ))
957 })
958 .transpose()
959 })
960 .collect::<Result<Vec<_>, _>>()?;
961
962 let log_packed_elems =
963 (self.log_size + log_vals_per_row).saturating_sub(<PackedSubfield<P, FSub>>::LOG_WIDTH);
964
965 let dummy_col = zeroed_vec(1 << log_packed_elems);
967
968 let cols = col_refs
969 .iter()
970 .map(|col| col.as_ref().map(|col_ref| &**col_ref).unwrap_or(&dummy_col))
971 .collect::<Vec<_>>();
972 let cols = RowsBatchRef::new(&cols, 1 << log_packed_elems);
973
974 let mut evals = zeroed_vec(1 << log_packed_elems);
979 ArithCircuitPoly::new(expr.expr().clone()).batch_evaluate(&cols, &mut evals)?;
980 Ok(evals.into_iter())
981 }
982
983 pub fn size(&self) -> usize {
984 1 << self.log_size
985 }
986
987 fn get_col_data(&self, table_index: ColumnIndex) -> Option<&RefCell<&'a mut [P]>> {
988 self.get_col_data_by_oracle_offset(self.oracle_offset + table_index)
989 }
990
991 fn get_col_data_by_oracle_offset(&self, oracle_id: OracleId) -> Option<&RefCell<&'a mut [P]>> {
992 match self.cols.get(oracle_id) {
993 Some(RefCellData::Owned(data)) => Some(data),
994 Some(RefCellData::SameAsOracleIndex(id)) => self.get_col_data_by_oracle_offset(*id),
995 None => None,
996 }
997 }
998}
999
1000pub trait TableFiller<P = PackedType<OptimalUnderlier, B128>>
1002where
1003 P: PackedField,
1004 P::Scalar: TowerField,
1005{
1006 type Event: Clone;
1008
1009 fn id(&self) -> TableId;
1011
1012 fn fill<'a>(
1018 &'a self,
1019 rows: impl Iterator<Item = &'a Self::Event> + Clone,
1020 witness: &'a mut TableWitnessSegment<P>,
1021 ) -> anyhow::Result<()>;
1022}
1023
1024#[cfg(test)]
1025mod tests {
1026 use std::iter::repeat_with;
1027
1028 use assert_matches::assert_matches;
1029 use binius_field::{
1030 arch::{OptimalUnderlier128b, OptimalUnderlier256b},
1031 packed::{len_packed_slice, set_packed_slice},
1032 };
1033 use rand::{rngs::StdRng, Rng, SeedableRng};
1034
1035 use super::*;
1036 use crate::builder::{
1037 types::{B1, B32, B8},
1038 ConstraintSystem, Statement, TableBuilder,
1039 };
1040
1041 #[test]
1042 fn test_table_witness_borrows() {
1043 let table_id = 0;
1044 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1045 let mut table = TableBuilder::new(&mut inner_table);
1046 let col0 = table.add_committed::<B1, 8>("col0");
1047 let col1 = table.add_committed::<B1, 32>("col1");
1048 let col2 = table.add_committed::<B8, 1>("col2");
1049 let col3 = table.add_committed::<B32, 1>("col3");
1050
1051 let allocator = bumpalo::Bump::new();
1052 let table_size = 64;
1053 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1054 &allocator,
1055 &inner_table,
1056 table_size,
1057 )
1058 .unwrap();
1059 let segment = index.full_segment();
1060
1061 {
1062 let col0_ref0 = segment.get(col0).unwrap();
1063 let _col0_ref1 = segment.get(col0).unwrap();
1064 assert_matches!(segment.get_mut(col0), Err(Error::WitnessBorrowMut(_)));
1065 drop(col0_ref0);
1066
1067 let col1_ref = segment.get_mut(col1).unwrap();
1068 assert_matches!(segment.get(col1), Err(Error::WitnessBorrow(_)));
1069 drop(col1_ref);
1070 }
1071
1072 assert_eq!(len_packed_slice(&segment.get_mut(col0).unwrap()), 1 << 9);
1073 assert_eq!(len_packed_slice(&segment.get_mut(col1).unwrap()), 1 << 11);
1074 assert_eq!(len_packed_slice(&segment.get_mut(col2).unwrap()), 1 << 6);
1075 assert_eq!(len_packed_slice(&segment.get_mut(col3).unwrap()), 1 << 6);
1076 }
1077
1078 #[test]
1079 fn test_table_witness_segments() {
1080 let table_id = 0;
1081 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1082 let mut table = TableBuilder::new(&mut inner_table);
1083 let col0 = table.add_committed::<B1, 8>("col0");
1084 let col1 = table.add_committed::<B1, 32>("col1");
1085 let col2 = table.add_committed::<B8, 1>("col2");
1086 let col3 = table.add_committed::<B32, 1>("col3");
1087
1088 let allocator = bumpalo::Bump::new();
1089 let table_size = 64;
1090 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1091 &allocator,
1092 &inner_table,
1093 table_size,
1094 )
1095 .unwrap();
1096
1097 assert_eq!(index.min_log_segment_size(), 4);
1098 let mut iter = index.segments(5);
1099 let seg0 = iter.next().unwrap();
1100 let seg1 = iter.next().unwrap();
1101 assert!(iter.next().is_none());
1102
1103 assert_eq!(len_packed_slice(&seg0.get_mut(col0).unwrap()), 1 << 8);
1104 assert_eq!(len_packed_slice(&seg0.get_mut(col1).unwrap()), 1 << 10);
1105 assert_eq!(len_packed_slice(&seg0.get_mut(col2).unwrap()), 1 << 5);
1106 assert_eq!(len_packed_slice(&seg0.get_mut(col3).unwrap()), 1 << 5);
1107
1108 assert_eq!(len_packed_slice(&seg1.get_mut(col0).unwrap()), 1 << 8);
1109 assert_eq!(len_packed_slice(&seg1.get_mut(col1).unwrap()), 1 << 10);
1110 assert_eq!(len_packed_slice(&seg1.get_mut(col2).unwrap()), 1 << 5);
1111 assert_eq!(len_packed_slice(&seg1.get_mut(col3).unwrap()), 1 << 5);
1112 }
1113
1114 #[test]
1115 fn test_eval_expr() {
1116 let table_id = 0;
1117 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1118 let mut table = TableBuilder::new(&mut inner_table);
1119 let col0 = table.add_committed::<B8, 2>("col0");
1120 let col1 = table.add_committed::<B8, 2>("col0");
1121 let col2 = table.add_committed::<B8, 2>("col0");
1122
1123 let allocator = bumpalo::Bump::new();
1124 let table_size = 1 << 6;
1125 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1126 &allocator,
1127 &inner_table,
1128 table_size,
1129 )
1130 .unwrap();
1131
1132 let segment = index.full_segment();
1133 assert_eq!(segment.log_size(), 6);
1134
1135 {
1137 let mut col0 = segment.get_mut(col0).unwrap();
1138 let mut col1 = segment.get_mut(col1).unwrap();
1139 let mut col2 = segment.get_mut(col2).unwrap();
1140
1141 let expected_slice_len = 1 << 3;
1143 assert_eq!(col0.len(), expected_slice_len);
1144 assert_eq!(col0.len(), expected_slice_len);
1145 assert_eq!(col0.len(), expected_slice_len);
1146
1147 for i in 0..expected_slice_len * <PackedType<OptimalUnderlier128b, B8>>::WIDTH {
1148 set_packed_slice(&mut col0, i, B8::new(i as u8) + B8::new(0x00));
1149 set_packed_slice(&mut col1, i, B8::new(i as u8) + B8::new(0x40));
1150 set_packed_slice(&mut col2, i, B8::new(i as u8) + B8::new(0x80));
1151 }
1152 }
1153
1154 let evals = segment.eval_expr(&(col0 * col1 - col2)).unwrap();
1155 for (i, eval_i) in evals
1156 .into_iter()
1157 .flat_map(PackedField::into_iter)
1158 .enumerate()
1159 {
1160 let col0_val = B8::new(i as u8) + B8::new(0x00);
1161 let col1_val = B8::new(i as u8) + B8::new(0x40);
1162 let col2_val = B8::new(i as u8) + B8::new(0x80);
1163 assert_eq!(eval_i, col0_val * col1_val - col2_val);
1164 }
1165 }
1166
1167 #[test]
1168 fn test_small_tables() {
1169 let table_id = 0;
1170 let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1171 let mut table = TableBuilder::new(&mut inner_table);
1172 let _col0 = table.add_committed::<B1, 8>("col0");
1173 let _col1 = table.add_committed::<B1, 32>("col1");
1174 let _col2 = table.add_committed::<B8, 1>("col2");
1175 let _col3 = table.add_committed::<B32, 1>("col3");
1176
1177 let allocator = bumpalo::Bump::new();
1178 let table_size = 7;
1179 let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier256b, B128>>::new(
1180 &allocator,
1181 &inner_table,
1182 table_size,
1183 )
1184 .unwrap();
1185
1186 assert_eq!(index.log_capacity(), 4);
1187 assert_eq!(index.min_log_segment_size(), 4);
1188
1189 let mut iter = index.segments(5);
1190 assert_eq!(iter.next().unwrap().log_size(), 4);
1192 assert!(iter.next().is_none());
1193 drop(iter);
1194
1195 let mut iter = index.segments(2);
1196 assert_eq!(iter.next().unwrap().log_size(), 4);
1198 assert!(iter.next().is_none());
1199 drop(iter);
1200 }
1201
1202 struct TestTable {
1203 id: TableId,
1204 col0: Col<B32>,
1205 col1: Col<B32>,
1206 }
1207
1208 impl TestTable {
1209 fn new(cs: &mut ConstraintSystem) -> Self {
1210 let mut table = cs.add_table("test");
1211
1212 let col0 = table.add_committed("col0");
1213 let col1 = table.add_computed("col1", col0 * col0 + B32::new(0x03));
1214
1215 Self {
1216 id: table.id(),
1217 col0,
1218 col1,
1219 }
1220 }
1221 }
1222
1223 impl TableFiller<PackedType<OptimalUnderlier128b, B128>> for TestTable {
1224 type Event = u32;
1225
1226 fn id(&self) -> TableId {
1227 self.id
1228 }
1229
1230 fn fill<'a>(
1231 &'a self,
1232 rows: impl Iterator<Item = &'a Self::Event> + Clone,
1233 witness: &'a mut TableWitnessSegment<PackedType<OptimalUnderlier128b, B128>>,
1234 ) -> anyhow::Result<()> {
1235 let mut col0 = witness.get_scalars_mut(self.col0)?;
1236 let mut col1 = witness.get_scalars_mut(self.col1)?;
1237 for (i, &val) in rows.enumerate() {
1238 col0[i] = B32::new(val);
1239 col1[i] = col0[i].pow(2) + B32::new(0x03);
1240 }
1241 Ok(())
1242 }
1243 }
1244
1245 #[test]
1246 fn test_fill_sequential_with_incomplete_events() {
1247 let mut cs = ConstraintSystem::new();
1248 let test_table = TestTable::new(&mut cs);
1249
1250 let allocator = bumpalo::Bump::new();
1251
1252 let table_size = 11;
1253 let statement = Statement {
1254 boundaries: vec![],
1255 table_sizes: vec![table_size],
1256 };
1257 let mut index = cs.build_witness(&allocator, &statement).unwrap();
1258 let table_index = index.get_table(test_table.id()).unwrap();
1259
1260 let mut rng = StdRng::seed_from_u64(0);
1261 let rows = repeat_with(|| rng.gen())
1262 .take(table_size)
1263 .collect::<Vec<_>>();
1264
1265 assert_eq!(table_index.log_capacity(), 4);
1267
1268 assert_eq!(table_index.min_log_segment_size(), 2);
1270
1271 assert_matches!(
1273 table_index.fill_sequential_with_segment_size(&test_table, &rows[1..], 2),
1274 Err(Error::IncorrectNumberOfTableEvents { .. })
1275 );
1276
1277 table_index
1278 .fill_sequential_with_segment_size(&test_table, &rows, 2)
1279 .unwrap();
1280
1281 let segment = table_index.full_segment();
1282 let col0 = segment.get_scalars(test_table.col0).unwrap();
1283 for i in 0..11 {
1284 assert_eq!(col0[i].val(), rows[i]);
1285 }
1286 assert_eq!(col0[11].val(), rows[10]);
1287 assert_eq!(col0[12].val(), rows[8]);
1288 assert_eq!(col0[13].val(), rows[9]);
1289 assert_eq!(col0[14].val(), rows[10]);
1290 assert_eq!(col0[15].val(), rows[10]);
1291 }
1292
1293 #[test]
1294 fn test_fill_parallel_with_incomplete_events() {
1295 let mut cs = ConstraintSystem::new();
1296 let test_table = TestTable::new(&mut cs);
1297
1298 let allocator = bumpalo::Bump::new();
1299
1300 let table_size = 11;
1301 let statement = Statement {
1302 boundaries: vec![],
1303 table_sizes: vec![table_size],
1304 };
1305 let mut index = cs.build_witness(&allocator, &statement).unwrap();
1306 let table_index = index.get_table(test_table.id()).unwrap();
1307
1308 let mut rng = StdRng::seed_from_u64(0);
1309 let rows = repeat_with(|| rng.gen())
1310 .take(table_size)
1311 .collect::<Vec<_>>();
1312
1313 assert_eq!(table_index.log_capacity(), 4);
1315
1316 assert_eq!(table_index.min_log_segment_size(), 2);
1318
1319 assert_matches!(
1321 table_index.fill_parallel_with_segment_size(&test_table, &rows[1..], 2),
1322 Err(Error::IncorrectNumberOfTableEvents { .. })
1323 );
1324
1325 table_index
1326 .fill_parallel_with_segment_size(&test_table, &rows, 2)
1327 .unwrap();
1328
1329 let segment = table_index.full_segment();
1330 let col0 = segment.get_scalars(test_table.col0).unwrap();
1331 for i in 0..11 {
1332 assert_eq!(col0[i].val(), rows[i]);
1333 }
1334 assert_eq!(col0[11].val(), rows[10]);
1335 assert_eq!(col0[12].val(), rows[8]);
1336 assert_eq!(col0[13].val(), rows[9]);
1337 assert_eq!(col0[14].val(), rows[10]);
1338 assert_eq!(col0[15].val(), rows[10]);
1339 }
1340}