binius_m3/builder/
witness.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{
4	cell::{Ref, RefCell, RefMut},
5	iter,
6	ops::{Deref, DerefMut},
7	slice,
8	sync::Arc,
9};
10
11use binius_compute::alloc::{ComputeAllocator, HostBumpAllocator};
12use binius_core::witness::{MultilinearExtensionIndex, MultilinearWitness};
13use binius_fast_compute::arith_circuit::ArithCircuitPoly;
14use binius_field::{
15	ExtensionField, PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield, TowerField,
16	arch::OptimalUnderlier,
17	as_packed_field::PackedType,
18	packed::{get_packed_slice, set_packed_slice},
19};
20use binius_math::{
21	ArithCircuit, CompositionPoly, MultilinearExtension, MultilinearPoly, RowsBatchRef,
22};
23use binius_maybe_rayon::prelude::*;
24use binius_utils::checked_arithmetics::checked_log_2;
25use bytemuck::{Pod, must_cast_slice, must_cast_slice_mut, zeroed_vec};
26use either::Either;
27use getset::CopyGetters;
28use itertools::Itertools;
29
30use super::{
31	ColumnDef, ColumnId, ColumnInfo, ConstraintSystem, Expr,
32	column::{Col, ColumnShape},
33	constraint_system::OracleMapping,
34	error::Error,
35	table::{self, Table, TableId},
36	types::{B1, B8, B16, B32, B64, B128},
37};
38use crate::builder::multi_iter::MultiIterator;
39
40/// Holds witness column data for all tables in a constraint system, indexed by column ID.
41///
42/// The struct has two lifetimes: `'cs` is the lifetime of the constraint system, and `'alloc` is
43/// the lifetime of the bump allocator. The reason these must be separate is that the witness index
44/// gets converted into a multilinear extension index, which maintains references to the data
45/// allocated by the allocator, but does not need to maintain a reference to the constraint system,
46/// which can then be dropped.
47pub struct WitnessIndex<'cs, 'alloc, P = PackedType<OptimalUnderlier, B128>>
48where
49	P: PackedField,
50	P::Scalar: TowerField,
51{
52	cs: &'cs ConstraintSystem<P::Scalar>,
53	allocator: &'alloc HostBumpAllocator<'alloc, P>,
54	/// Each entry is Left if the index hasn't been initialized & filled, and Right if it has.
55	tables: Vec<Either<&'cs Table<P::Scalar>, TableWitnessIndex<'cs, 'alloc, P>>>,
56}
57
58impl<'cs, 'alloc, F: TowerField, P: PackedField<Scalar = F>> WitnessIndex<'cs, 'alloc, P> {
59	/// Creates and allocates the witness index for a constraint system.
60	pub fn new(
61		cs: &'cs ConstraintSystem<F>,
62		allocator: &'alloc HostBumpAllocator<'alloc, P>,
63	) -> Self {
64		Self {
65			cs,
66			allocator,
67			tables: cs.tables.iter().map(Either::Left).collect(),
68		}
69	}
70
71	pub fn init_table(
72		&mut self,
73		table_id: TableId,
74		size: usize,
75	) -> Result<&mut TableWitnessIndex<'cs, 'alloc, P>, Error> {
76		match self.tables.get_mut(table_id) {
77			Some(entry) => match entry {
78				Either::Left(table) => {
79					if size == 0 {
80						Err(Error::EmptyTable { table_id })
81					} else {
82						let table_witness = TableWitnessIndex::new(self.allocator, table, size)?;
83						*entry = Either::Right(table_witness);
84						let Either::Right(table_witness) = entry else {
85							unreachable!("entry is assigned to this pattern on the previous line")
86						};
87						Ok(table_witness)
88					}
89				}
90				Either::Right(_) => Err(Error::TableIndexAlreadyInitialized { table_id }),
91			},
92			None => Err(Error::MissingTable { table_id }),
93		}
94	}
95
96	pub fn get_table(
97		&mut self,
98		table_id: TableId,
99	) -> Option<&mut TableWitnessIndex<'cs, 'alloc, P>> {
100		self.tables.get_mut(table_id).and_then(|table| match table {
101			Either::Left(_) => None,
102			Either::Right(index) => Some(index),
103		})
104	}
105
106	pub fn fill_table_sequential<T: TableFiller<P>>(
107		&mut self,
108		filler: &T,
109		rows: &[T::Event],
110	) -> Result<(), Error> {
111		self.init_and_fill_table(
112			filler.id(),
113			|table_witness, rows| table_witness.fill_sequential(filler, rows),
114			rows,
115		)
116	}
117
118	pub fn fill_table_parallel<T>(&mut self, filler: &T, rows: &[T::Event]) -> Result<(), Error>
119	where
120		T: TableFiller<P> + Sync,
121		T::Event: Sync,
122	{
123		self.init_and_fill_table(
124			filler.id(),
125			|table_witness, rows| table_witness.fill_parallel(filler, rows),
126			rows,
127		)
128	}
129
130	fn init_and_fill_table<Event>(
131		&mut self,
132		table_id: TableId,
133		fill: impl FnOnce(&mut TableWitnessIndex<'cs, 'alloc, P>, &[Event]) -> Result<(), Error>,
134		rows: &[Event],
135	) -> Result<(), Error> {
136		match self.tables.get_mut(table_id) {
137			Some(entry) => match entry {
138				Either::Right(witness) => fill(witness, rows),
139				Either::Left(table) => {
140					if rows.is_empty() {
141						Ok(())
142					} else {
143						let mut table_witness =
144							TableWitnessIndex::new(self.allocator, table, rows.len())?;
145						fill(&mut table_witness, rows)?;
146						*entry = Either::Right(table_witness);
147						Ok(())
148					}
149				}
150			},
151			None => Err(Error::MissingTable { table_id }),
152		}
153	}
154
155	/// Returns the sizes of all tables in the witness, indexed by table ID.
156	pub fn table_sizes(&self) -> Vec<usize> {
157		self.tables
158			.iter()
159			.map(|entry| match entry {
160				Either::Left(_) => 0,
161				Either::Right(index) => index.size(),
162			})
163			.collect()
164	}
165
166	fn mk_column_witness<'a>(
167		log_capacity: usize,
168		shape: ColumnShape,
169		data: &'a [P],
170	) -> MultilinearWitness<'a, P>
171	where
172		P: PackedExtension<B1>
173			+ PackedExtension<B8>
174			+ PackedExtension<B16>
175			+ PackedExtension<B32>
176			+ PackedExtension<B64>
177			+ PackedExtension<B128>,
178	{
179		let n_vars = log_capacity + shape.log_values_per_row;
180		let underlier_count =
181			1 << (n_vars + shape.tower_height).saturating_sub(P::LOG_WIDTH + F::TOWER_LEVEL);
182		multilin_poly_from_underlier_data(&data[..underlier_count], n_vars, shape.tower_height)
183	}
184
185	/// Converts this witness into binius_core's [`MultilinearExtensionIndex`].
186	///
187	/// Note that this function must be called only after the [`ConstraintSystem::compile`].
188	pub fn into_multilinear_extension_index(self) -> MultilinearExtensionIndex<'alloc, P>
189	where
190		P: PackedExtension<B1>
191			+ PackedExtension<B8>
192			+ PackedExtension<B16>
193			+ PackedExtension<B32>
194			+ PackedExtension<B64>
195			+ PackedExtension<B128>,
196	{
197		let oracle_lookup = self.cs.oracle_lookup();
198
199		let mut index = MultilinearExtensionIndex::new();
200
201		for table_witness in self.tables {
202			let Either::Right(table_witness) = table_witness else {
203				continue;
204			};
205			let cols = immutable_witness_index_columns(table_witness.cols);
206
207			// Here our objective is to add a witness for every oracle the table has created.
208			//
209			// There are some tricky parts that is worth keeping in mind:
210			//
211			// 1. Some oracles share witnesses, e.g. packed column has the same witness as the
212			//    column that it packs. Despite that they cannot share the underlying witness
213			//    polynomial because of the difference in n_vars.
214			//
215			// 2. Similarly, the constant column creates two oracles: the original constant and the
216			//    the user-visible one, repeating column. Instead of making the user to fill both
217			//    witnesses, we fill the original constant oracle with the truncated version of the
218			//    repeating column.
219			//
220
221			for col in cols.into_iter() {
222				let oracle_mapping = *oracle_lookup.lookup(col.column_id);
223				match oracle_mapping {
224					OracleMapping::Regular(oracle_id) => index
225						.update_multilin_poly([(
226							oracle_id,
227							Self::mk_column_witness(
228								table_witness.log_capacity,
229								col.shape,
230								col.data,
231							),
232						)])
233						.unwrap(),
234					OracleMapping::TransparentCompound {
235						original,
236						repeating,
237					} => {
238						// Create a single row poly witness for the original oracle and the
239						// repeating version of that for the repeating oracle.
240						let original_witness = Self::mk_column_witness(0, col.shape, col.data);
241						let repeating_witness = Self::mk_column_witness(
242							table_witness.log_capacity,
243							col.shape,
244							col.data,
245						);
246						index
247							.update_multilin_poly([
248								(original, original_witness),
249								(repeating, repeating_witness),
250							])
251							.unwrap();
252					}
253				}
254			}
255		}
256		index
257	}
258}
259
260impl<'cs, 'alloc, P> WitnessIndex<'cs, 'alloc, P>
261where
262	P: PackedField<Scalar: TowerField>
263		+ PackedExtension<B1>
264		+ PackedExtension<B8>
265		+ PackedExtension<B16>
266		+ PackedExtension<B32>
267		+ PackedExtension<B64>
268		+ PackedExtension<B128>,
269{
270	/// Automatically populate the witness data for all the constant columns in all the tables with
271	/// a [`TableWitnessIndex<P>`].
272	pub fn fill_constant_cols(&mut self) -> Result<(), Error> {
273		for table in self.tables.iter_mut() {
274			match table.as_mut() {
275				Either::Left(_) => (),
276				// If we have witness index data, populate the witness
277				Either::Right(table_witness_index) => {
278					let table = table_witness_index.table();
279					let segment = table_witness_index.full_segment();
280					for col in table.columns.iter() {
281						if let ColumnDef::Constant { data, .. } = &col.col {
282							let mut witness_data = segment.get_dyn_mut(col.id)?;
283							let len = witness_data.size();
284							for (i, scalar) in data.iter().cycle().take(len).enumerate() {
285								witness_data.set(i, *scalar)?
286							}
287						}
288					}
289				}
290			}
291		}
292		Ok(())
293	}
294}
295
296fn multilin_poly_from_underlier_data<P>(
297	data: &[P],
298	n_vars: usize,
299	tower_height: usize,
300) -> Arc<dyn MultilinearPoly<P> + Send + Sync + '_>
301where
302	P: PackedExtension<B1>
303		+ PackedExtension<B8>
304		+ PackedExtension<B16>
305		+ PackedExtension<B32>
306		+ PackedExtension<B64>
307		+ PackedExtension<B128>,
308{
309	match tower_height {
310		0 => MultilinearExtension::new(n_vars, PackedExtension::<B1>::cast_bases(data))
311			.unwrap()
312			.specialize_arc_dyn(),
313		3 => MultilinearExtension::new(n_vars, PackedExtension::<B8>::cast_bases(data))
314			.unwrap()
315			.specialize_arc_dyn(),
316		4 => MultilinearExtension::new(n_vars, PackedExtension::<B16>::cast_bases(data))
317			.unwrap()
318			.specialize_arc_dyn(),
319		5 => MultilinearExtension::new(n_vars, PackedExtension::<B32>::cast_bases(data))
320			.unwrap()
321			.specialize_arc_dyn(),
322		6 => MultilinearExtension::new(n_vars, PackedExtension::<B64>::cast_bases(data))
323			.unwrap()
324			.specialize_arc_dyn(),
325		7 => MultilinearExtension::new(n_vars, PackedExtension::<B128>::cast_bases(data))
326			.unwrap()
327			.specialize_arc_dyn(),
328		_ => {
329			panic!("Unsupported tower height: {tower_height}");
330		}
331	}
332}
333
334/// Holds witness column data for a table, indexed by column index.
335#[derive(Debug, CopyGetters)]
336pub struct TableWitnessIndex<'cs, 'alloc, P = PackedType<OptimalUnderlier, B128>>
337where
338	P: PackedField,
339	P::Scalar: TowerField,
340{
341	#[get_copy = "pub"]
342	table: &'cs Table<P::Scalar>,
343	cols: Vec<WitnessIndexColumn<'alloc, P>>,
344	/// The number of table events that the index should contain.
345	#[get_copy = "pub"]
346	size: usize,
347	#[get_copy = "pub"]
348	log_capacity: usize,
349	/// Binary logarithm of the mininimum segment size.
350	///
351	/// This is the minimum number of logical rows that can be put into one segment during
352	/// iteration. It is the maximum number of logical rows occupied by a single underlier.
353	#[get_copy = "pub"]
354	min_log_segment_size: usize,
355}
356
357#[derive(Debug)]
358pub struct WitnessIndexColumn<'a, P: PackedField> {
359	shape: ColumnShape,
360	data: WitnessDataMut<'a, P>,
361	column_id: ColumnId,
362}
363
364#[derive(Debug, Clone)]
365enum WitnessColumnInfo<T> {
366	Owned(T),
367	/// This column is same as the column stored in `cols[.0]`.
368	SameAsIndex(usize),
369}
370
371type WitnessDataMut<'a, P> = WitnessColumnInfo<&'a mut [P]>;
372
373impl<'a, P: PackedField> WitnessDataMut<'a, P> {
374	pub fn new_owned(allocator: &'a HostBumpAllocator<'a, P>, log_underlier_count: usize) -> Self {
375		let slice = allocator
376			.alloc(1 << log_underlier_count)
377			.expect("failed to allocate witness data slice");
378
379		Self::Owned(slice)
380	}
381}
382
383type RefCellData<'a, P> = WitnessColumnInfo<RefCell<&'a mut [P]>>;
384
385#[derive(Debug)]
386struct ImmutableWitnessIndexColumn<'a, P: PackedField> {
387	shape: ColumnShape,
388	data: &'a [P],
389	column_id: ColumnId,
390}
391
392/// Converts the vector of witness columns into immutable references to column data that may be
393/// shared.
394fn immutable_witness_index_columns<P: PackedField>(
395	cols: Vec<WitnessIndexColumn<P>>,
396) -> Vec<ImmutableWitnessIndexColumn<P>> {
397	let mut result = Vec::<ImmutableWitnessIndexColumn<_>>::with_capacity(cols.len());
398	for col in cols {
399		result.push(ImmutableWitnessIndexColumn {
400			shape: col.shape,
401			data: match col.data {
402				WitnessDataMut::Owned(data) => data,
403				WitnessDataMut::SameAsIndex(index) => result[index].data,
404			},
405			column_id: col.column_id,
406		});
407	}
408	result
409}
410
411impl<'cs, 'alloc, F: TowerField, P: PackedField<Scalar = F>> TableWitnessIndex<'cs, 'alloc, P> {
412	pub(crate) fn new(
413		allocator: &'alloc HostBumpAllocator<'alloc, P>,
414		table: &'cs Table<F>,
415		size: usize,
416	) -> Result<Self, Error> {
417		if size == 0 {
418			return Err(Error::EmptyTable {
419				table_id: table.id(),
420			});
421		}
422
423		let log_capacity = table::log_capacity(size);
424		let packed_elem_log_bits = P::LOG_WIDTH + F::TOWER_LEVEL;
425
426		let mut cols = Vec::with_capacity(table.columns.len());
427		for ColumnInfo { id, col, shape, .. } in &table.columns {
428			let data: WitnessDataMut<P> = if let ColumnDef::Packed { col: source, .. } = col {
429				// Packed column reuses the witness of the one it is based on.
430				WitnessDataMut::SameAsIndex(source.table_index.0)
431			} else {
432				// Everything else has it's own column.
433				WitnessDataMut::new_owned(
434					allocator,
435					(shape.log_cell_size() + log_capacity).saturating_sub(packed_elem_log_bits),
436				)
437			};
438			cols.push(WitnessIndexColumn {
439				shape: *shape,
440				data,
441				column_id: *id,
442			});
443		}
444
445		// The minimum segment size is chosen such that the segment of each column is at least one
446		// underlier in size.
447		let min_log_segment_size = packed_elem_log_bits
448			- table
449				.columns
450				.iter()
451				.map(|col| col.shape.log_cell_size())
452				.fold(packed_elem_log_bits, |a, b| a.min(b));
453
454		// But, in case the minimum segment size is larger than the capacity, we lower it so the
455		// caller can get the full witness index in one segment. This is OK because the extra field
456		// elements in the smallest columns are just padding.
457		let min_log_segment_size = min_log_segment_size.min(log_capacity);
458
459		Ok(Self {
460			table,
461			cols,
462			size,
463			log_capacity,
464			min_log_segment_size,
465		})
466	}
467
468	pub fn table_id(&self) -> TableId {
469		self.table.id
470	}
471
472	pub fn capacity(&self) -> usize {
473		1 << self.log_capacity
474	}
475
476	/// Returns a witness index segment covering the entire table.
477	pub fn full_segment(&mut self) -> TableWitnessSegment<P> {
478		let cols = self
479			.cols
480			.iter_mut()
481			.map(|col| match &mut col.data {
482				WitnessDataMut::SameAsIndex(id) => RefCellData::SameAsIndex(*id),
483				WitnessDataMut::Owned(data) => RefCellData::Owned(RefCell::new(data)),
484			})
485			.collect();
486		TableWitnessSegment {
487			table: self.table,
488			cols,
489			log_size: self.log_capacity,
490			index: 0,
491		}
492	}
493
494	/// Fill a full table witness index using the given row data.
495	///
496	/// This function iterates through witness segments sequentially in a single thread.
497	pub fn fill_sequential<T: TableFiller<P>>(
498		&mut self,
499		table: &T,
500		rows: &[T::Event],
501	) -> Result<(), Error> {
502		let log_size = self.optimal_segment_size_heuristic();
503		self.fill_sequential_with_segment_size(table, rows, log_size)
504	}
505
506	/// Fill a full table witness index using the given row data.
507	///
508	/// This function iterates through witness segments in parallel in multiple threads.
509	pub fn fill_parallel<T>(&mut self, table: &T, rows: &[T::Event]) -> Result<(), Error>
510	where
511		T: TableFiller<P> + Sync,
512		T::Event: Sync,
513	{
514		let log_size = self.optimal_segment_size_heuristic();
515		self.fill_parallel_with_segment_size(table, rows, log_size)
516	}
517
518	fn optimal_segment_size_heuristic(&self) -> usize {
519		// As a heuristic, choose log_size so that the median column segment size is 4 KiB.
520		const TARGET_SEGMENT_LOG_BITS: usize = 12 + 3;
521
522		let n_cols = self.table.columns.len();
523		let median_col_log_bits = self
524			.table
525			.columns
526			.iter()
527			.map(|col| col.shape.log_cell_size())
528			.sorted()
529			.nth(n_cols / 2)
530			.unwrap_or_default();
531
532		TARGET_SEGMENT_LOG_BITS.saturating_sub(median_col_log_bits)
533	}
534
535	/// Fill a full table witness index using the given row data.
536	///
537	/// This function iterates through witness segments sequentially in a single thread.
538	pub fn fill_sequential_with_segment_size<T: TableFiller<P>>(
539		&mut self,
540		table: &T,
541		rows: &[T::Event],
542		log_size: usize,
543	) -> Result<(), Error> {
544		if rows.len() != self.size {
545			return Err(Error::IncorrectNumberOfTableEvents {
546				expected: self.size,
547				actual: rows.len(),
548			});
549		}
550
551		let mut segmented_view = TableWitnessSegmentedView::new(self, log_size);
552
553		// Overwrite log_size because it may need to get clamped.
554		let log_size = segmented_view.log_segment_size;
555		let segment_size = 1 << log_size;
556
557		// rows.len() equals self.size and self.size is check to be non-zero in the constructor
558		debug_assert_ne!(rows.len(), 0);
559		// number of chunks is rounded up
560		let n_chunks = (rows.len() - 1) / segment_size + 1;
561
562		let (full_chunk_segments, mut rest_segments) = segmented_view.split_at(n_chunks - 1);
563
564		// Fill segments of the table with full chunks
565		full_chunk_segments
566			.into_iter()
567			// by taking n_chunks - 1, we guarantee that all row chunks are full
568			.zip(rows.chunks(segment_size).take(n_chunks - 1))
569			.try_for_each(|(mut witness_segment, row_chunk)| {
570				table
571					.fill(row_chunk, &mut witness_segment)
572					.map_err(Error::TableFill)
573			})?;
574
575		// Fill the last segment. There may not be enough events to match the size of the segment,
576		// which is a pre-condition for TableFiller::fill. In that case, we clone the last event to
577		// pad the row_chunk. Since it's a clone, the filled witness should satisfy all row-wise
578		// constraints as long as all the given events do.
579		let row_chunk = &rows[(n_chunks - 1) * segment_size..];
580		let mut padded_row_chunk = Vec::new();
581		let row_chunk = if row_chunk.len() != segment_size {
582			padded_row_chunk.reserve(segment_size);
583			padded_row_chunk.extend_from_slice(row_chunk);
584			let last_event = row_chunk
585				.last()
586				.expect("row_chunk must be non-empty because of how n_chunk is calculated")
587				.clone();
588			padded_row_chunk.resize(segment_size, last_event);
589			&padded_row_chunk
590		} else {
591			row_chunk
592		};
593
594		let (partial_chunk_segments, rest_segments) = rest_segments.split_at(1);
595		let mut partial_chunk_segment_iter = partial_chunk_segments.into_iter();
596		let mut witness_segment = partial_chunk_segment_iter.next().expect(
597			"segmented_view.split_at called with 1 must return a view with exactly one segment",
598		);
599		table
600			.fill(row_chunk, &mut witness_segment)
601			.map_err(Error::TableFill)?;
602		assert!(partial_chunk_segment_iter.next().is_none());
603
604		// Finally, copy the last filled segment to the remaining segments. This should satisfy all
605		// row-wise constraints if the last segment does.
606		let last_segment_cols = witness_segment
607			.cols
608			.iter_mut()
609			.map(|col| match col {
610				RefCellData::Owned(data) => WitnessColumnInfo::Owned(data.get_mut()),
611				RefCellData::SameAsIndex(idx) => WitnessColumnInfo::SameAsIndex(*idx),
612			})
613			.collect::<Vec<_>>();
614
615		rest_segments.into_iter().for_each(|mut segment| {
616			for (dst_col, src_col) in iter::zip(&mut segment.cols, &last_segment_cols) {
617				if let (RefCellData::Owned(dst), WitnessColumnInfo::Owned(src)) = (dst_col, src_col)
618				{
619					dst.get_mut().copy_from_slice(src)
620				}
621			}
622		});
623
624		Ok(())
625	}
626
627	/// Fill a full table witness index using the given row data.
628	///
629	/// This function iterates through witness segments in parallel in multiple threads.
630	pub fn fill_parallel_with_segment_size<T>(
631		&mut self,
632		table: &T,
633		rows: &[T::Event],
634		log_size: usize,
635	) -> Result<(), Error>
636	where
637		T: TableFiller<P> + Sync,
638		T::Event: Sync,
639	{
640		if rows.len() != self.size {
641			return Err(Error::IncorrectNumberOfTableEvents {
642				expected: self.size,
643				actual: rows.len(),
644			});
645		}
646
647		// This implementation duplicates a lot of code with `fill_sequential_with_segment_size`.
648		// We could either refactor to deduplicate or just remove `fill_sequential` once this
649		// method is more battle-tested.
650
651		let mut segmented_view = TableWitnessSegmentedView::new(self, log_size);
652
653		// Overwrite log_size because it may need to get clamped.
654		let log_size = segmented_view.log_segment_size;
655		let segment_size = 1 << log_size;
656
657		// rows.len() equals self.size and self.size is check to be non-zero in the constructor
658		debug_assert_ne!(rows.len(), 0);
659		// number of chunks is rounded up
660		let n_chunks = (rows.len() - 1) / segment_size + 1;
661
662		let (full_chunk_segments, mut rest_segments) = segmented_view.split_at(n_chunks - 1);
663
664		// Fill segments of the table with full chunks
665		full_chunk_segments
666			.into_par_iter()
667			// by taking n_chunks - 1, we guarantee that all row chunks are full
668			.zip(rows.par_chunks(segment_size).take(n_chunks - 1))
669			.try_for_each(|(mut witness_segment, row_chunk)| {
670				table
671					.fill(row_chunk, &mut witness_segment)
672					.map_err(Error::TableFill)
673			})?;
674
675		// Fill the last segment. There may not be enough events to match the size of the segment,
676		// which is a pre-condition for TableFiller::fill. In that case, we clone the last event to
677		// pad the row_chunk. Since it's a clone, the filled witness should satisfy all row-wise
678		// constraints as long as all the given events do.
679		let row_chunk = &rows[(n_chunks - 1) * segment_size..];
680		let mut padded_row_chunk = Vec::new();
681		let row_chunk = if row_chunk.len() != segment_size {
682			padded_row_chunk.reserve(segment_size);
683			padded_row_chunk.extend_from_slice(row_chunk);
684			let last_event = row_chunk
685				.last()
686				.expect("row_chunk must be non-empty because of how n_chunk is calculated")
687				.clone();
688			padded_row_chunk.resize(segment_size, last_event);
689			&padded_row_chunk
690		} else {
691			row_chunk
692		};
693
694		let (partial_chunk_segments, rest_segments) = rest_segments.split_at(1);
695		let mut partial_chunk_segment_iter = partial_chunk_segments.into_iter();
696		let mut witness_segment = partial_chunk_segment_iter.next().expect(
697			"segmented_view.split_at called with 1 must return a view with exactly one segment",
698		);
699		table
700			.fill(row_chunk, &mut witness_segment)
701			.map_err(Error::TableFill)?;
702		assert!(partial_chunk_segment_iter.next().is_none());
703
704		// Finally, copy the last filled segment to the remaining segments. This should satisfy all
705		// row-wise constraints if the last segment does.
706		let last_segment_cols = witness_segment
707			.cols
708			.iter_mut()
709			.map(|col| match col {
710				RefCellData::Owned(data) => WitnessColumnInfo::Owned(data.get_mut()),
711				RefCellData::SameAsIndex(idx) => WitnessColumnInfo::SameAsIndex(*idx),
712			})
713			.collect::<Vec<_>>();
714
715		rest_segments.into_par_iter().for_each(|mut segment| {
716			for (dst_col, src_col) in iter::zip(&mut segment.cols, &last_segment_cols) {
717				if let (RefCellData::Owned(dst), WitnessColumnInfo::Owned(src)) = (dst_col, src_col)
718				{
719					dst.get_mut().copy_from_slice(src)
720				}
721			}
722		});
723
724		Ok(())
725	}
726
727	/// Returns an iterator over segments of witness index rows.
728	///
729	/// This method clamps the segment size, requested as `log_size`, to a minimum of
730	/// `self.min_log_segment_size()` and a maximum of `self.log_capacity()`. The actual segment
731	/// size can be queried on the items yielded by the iterator.
732	pub fn segments(&mut self, log_size: usize) -> impl Iterator<Item = TableWitnessSegment<P>> {
733		TableWitnessSegmentedView::new(self, log_size).into_iter()
734	}
735
736	pub fn par_segments(
737		&mut self,
738		log_size: usize,
739	) -> impl IndexedParallelIterator<Item = TableWitnessSegment<'_, P>> {
740		TableWitnessSegmentedView::new(self, log_size).into_par_iter()
741	}
742}
743
744/// A view over a table witness that splits the table into segments.
745///
746/// The purpose of this struct is to implement the `split_at` method, which safely splits the view
747/// of the table witness vertically. This aids in the implementation of `fill_sequential` and
748/// `fill_parallel`.
749#[derive(Debug)]
750struct TableWitnessSegmentedView<'a, P = PackedType<OptimalUnderlier, B128>>
751where
752	P: PackedField,
753	P::Scalar: TowerField,
754{
755	table: &'a Table<P::Scalar>,
756	cols: Vec<WitnessColumnInfo<(&'a mut [P], usize)>>,
757	log_segment_size: usize,
758	start_index: usize,
759	n_segments: usize,
760}
761
762impl<'a, F: TowerField, P: PackedField<Scalar = F>> TableWitnessSegmentedView<'a, P> {
763	fn new(witness: &'a mut TableWitnessIndex<P>, log_segment_size: usize) -> Self {
764		// Clamp the segment size.
765		let log_segment_size = log_segment_size
766			.min(witness.log_capacity)
767			.max(witness.min_log_segment_size);
768
769		let cols = witness
770			.cols
771			.iter_mut()
772			.map(|col| match &mut col.data {
773				WitnessColumnInfo::Owned(data) => {
774					let chunk_size = (log_segment_size + col.shape.log_cell_size())
775						.saturating_sub(P::LOG_WIDTH + F::TOWER_LEVEL);
776					WitnessColumnInfo::Owned((&mut **data, 1 << chunk_size))
777				}
778				WitnessColumnInfo::SameAsIndex(id) => WitnessColumnInfo::SameAsIndex(*id),
779			})
780			.collect::<Vec<_>>();
781		Self {
782			table: witness.table,
783			cols,
784			log_segment_size,
785			start_index: 0,
786			n_segments: 1 << (witness.log_capacity - log_segment_size),
787		}
788	}
789
790	fn split_at(
791		&mut self,
792		index: usize,
793	) -> (TableWitnessSegmentedView<P>, TableWitnessSegmentedView<P>) {
794		assert!(index <= self.n_segments);
795		let (cols_0, cols_1) = self
796			.cols
797			.iter_mut()
798			.map(|col| match col {
799				WitnessColumnInfo::Owned((data, chunk_size)) => {
800					let (data_0, data_1) = data.split_at_mut(*chunk_size * index);
801					(
802						WitnessColumnInfo::Owned((data_0, *chunk_size)),
803						WitnessColumnInfo::Owned((data_1, *chunk_size)),
804					)
805				}
806				WitnessColumnInfo::SameAsIndex(id) => {
807					(WitnessColumnInfo::SameAsIndex(*id), WitnessColumnInfo::SameAsIndex(*id))
808				}
809			})
810			.unzip();
811		(
812			TableWitnessSegmentedView {
813				table: self.table,
814				cols: cols_0,
815				log_segment_size: self.log_segment_size,
816				start_index: self.start_index,
817				n_segments: index,
818			},
819			TableWitnessSegmentedView {
820				table: self.table,
821				cols: cols_1,
822				log_segment_size: self.log_segment_size,
823				start_index: self.start_index + index,
824				n_segments: self.n_segments - index,
825			},
826		)
827	}
828
829	fn into_iter(self) -> impl Iterator<Item = TableWitnessSegment<'a, P>> {
830		let TableWitnessSegmentedView {
831			table,
832			cols,
833			log_segment_size,
834			start_index,
835			n_segments,
836		} = self;
837
838		if n_segments == 0 {
839			itertools::Either::Left(iter::empty())
840		} else {
841			let iter = MultiIterator::new(
842				cols.into_iter()
843					.map(|col| match col {
844						WitnessColumnInfo::Owned((data, chunk_size)) => itertools::Either::Left(
845							data.chunks_mut(chunk_size)
846								.map(|chunk| RefCellData::Owned(RefCell::new(chunk))),
847						),
848						WitnessColumnInfo::SameAsIndex(id) => itertools::Either::Right(
849							iter::repeat_n(id, n_segments).map(RefCellData::SameAsIndex),
850						),
851					})
852					.collect(),
853			)
854			.enumerate()
855			.map(move |(index, cols)| TableWitnessSegment {
856				table,
857				cols,
858				log_size: log_segment_size,
859				index: start_index + index,
860			});
861			itertools::Either::Right(iter)
862		}
863	}
864
865	fn into_par_iter(self) -> impl IndexedParallelIterator<Item = TableWitnessSegment<'a, P>> {
866		let TableWitnessSegmentedView {
867			table,
868			cols,
869			log_segment_size,
870			start_index,
871			n_segments,
872		} = self;
873
874		// This implementation uses unsafe code to iterate the segments of the view. A fully safe
875		// implementation is also possible, which would look similar to that of
876		// `rayon::slice::ChunksMut`. That just requires more code and doesn't seem justified.
877
878		// TODO: clippy error (clippy::mut_from_ref): mutable borrow from immutable input(s)
879		#[allow(clippy::mut_from_ref)]
880		unsafe fn cast_slice_ref_to_mut<T>(slice: &[T]) -> &mut [T] {
881			unsafe { slice::from_raw_parts_mut(slice.as_ptr() as *mut T, slice.len()) }
882		}
883
884		// Convert cols with mutable references into cols with const refs so that they can be
885		// cloned. Within the loop, we unsafely cast back to mut refs.
886		let cols = cols
887			.into_iter()
888			.map(|col| -> WitnessColumnInfo<(&'a [P], usize)> {
889				match col {
890					WitnessColumnInfo::Owned((data, chunk_size)) => {
891						WitnessColumnInfo::Owned((data, chunk_size))
892					}
893					WitnessColumnInfo::SameAsIndex(id) => WitnessColumnInfo::SameAsIndex(id),
894				}
895			})
896			.collect::<Vec<_>>();
897
898		(0..n_segments).into_par_iter().map(move |i| {
899			let col_strides = cols
900				.iter()
901				.map(|col| match col {
902					WitnessColumnInfo::SameAsIndex(id) => RefCellData::SameAsIndex(*id),
903					WitnessColumnInfo::Owned((data, chunk_size)) => {
904						RefCellData::Owned(RefCell::new(unsafe {
905							// Safety: The function borrows self mutably, so we have mutable access
906							// to all columns and thus none can be borrowed by anyone else. The
907							// loop is constructed to borrow disjoint segments of each column -- if
908							// this loop were transposed, we would use `chunks_mut`.
909							cast_slice_ref_to_mut(&data[i * chunk_size..(i + 1) * chunk_size])
910						}))
911					}
912				})
913				.collect();
914			TableWitnessSegment {
915				table,
916				cols: col_strides,
917				log_size: log_segment_size,
918				index: start_index + i,
919			}
920		})
921	}
922}
923
924/// A vertical segment of a table witness index.
925///
926/// This provides runtime-checked access to slices of the witness columns. This is used separately
927/// from [`TableWitnessIndex`] so that witness population can be parallelized over segments.
928#[derive(Debug, CopyGetters)]
929pub struct TableWitnessSegment<'a, P = PackedType<OptimalUnderlier, B128>>
930where
931	P: PackedField,
932	P::Scalar: TowerField,
933{
934	table: &'a Table<P::Scalar>,
935	/// Stores the actual data for the witness columns.
936	///
937	/// The order of the columns corresponds to the same order as defined in the table.
938	cols: Vec<RefCellData<'a, P>>,
939	#[get_copy = "pub"]
940	log_size: usize,
941	/// The index of the segment in the segmented table witness.
942	#[get_copy = "pub"]
943	index: usize,
944}
945
946impl<'a, F: TowerField, P: PackedField<Scalar = F>> TableWitnessSegment<'a, P> {
947	pub fn get<FSub: TowerField, const V: usize>(
948		&self,
949		col: Col<FSub, V>,
950	) -> Result<Ref<[PackedSubfield<P, FSub>]>, Error>
951	where
952		P: PackedExtension<FSub>,
953	{
954		if col.table_id != self.table.id() {
955			return Err(Error::TableMismatch {
956				column_table_id: col.table_id,
957				witness_table_id: self.table.id(),
958			});
959		}
960
961		let col = self
962			.get_col_data(col.id())
963			.ok_or_else(|| Error::MissingColumn(col.id()))?;
964		let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
965		Ok(Ref::map(col_ref, |packed| PackedExtension::cast_bases(packed)))
966	}
967
968	pub fn get_mut<FSub: TowerField, const V: usize>(
969		&self,
970		col: Col<FSub, V>,
971	) -> Result<RefMut<[PackedSubfield<P, FSub>]>, Error>
972	where
973		P: PackedExtension<FSub>,
974		F: ExtensionField<FSub>,
975	{
976		if col.table_id != self.table.id() {
977			return Err(Error::TableMismatch {
978				column_table_id: col.table_id,
979				witness_table_id: self.table.id(),
980			});
981		}
982
983		let col = self
984			.get_col_data(col.id())
985			.ok_or_else(|| Error::MissingColumn(col.id()))?;
986		let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
987		Ok(RefMut::map(col_ref, |packed| PackedExtension::cast_bases_mut(packed)))
988	}
989
990	pub fn get_scalars<FSub: TowerField, const V: usize>(
991		&self,
992		col: Col<FSub, V>,
993	) -> Result<Ref<[FSub]>, Error>
994	where
995		P: PackedExtension<FSub>,
996		F: ExtensionField<FSub>,
997		PackedSubfield<P, FSub>: PackedFieldIndexable,
998	{
999		self.get(col)
1000			.map(|packed| Ref::map(packed, <PackedSubfield<P, FSub>>::unpack_scalars))
1001	}
1002
1003	pub fn get_scalars_mut<FSub: TowerField, const V: usize>(
1004		&self,
1005		col: Col<FSub, V>,
1006	) -> Result<RefMut<[FSub]>, Error>
1007	where
1008		P: PackedExtension<FSub>,
1009		F: ExtensionField<FSub>,
1010		PackedSubfield<P, FSub>: PackedFieldIndexable,
1011	{
1012		self.get_mut(col)
1013			.map(|packed| RefMut::map(packed, <PackedSubfield<P, FSub>>::unpack_scalars_mut))
1014	}
1015
1016	pub fn get_as<T: Pod, FSub: TowerField, const V: usize>(
1017		&self,
1018		col: Col<FSub, V>,
1019	) -> Result<Ref<[T]>, Error>
1020	where
1021		P: PackedExtension<FSub> + PackedFieldIndexable,
1022		F: ExtensionField<FSub> + Pod,
1023	{
1024		let col = self
1025			.get_col_data(col.id())
1026			.ok_or_else(|| Error::MissingColumn(col.id()))?;
1027		let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
1028		Ok(Ref::map(col_ref, |col| must_cast_slice(P::unpack_scalars(col))))
1029	}
1030
1031	pub fn get_mut_as<T: Pod, FSub: TowerField, const V: usize>(
1032		&self,
1033		col: Col<FSub, V>,
1034	) -> Result<RefMut<[T]>, Error>
1035	where
1036		P: PackedExtension<FSub> + PackedFieldIndexable,
1037		F: ExtensionField<FSub> + Pod,
1038	{
1039		if col.table_id != self.table.id() {
1040			return Err(Error::TableMismatch {
1041				column_table_id: col.table_id,
1042				witness_table_id: self.table.id(),
1043			});
1044		}
1045
1046		let col = self
1047			.get_col_data(col.id())
1048			.ok_or_else(|| Error::MissingColumn(col.id()))?;
1049		let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
1050		Ok(RefMut::map(col_ref, |col| must_cast_slice_mut(P::unpack_scalars_mut(col))))
1051	}
1052
1053	/// Evaluate an expression over columns that are assumed to be already populated.
1054	///
1055	/// This function evaluates an expression over the columns in the segment and returns an
1056	/// iterator over the packed elements. This borrows the columns segments it reads from, so
1057	/// they must not be borrowed mutably elsewhere (which is possible due to runtime-checked
1058	/// column borrowing).
1059	pub fn eval_expr<FSub: TowerField, const V: usize>(
1060		&self,
1061		expr: &Expr<FSub, V>,
1062	) -> Result<impl Iterator<Item = PackedSubfield<P, FSub>> + use<FSub, V, F, P>, Error>
1063	where
1064		P: PackedExtension<FSub>,
1065	{
1066		let log_vals_per_row = checked_log_2(V);
1067
1068		let partition =
1069			self.table
1070				.partitions
1071				.get(log_vals_per_row)
1072				.ok_or_else(|| Error::MissingPartition {
1073					table_id: self.table.id(),
1074					log_vals_per_row,
1075				})?;
1076		let expr_circuit = ArithCircuit::from(expr.expr());
1077
1078		let col_refs = partition
1079			.columns
1080			.iter()
1081			.zip(expr_circuit.vars_usage())
1082			.map(|(col_id, used)| {
1083				used.then(|| {
1084					assert_eq!(FSub::TOWER_LEVEL, self.table[*col_id].shape.tower_height);
1085					let col = self
1086						.get_col_data(*col_id)
1087						.ok_or_else(|| Error::MissingColumn(*col_id))?;
1088					let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
1089					Ok::<_, Error>(Ref::map(col_ref, |packed| PackedExtension::cast_bases(packed)))
1090				})
1091				.transpose()
1092			})
1093			.collect::<Result<Vec<_>, _>>()?;
1094
1095		let log_packed_elems =
1096			(self.log_size + log_vals_per_row).saturating_sub(<PackedSubfield<P, FSub>>::LOG_WIDTH);
1097
1098		// Batch evaluate requires value slices even for the indices it will not read.
1099		let dummy_col = zeroed_vec(1 << log_packed_elems);
1100
1101		let cols = col_refs
1102			.iter()
1103			.map(|col| col.as_ref().map(|col_ref| &**col_ref).unwrap_or(&dummy_col))
1104			.collect::<Vec<_>>();
1105		let cols = RowsBatchRef::new(&cols, 1 << log_packed_elems);
1106
1107		// REVIEW: This could be inefficient with very large segments because batch evaluation
1108		// allocates more memory, proportional to the size of the segment. Because of how segments
1109		// get split up in practice, it's not a problem yet. If we see stack overflows, we should
1110		// split up the evaluation into multiple batches.
1111		let mut evals = zeroed_vec(1 << log_packed_elems);
1112		ArithCircuitPoly::new(expr_circuit).batch_evaluate(&cols, &mut evals)?;
1113		Ok(evals.into_iter())
1114	}
1115
1116	pub fn size(&self) -> usize {
1117		1 << self.log_size
1118	}
1119
1120	fn get_col_data(&self, column_id: ColumnId) -> Option<&RefCell<&'a mut [P]>> {
1121		let column_index = column_id.table_index.0;
1122		self.get_col_data_by_index(column_index)
1123	}
1124
1125	fn get_col_data_by_index(&self, index: usize) -> Option<&RefCell<&'a mut [P]>> {
1126		match self.cols.get(index) {
1127			Some(RefCellData::Owned(data)) => Some(data),
1128			Some(RefCellData::SameAsIndex(index)) => self.get_col_data_by_index(*index),
1129			None => None,
1130		}
1131	}
1132}
1133
1134impl<'a, P> TableWitnessSegment<'a, P>
1135where
1136	P: PackedField<Scalar: TowerField>
1137		+ PackedExtension<B1>
1138		+ PackedExtension<B8>
1139		+ PackedExtension<B16>
1140		+ PackedExtension<B32>
1141		+ PackedExtension<B64>
1142		+ PackedExtension<B128>,
1143{
1144	/// For a given column index within a table, return the immutable upcasted witness data
1145	pub fn get_dyn(
1146		&self,
1147		col_id: ColumnId,
1148	) -> Result<Box<dyn WitnessColView<P::Scalar> + '_>, Error> {
1149		let col = self
1150			.get_col_data(col_id)
1151			.ok_or_else(|| Error::MissingColumn(col_id))?;
1152		let col_ref = col.try_borrow().map_err(Error::WitnessBorrow)?;
1153		let tower_level = self.table[col_id].shape.tower_height;
1154		let ret: Box<dyn WitnessColView<_>> = match tower_level {
1155			0 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1156				PackedExtension::<B1>::cast_bases(packed)
1157			}))),
1158			3 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1159				PackedExtension::<B8>::cast_bases(packed)
1160			}))),
1161			4 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1162				PackedExtension::<B16>::cast_bases(packed)
1163			}))),
1164			5 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1165				PackedExtension::<B32>::cast_bases(packed)
1166			}))),
1167			6 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1168				PackedExtension::<B64>::cast_bases(packed)
1169			}))),
1170			7 => Box::new(WitnessColViewImpl(Ref::map(col_ref, |packed| {
1171				PackedExtension::<B128>::cast_bases(packed)
1172			}))),
1173			_ => panic!("tower_level must be in the range [0, 7]"),
1174		};
1175		Ok(ret)
1176	}
1177
1178	/// For a given column index within a table, return the mutable witness data.
1179	pub fn get_dyn_mut(
1180		&self,
1181		col_id: ColumnId,
1182	) -> Result<Box<dyn WitnessColViewMut<P::Scalar> + '_>, Error> {
1183		let col = self
1184			.get_col_data(col_id)
1185			.ok_or_else(|| Error::MissingColumn(col_id))?;
1186		let col_ref = col.try_borrow_mut().map_err(Error::WitnessBorrowMut)?;
1187		let tower_level = self.table[col_id].shape.tower_height;
1188		let ret: Box<dyn WitnessColViewMut<_>> = match tower_level {
1189			0 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1190				PackedExtension::<B1>::cast_bases_mut(packed)
1191			}))),
1192			3 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1193				PackedExtension::<B8>::cast_bases_mut(packed)
1194			}))),
1195			4 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1196				PackedExtension::<B16>::cast_bases_mut(packed)
1197			}))),
1198			5 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1199				PackedExtension::<B32>::cast_bases_mut(packed)
1200			}))),
1201			6 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1202				PackedExtension::<B64>::cast_bases_mut(packed)
1203			}))),
1204			7 => Box::new(WitnessColViewImpl(RefMut::map(col_ref, |packed| {
1205				PackedExtension::<B128>::cast_bases_mut(packed)
1206			}))),
1207			_ => panic!("tower_level must be in the range [0, 7]"),
1208		};
1209		Ok(ret)
1210	}
1211}
1212
1213/// Type erased interface for viewing witness columns. Note that `F` will be an extension field of
1214/// the underlying column's field.
1215pub trait WitnessColView<F> {
1216	/// Returns the scalar at a given index
1217	fn get(&self, index: usize) -> F;
1218
1219	/// The size of the scalar elements in this column.
1220	fn size(&self) -> usize;
1221}
1222
1223/// Similar to [`WitnessColView`], for mutating witness columns.
1224pub trait WitnessColViewMut<F>: WitnessColView<F> {
1225	/// Modifies the upcasted scalar at a given index.
1226	fn set(&mut self, index: usize, val: F) -> Result<(), Error>;
1227}
1228
1229struct WitnessColViewImpl<Data>(Data);
1230
1231impl<F, P, Data> WitnessColView<F> for WitnessColViewImpl<Data>
1232where
1233	F: ExtensionField<P::Scalar>,
1234	P: PackedField,
1235	Data: Deref<Target = [P]>,
1236{
1237	fn get(&self, index: usize) -> F {
1238		get_packed_slice(&self.0, index).into()
1239	}
1240
1241	fn size(&self) -> usize {
1242		self.0.len() << P::LOG_WIDTH
1243	}
1244}
1245
1246impl<F, P, Data> WitnessColViewMut<F> for WitnessColViewImpl<Data>
1247where
1248	F: ExtensionField<P::Scalar>,
1249	P: PackedField,
1250	Data: DerefMut<Target = [P]>,
1251{
1252	fn set(&mut self, index: usize, val: F) -> Result<(), Error> {
1253		let subfield_val = val.try_into().map_err(|_| Error::FieldElementTooBig)?;
1254		set_packed_slice(&mut self.0, index, subfield_val);
1255		Ok(())
1256	}
1257}
1258
1259/// A struct that can populate segments of a table witness using row descriptors.
1260pub trait TableFiller<P = PackedType<OptimalUnderlier, B128>>
1261where
1262	P: PackedField,
1263	P::Scalar: TowerField,
1264{
1265	/// A struct that specifies the row contents.
1266	type Event: Clone;
1267
1268	/// Returns the table ID.
1269	fn id(&self) -> TableId;
1270
1271	/// Fill the table witness with data derived from the given rows.
1272	///
1273	/// ## Preconditions
1274	///
1275	/// * the number of elements in `rows` must equal `witness.size()`
1276	fn fill(
1277		&self,
1278		rows: &[Self::Event],
1279		witness: &mut TableWitnessSegment<P>,
1280	) -> anyhow::Result<()>;
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285	use std::{array, iter::repeat_with};
1286
1287	use assert_matches::assert_matches;
1288	use binius_compute::cpu::alloc::CpuComputeAllocator;
1289	use binius_core::oracle::{OracleId, SymbolicMultilinearOracleSet};
1290	use binius_field::{
1291		arch::{OptimalUnderlier128b, OptimalUnderlier256b},
1292		packed::{len_packed_slice, set_packed_slice},
1293	};
1294	use rand::{Rng, SeedableRng, rngs::StdRng};
1295
1296	use super::*;
1297	use crate::builder::{
1298		ConstraintSystem, TableBuilder,
1299		types::{B1, B8, B16, B32},
1300	};
1301
1302	#[test]
1303	fn test_table_witness_borrows() {
1304		let table_id = 0;
1305		let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1306		let mut table = TableBuilder::new(&mut inner_table);
1307		let col0 = table.add_committed::<B1, 8>("col0");
1308		let col1 = table.add_committed::<B1, 32>("col1");
1309		let col2 = table.add_committed::<B8, 1>("col2");
1310		let col3 = table.add_committed::<B32, 1>("col3");
1311
1312		let mut allocator = CpuComputeAllocator::new(1 << 12);
1313		let allocator = allocator.into_bump_allocator();
1314		let table_size = 64;
1315		let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1316			&allocator,
1317			&inner_table,
1318			table_size,
1319		)
1320		.unwrap();
1321		let segment = index.full_segment();
1322
1323		{
1324			let col0_ref0 = segment.get(col0).unwrap();
1325			let _col0_ref1 = segment.get(col0).unwrap();
1326			assert_matches!(segment.get_mut(col0), Err(Error::WitnessBorrowMut(_)));
1327			drop(col0_ref0);
1328
1329			let col1_ref = segment.get_mut(col1).unwrap();
1330			assert_matches!(segment.get(col1), Err(Error::WitnessBorrow(_)));
1331			drop(col1_ref);
1332		}
1333
1334		assert_eq!(len_packed_slice(&segment.get_mut(col0).unwrap()), 1 << 9);
1335		assert_eq!(len_packed_slice(&segment.get_mut(col1).unwrap()), 1 << 11);
1336		assert_eq!(len_packed_slice(&segment.get_mut(col2).unwrap()), 1 << 6);
1337		assert_eq!(len_packed_slice(&segment.get_mut(col3).unwrap()), 1 << 6);
1338	}
1339
1340	#[test]
1341	fn test_table_witness_segments() {
1342		let table_id = 0;
1343		let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1344		let mut table = TableBuilder::new(&mut inner_table);
1345		let col0 = table.add_committed::<B1, 8>("col0");
1346		let col1 = table.add_committed::<B1, 32>("col1");
1347		let col2 = table.add_committed::<B8, 1>("col2");
1348		let col3 = table.add_committed::<B32, 1>("col3");
1349
1350		let mut allocator = CpuComputeAllocator::new(1 << 12);
1351		let allocator = allocator.into_bump_allocator();
1352		let table_size = 64;
1353		let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1354			&allocator,
1355			&inner_table,
1356			table_size,
1357		)
1358		.unwrap();
1359
1360		assert_eq!(index.min_log_segment_size(), 4);
1361		let mut iter = index.segments(5);
1362		let seg0 = iter.next().unwrap();
1363		let seg1 = iter.next().unwrap();
1364		assert!(iter.next().is_none());
1365
1366		assert_eq!(len_packed_slice(&seg0.get_mut(col0).unwrap()), 1 << 8);
1367		assert_eq!(len_packed_slice(&seg0.get_mut(col1).unwrap()), 1 << 10);
1368		assert_eq!(len_packed_slice(&seg0.get_mut(col2).unwrap()), 1 << 5);
1369		assert_eq!(len_packed_slice(&seg0.get_mut(col3).unwrap()), 1 << 5);
1370
1371		assert_eq!(len_packed_slice(&seg1.get_mut(col0).unwrap()), 1 << 8);
1372		assert_eq!(len_packed_slice(&seg1.get_mut(col1).unwrap()), 1 << 10);
1373		assert_eq!(len_packed_slice(&seg1.get_mut(col2).unwrap()), 1 << 5);
1374		assert_eq!(len_packed_slice(&seg1.get_mut(col3).unwrap()), 1 << 5);
1375	}
1376
1377	#[test]
1378	fn test_eval_expr() {
1379		let table_id = 0;
1380		let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1381		let mut table = TableBuilder::new(&mut inner_table);
1382		let col0 = table.add_committed::<B8, 2>("col0");
1383		let col1 = table.add_committed::<B8, 2>("col1");
1384		let col2 = table.add_committed::<B8, 2>("col2");
1385
1386		let mut allocator = CpuComputeAllocator::new(1 << 12);
1387		let allocator = allocator.into_bump_allocator();
1388		let table_size = 1 << 6;
1389		let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1390			&allocator,
1391			&inner_table,
1392			table_size,
1393		)
1394		.unwrap();
1395
1396		let segment = index.full_segment();
1397		assert_eq!(segment.log_size(), 6);
1398
1399		// Fill the columns with a deterministic pattern.
1400		{
1401			let mut col0 = segment.get_mut(col0).unwrap();
1402			let mut col1 = segment.get_mut(col1).unwrap();
1403			let mut col2 = segment.get_mut(col2).unwrap();
1404
1405			// 3 = 6 (log table size) + 1 (log values per row) - 4 (log packed field width)
1406			let expected_slice_len = 1 << 3;
1407			assert_eq!(col0.len(), expected_slice_len);
1408			assert_eq!(col0.len(), expected_slice_len);
1409			assert_eq!(col0.len(), expected_slice_len);
1410
1411			for i in 0..expected_slice_len * <PackedType<OptimalUnderlier128b, B8>>::WIDTH {
1412				set_packed_slice(&mut col0, i, B8::new(i as u8) + B8::new(0x00));
1413				set_packed_slice(&mut col1, i, B8::new(i as u8) + B8::new(0x40));
1414				set_packed_slice(&mut col2, i, B8::new(i as u8) + B8::new(0x80));
1415			}
1416		}
1417
1418		let evals = segment.eval_expr(&(col0 * col1 - col2)).unwrap();
1419		for (i, eval_i) in evals
1420			.into_iter()
1421			.flat_map(PackedField::into_iter)
1422			.enumerate()
1423		{
1424			let col0_val = B8::new(i as u8) + B8::new(0x00);
1425			let col1_val = B8::new(i as u8) + B8::new(0x40);
1426			let col2_val = B8::new(i as u8) + B8::new(0x80);
1427			assert_eq!(eval_i, col0_val * col1_val - col2_val);
1428		}
1429	}
1430
1431	#[test]
1432	fn test_eval_expr_different_cols() {
1433		let table_id = 0;
1434		let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1435		let mut table = TableBuilder::new(&mut inner_table);
1436		let col0 = table.add_committed::<B32, 1>("col1");
1437		let col1 = table.add_committed::<B8, 8>("col2");
1438		let col2 = table.add_committed::<B16, 4>("col3");
1439		let col3 = table.add_committed::<B8, 8>("col4");
1440
1441		let mut allocator = CpuComputeAllocator::new(1 << 12);
1442		let allocator = allocator.into_bump_allocator();
1443		let table_size = 4;
1444		let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(
1445			&allocator,
1446			&inner_table,
1447			table_size,
1448		)
1449		.unwrap();
1450
1451		let segment = index.full_segment();
1452
1453		// Fill the columns with a deterministic pattern.
1454		{
1455			let mut col0: RefMut<'_, [u32]> = segment.get_mut_as(col0).unwrap();
1456			let mut col1: RefMut<'_, [[u8; 8]]> = segment.get_mut_as(col1).unwrap();
1457			let mut col2: RefMut<'_, [[u16; 4]]> = segment.get_mut_as(col2).unwrap();
1458			let mut col3: RefMut<'_, [[u8; 8]]> = segment.get_mut_as(col3).unwrap();
1459
1460			col0[0] = 0x40;
1461			col1[0] = [0x45, 0, 0, 0, 0, 0, 0, 0];
1462			col2[0] = [0, 0, 0, 0x80];
1463			col3[0] = [0x85, 0, 0, 0, 0, 0, 0, 00];
1464		}
1465
1466		let evals = segment.eval_expr(&(col1 + col3)).unwrap();
1467		for (i, eval_i) in evals
1468			.into_iter()
1469			.flat_map(PackedField::into_iter)
1470			.enumerate()
1471		{
1472			if i == 0 {
1473				assert_eq!(eval_i, B8::new(0x45) + B8::new(0x85));
1474			} else {
1475				assert_eq!(eval_i, B8::new(0x00));
1476			}
1477		}
1478	}
1479
1480	#[test]
1481	fn test_small_tables() {
1482		let table_id = 0;
1483		let mut inner_table = Table::<B128>::new(table_id, "table".to_string());
1484		let mut table = TableBuilder::new(&mut inner_table);
1485		let _col0 = table.add_committed::<B1, 8>("col0");
1486		let _col1 = table.add_committed::<B1, 32>("col1");
1487		let _col2 = table.add_committed::<B8, 1>("col2");
1488		let _col3 = table.add_committed::<B32, 1>("col3");
1489
1490		let mut allocator = CpuComputeAllocator::new(1 << 12);
1491		let allocator = allocator.into_bump_allocator();
1492		let table_size = 7;
1493		let mut index = TableWitnessIndex::<PackedType<OptimalUnderlier256b, B128>>::new(
1494			&allocator,
1495			&inner_table,
1496			table_size,
1497		)
1498		.unwrap();
1499
1500		assert_eq!(index.log_capacity(), 3);
1501		assert_eq!(index.min_log_segment_size(), 3);
1502
1503		let mut iter = index.segments(5);
1504		// Check that the segment size is clamped to the capacity.
1505		assert_eq!(iter.next().unwrap().log_size(), 3);
1506		assert!(iter.next().is_none());
1507		drop(iter);
1508
1509		let mut iter = index.segments(2);
1510		// Check that the segment size is clamped to the minimum segment size.
1511		assert_eq!(iter.next().unwrap().log_size(), 3);
1512		assert!(iter.next().is_none());
1513		drop(iter);
1514	}
1515
1516	struct TestTable {
1517		id: TableId,
1518		col0: Col<B32>,
1519		col1: Col<B32>,
1520	}
1521
1522	impl TestTable {
1523		fn new(cs: &mut ConstraintSystem) -> Self {
1524			let mut table = cs.add_table("test");
1525
1526			let col0 = table.add_committed("col0");
1527			let col1 = table.add_computed("col1", col0 * col0 + B32::new(0x03));
1528
1529			Self {
1530				id: table.id(),
1531				col0,
1532				col1,
1533			}
1534		}
1535	}
1536
1537	impl TableFiller<PackedType<OptimalUnderlier128b, B128>> for TestTable {
1538		type Event = u32;
1539
1540		fn id(&self) -> TableId {
1541			self.id
1542		}
1543
1544		fn fill(
1545			&self,
1546			rows: &[Self::Event],
1547			witness: &mut TableWitnessSegment<PackedType<OptimalUnderlier128b, B128>>,
1548		) -> anyhow::Result<()> {
1549			let mut col0 = witness.get_scalars_mut(self.col0)?;
1550			let mut col1 = witness.get_scalars_mut(self.col1)?;
1551			for (i, &val) in rows.iter().enumerate() {
1552				col0[i] = B32::new(val);
1553				col1[i] = col0[i].pow(2) + B32::new(0x03);
1554			}
1555			Ok(())
1556		}
1557	}
1558
1559	#[test]
1560	fn test_fill_sequential_with_incomplete_events() {
1561		let mut cs = ConstraintSystem::new();
1562		let test_table = TestTable::new(&mut cs);
1563
1564		let mut allocator = CpuComputeAllocator::new(1 << 12);
1565		let allocator = allocator.into_bump_allocator();
1566
1567		let table_size = 11;
1568		let mut index = WitnessIndex::new(&cs, &allocator);
1569		let table_index = index.init_table(test_table.id(), table_size).unwrap();
1570
1571		let mut rng = StdRng::seed_from_u64(0);
1572		let rows = repeat_with(|| rng.random())
1573			.take(table_size)
1574			.collect::<Vec<_>>();
1575
1576		// 2^4 is the next power of two after 1, the table size.
1577		assert_eq!(table_index.log_capacity(), 4);
1578
1579		// 2^2 B32 values fit into a 128-bit underlier.
1580		assert_eq!(table_index.min_log_segment_size(), 2);
1581
1582		// Assert that fill_sequential validates the number of events..
1583		assert_matches!(
1584			table_index.fill_sequential_with_segment_size(&test_table, &rows[1..], 2),
1585			Err(Error::IncorrectNumberOfTableEvents { .. })
1586		);
1587
1588		table_index
1589			.fill_sequential_with_segment_size(&test_table, &rows, 2)
1590			.unwrap();
1591
1592		let segment = table_index.full_segment();
1593		let col0 = segment.get_scalars(test_table.col0).unwrap();
1594		for i in 0..11 {
1595			assert_eq!(col0[i].val(), rows[i]);
1596		}
1597		assert_eq!(col0[11].val(), rows[10]);
1598		assert_eq!(col0[12].val(), rows[8]);
1599		assert_eq!(col0[13].val(), rows[9]);
1600		assert_eq!(col0[14].val(), rows[10]);
1601		assert_eq!(col0[15].val(), rows[10]);
1602	}
1603
1604	#[test]
1605	fn test_fill_parallel_with_incomplete_events() {
1606		let mut cs = ConstraintSystem::new();
1607		let test_table = TestTable::new(&mut cs);
1608
1609		let mut allocator = CpuComputeAllocator::new(1 << 12);
1610		let allocator = allocator.into_bump_allocator();
1611
1612		let table_size = 11;
1613		let mut index = WitnessIndex::new(&cs, &allocator);
1614		let table_index = index.init_table(test_table.id(), table_size).unwrap();
1615
1616		let mut rng = StdRng::seed_from_u64(0);
1617		let rows = repeat_with(|| rng.random())
1618			.take(table_size)
1619			.collect::<Vec<_>>();
1620
1621		// 2^4 is the next power of two after 1, the table size.
1622		assert_eq!(table_index.log_capacity(), 4);
1623
1624		// 2^2 B32 values fit into a 128-bit underlier.
1625		assert_eq!(table_index.min_log_segment_size(), 2);
1626
1627		// Assert that fill_sequential validates the number of events..
1628		assert_matches!(
1629			table_index.fill_parallel_with_segment_size(&test_table, &rows[1..], 2),
1630			Err(Error::IncorrectNumberOfTableEvents { .. })
1631		);
1632
1633		table_index
1634			.fill_parallel_with_segment_size(&test_table, &rows, 2)
1635			.unwrap();
1636
1637		let segment = table_index.full_segment();
1638		let col0 = segment.get_scalars(test_table.col0).unwrap();
1639		for i in 0..11 {
1640			assert_eq!(col0[i].val(), rows[i]);
1641		}
1642		assert_eq!(col0[11].val(), rows[10]);
1643		assert_eq!(col0[12].val(), rows[8]);
1644		assert_eq!(col0[13].val(), rows[9]);
1645		assert_eq!(col0[14].val(), rows[10]);
1646		assert_eq!(col0[15].val(), rows[10]);
1647	}
1648
1649	#[test]
1650	fn test_fill_empty_rows_non_empty_table() {
1651		let mut cs = ConstraintSystem::new();
1652		let test_table = TestTable::new(&mut cs);
1653
1654		let mut allocator = CpuComputeAllocator::new(1 << 12);
1655		let allocator = allocator.into_bump_allocator();
1656
1657		let table_size = 11;
1658		let mut index = WitnessIndex::new(&cs, &allocator);
1659
1660		index.init_table(test_table.id(), table_size).unwrap();
1661
1662		assert_matches!(
1663			index.fill_table_sequential(&test_table, &[]),
1664			Err(Error::IncorrectNumberOfTableEvents { .. })
1665		);
1666	}
1667
1668	#[test]
1669	fn test_dyn_witness() {
1670		let mut cs = ConstraintSystem::new();
1671		let mut test_table = cs.add_table("test");
1672		let test_col: Col<B32, 4> = test_table.add_committed("col");
1673		let table_id = test_table.id();
1674
1675		let mut allocator = CpuComputeAllocator::new(1 << 12);
1676		let allocator = allocator.into_bump_allocator();
1677
1678		let table_size = 11;
1679		let mut index = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
1680
1681		let table_index = index.init_table(table_id, table_size).unwrap();
1682		let segment = table_index.full_segment();
1683		let mut rng = StdRng::seed_from_u64(0);
1684		let row = repeat_with(|| B32::random(&mut rng))
1685			.take(table_size * 4)
1686			.collect::<Vec<_>>();
1687		{
1688			let mut data: Box<dyn WitnessColViewMut<_>> =
1689				segment.get_dyn_mut(test_col.id()).unwrap();
1690			row.iter().enumerate().for_each(|(i, val)| {
1691				data.set(i, (*val).into()).unwrap();
1692			})
1693		}
1694		let data = segment.get_dyn(test_col.id()).unwrap();
1695		row.iter().enumerate().for_each(|(i, val)| {
1696			let down_cast: B32 = data.get(i).try_into().unwrap();
1697			assert_eq!(down_cast, *val)
1698		})
1699	}
1700
1701	fn find_oracle_id_with_name(
1702		oracles: &SymbolicMultilinearOracleSet<B128>,
1703		name: &str,
1704	) -> Option<OracleId> {
1705		oracles
1706			.iter()
1707			.find(|(_, oracle)| oracle.name.as_deref() == Some(name))
1708			.map(|(id, _)| id)
1709	}
1710
1711	#[test]
1712	fn test_constant_filling() {
1713		let mut cs = ConstraintSystem::new();
1714
1715		let mut test_table = cs.add_table("test");
1716		let mut rng = StdRng::seed_from_u64(0);
1717		let unpack_value = B16::random(&mut rng);
1718		let pack_const_arr: [B32; 4] = array::from_fn(|_| B32::random(&mut rng));
1719
1720		let _ = test_table.add_constant("unpacked_col", [unpack_value]);
1721		let _ = test_table.add_constant("packed_col", pack_const_arr);
1722		let table_id = test_table.id();
1723
1724		let mut allocator = CpuComputeAllocator::new(1 << 12);
1725		let allocator = allocator.into_bump_allocator();
1726
1727		let table_size = 123;
1728		let ccs = cs.compile().unwrap();
1729		let mut index = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
1730
1731		{
1732			let _ = index.init_table(table_id, table_size).unwrap();
1733			index.fill_constant_cols().unwrap();
1734		}
1735
1736		let witness = index.into_multilinear_extension_index();
1737		let non_packed_col_id = find_oracle_id_with_name(&ccs.oracles, "unpacked_col").unwrap();
1738
1739		// Query MultilinearExtensionIndex to see if the constants are correct.
1740		let non_pack_witness = witness.get_multilin_poly(non_packed_col_id).unwrap();
1741		for index in 0..non_pack_witness.size() {
1742			let got = non_pack_witness.evaluate_on_hypercube(index).unwrap();
1743			assert_eq!(got, unpack_value.into());
1744		}
1745		let packed_col_id = find_oracle_id_with_name(&ccs.oracles, "packed_col").unwrap();
1746
1747		let pack_witness = witness.get_multilin_poly(packed_col_id).unwrap();
1748		for index in 0..pack_witness.size() {
1749			let got = pack_witness.evaluate_on_hypercube(index).unwrap();
1750
1751			assert_eq!(got, pack_const_arr[index % 4].into());
1752		}
1753	}
1754}