binius_m3/builder/
constraint_system.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{cell, collections::BTreeMap, ops::Index};
4
5use binius_compute::alloc::HostBumpAllocator;
6pub use binius_core::constraint_system::channel::{
7	Boundary, Flush as CompiledFlush, FlushDirection,
8};
9use binius_core::{
10	constraint_system::{
11		ConstraintSystem as CompiledConstraintSystem,
12		channel::{ChannelId, OracleOrConst},
13		exp::Exp,
14	},
15	oracle::{
16		Constraint, ConstraintPredicate, ConstraintSet, OracleId, SymbolicMultilinearOracleSet,
17	},
18};
19use binius_field::{PackedField, TowerField};
20use binius_math::{ArithCircuit, LinearNormalForm};
21use binius_utils::checked_arithmetics::log2_strict_usize;
22
23use super::{
24	ColumnId, Table, TableBuilder, TableId, ZeroConstraint,
25	channel::{Channel, Flush},
26	column::{ColumnDef, ColumnInfo},
27	error::Error,
28	table::TablePartition,
29	types::B128,
30	witness::WitnessIndex,
31};
32use crate::builder::expr::ArithExprNamedVars;
33
34/// An M3 constraint system, independent of the table sizes.
35#[derive(Debug, Default)]
36pub struct ConstraintSystem<F: TowerField = B128> {
37	pub tables: Vec<Table<F>>,
38	pub channels: Vec<Channel>,
39
40	// This is assigned as part of `ConstraintSystem::compile`.
41	oracle_lookup: cell::RefCell<Option<OracleLookup>>,
42}
43
44impl<F: TowerField> std::fmt::Display for ConstraintSystem<F> {
45	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46		writeln!(f, "ConstraintSystem {{")?;
47
48		for channel in self.channels.iter() {
49			writeln!(f, "    CHANNEL {}", channel.name)?;
50		}
51
52		let mut oracle_id = 0;
53
54		for table in self.tables.iter() {
55			writeln!(f, "    TABLE {} {{", table.name)?;
56
57			for partition in table.partitions.values() {
58				for flush in partition.flushes.iter() {
59					let channel = self.channels[flush.channel_id].name.clone();
60					let columns = flush
61						.columns
62						.iter()
63						.map(|i| table[*i].name.clone())
64						.collect::<Vec<_>>()
65						.join(", ");
66					match flush.direction {
67						FlushDirection::Push => {
68							writeln!(f, "        PUSH ({columns}) to {channel}")?
69						}
70						FlushDirection::Pull => {
71							writeln!(f, "        PULL ({columns}) from {channel}")?
72						}
73					};
74				}
75
76				let names = partition
77					.columns
78					.iter()
79					.map(|&index| table[index].name.clone())
80					.collect::<Vec<_>>();
81
82				for constraint in partition.zero_constraints.iter() {
83					let name = constraint.name.clone();
84					let expr = ArithExprNamedVars(&constraint.expr, &names);
85					writeln!(f, "        ZERO {name}: {expr}")?;
86				}
87			}
88
89			for col in table.columns.iter() {
90				if matches!(col.col, ColumnDef::Constant { .. }) {
91					oracle_id += 1;
92				}
93			}
94
95			for col in table.columns.iter() {
96				let name = col.name.clone();
97				let log_values_per_row = col.shape.log_values_per_row;
98				let field = match col.shape.tower_height {
99					0 => "B1",
100					1 => "B2",
101					2 => "B4",
102					3 => "B8",
103					4 => "B16",
104					5 => "B32",
105					6 => "B64",
106					_ => "B128",
107				};
108				let type_str = if log_values_per_row > 0 {
109					let values_per_row = 1 << log_values_per_row;
110					format!("{field}x{values_per_row}")
111				} else {
112					field.to_string()
113				};
114				writeln!(f, "        {oracle_id:04} {type_str} {name}")?;
115				oracle_id += 1;
116			}
117
118			// step_down selectors for the table
119			for log_values_per_row in table.partitions.keys() {
120				let values_per_row = 1 << log_values_per_row;
121				let selector_type_str = if values_per_row > 1 {
122					format!("B1x{values_per_row}")
123				} else {
124					"B1".to_string()
125				};
126				writeln!(f, "        {oracle_id:04} {selector_type_str} (ROW_SELECTOR)")?;
127				oracle_id += 1;
128			}
129
130			writeln!(f, "    }}")?;
131		}
132		writeln!(f, "}}")
133	}
134}
135
136impl<F: TowerField> ConstraintSystem<F> {
137	pub fn new() -> Self {
138		Self::default()
139	}
140
141	pub fn add_table(&mut self, name: impl ToString) -> TableBuilder<'_, F> {
142		let id = self.tables.len();
143		self.tables.push(Table::new(id, name.to_string()));
144		TableBuilder::new(self.tables.last_mut().expect("table was just pushed"))
145	}
146
147	pub fn add_channel(&mut self, name: impl ToString) -> ChannelId {
148		let id = self.channels.len();
149		self.channels.push(Channel {
150			name: name.to_string(),
151		});
152		id
153	}
154
155	/// Creates and allocates the witness index.
156	///
157	/// **Deprecated**: This is a thin wrapper over [`WitnessIndex::new`] now, which is preferred.
158	#[deprecated]
159	pub fn build_witness<'cs, 'alloc, P: PackedField<Scalar = F>>(
160		&'cs self,
161		allocator: &'alloc HostBumpAllocator<'alloc, P>,
162	) -> WitnessIndex<'cs, 'alloc, P> {
163		WitnessIndex::new(self, allocator)
164	}
165
166	/// Returns the oracle lookup for this constraint system.
167	///
168	/// Note that this function returns the struct as of the last call to
169	/// [`ConstraintSystem::lookup`].
170	#[track_caller]
171	pub(crate) fn oracle_lookup<'a>(&'a self) -> cell::Ref<'a, OracleLookup> {
172		const MESSAGE: &str = "oracle_lookup was requested but constraint system was not compiled";
173		cell::Ref::map(self.oracle_lookup.borrow(), |o| o.as_ref().expect(MESSAGE))
174	}
175
176	/// Compiles a [`CompiledConstraintSystem`] for a particular statement.
177	///
178	/// The most important transformation that takes place in this step is creating multilinear
179	/// oracles for all columns. The main difference between column definitions and oracle
180	/// definitions is that multilinear oracle definitions have a number of variables, whereas the
181	/// column definitions contained in a [`ConstraintSystem`] do not have size information.
182	pub fn compile(&self) -> Result<CompiledConstraintSystem<F>, Error> {
183		let mut oracles = SymbolicMultilinearOracleSet::new();
184		let mut table_constraints = Vec::new();
185		let mut compiled_flushes = Vec::new();
186		let mut non_zero_oracle_ids = Vec::new();
187		let mut exponents = Vec::new();
188		let mut table_size_specs = Vec::new();
189
190		let mut oracle_lookup = OracleLookup::new();
191
192		for table in &self.tables {
193			table_size_specs.push(table.size_spec());
194
195			// Add multilinear oracles for all table columns.
196			add_oracles_for_columns(
197				&mut oracle_lookup,
198				&mut oracles,
199				table,
200				&mut non_zero_oracle_ids,
201			)?;
202
203			for partition in table.partitions.values() {
204				let TablePartition {
205					columns,
206					flushes,
207					zero_constraints,
208					values_per_row,
209					..
210				} = partition;
211
212				let partition_oracle_ids = columns
213					.iter()
214					.map(|&index| oracle_lookup[index])
215					.collect::<Vec<_>>();
216
217				// Add Exponents with the same pack factor for the compiled constraint system.
218				columns.iter().for_each(|index| {
219					let col = &table[*index];
220					let col_info = &table[*index].col;
221					match col_info {
222						ColumnDef::StaticExp {
223							bit_cols,
224							base,
225							base_tower_level,
226						} => {
227							let bits_ids = bit_cols
228								.iter()
229								.map(|&column_id| oracle_lookup[column_id])
230								.collect();
231							exponents.push(Exp {
232								base: OracleOrConst::Const {
233									base: *base,
234									tower_level: *base_tower_level,
235								},
236								bits_ids,
237								exp_result_id: oracle_lookup[col.id],
238							});
239						}
240						ColumnDef::DynamicExp { bit_cols, base, .. } => {
241							let bits_ids = bit_cols
242								.iter()
243								.map(|&col_idx| oracle_lookup[col_idx])
244								.collect();
245							exponents.push(Exp {
246								base: OracleOrConst::Oracle(oracle_lookup[*base]),
247								bits_ids,
248								exp_result_id: oracle_lookup[col.id],
249							})
250						}
251						_ => (),
252					}
253				});
254
255				// Translate flushes for the compiled constraint system.
256				for Flush {
257					columns: flush_columns,
258					channel_id,
259					direction,
260					multiplicity,
261					selectors,
262				} in flushes
263				{
264					let flush_oracles = flush_columns
265						.iter()
266						.map(|&column_id| OracleOrConst::Oracle(oracle_lookup[column_id]))
267						.collect::<Vec<_>>();
268					let selectors = selectors
269						.iter()
270						.map(|column_idx| oracle_lookup[*column_idx])
271						.collect::<Vec<_>>();
272
273					compiled_flushes.push(CompiledFlush {
274						table_id: table.id(),
275						log_values_per_row: log2_strict_usize(*values_per_row),
276						oracles: flush_oracles,
277						channel_id: *channel_id,
278						direction: *direction,
279						selectors,
280						multiplicity: *multiplicity as u64,
281					});
282				}
283
284				if !zero_constraints.is_empty() {
285					let constraint_set = translate_constraint_set(
286						table.id(),
287						log2_strict_usize(*values_per_row),
288						zero_constraints,
289						partition_oracle_ids,
290					);
291					table_constraints.push(constraint_set);
292				}
293			}
294		}
295
296		*self.oracle_lookup.borrow_mut() = Some(oracle_lookup);
297
298		Ok(CompiledConstraintSystem {
299			oracles,
300			table_constraints,
301			flushes: compiled_flushes,
302			non_zero_oracle_ids,
303			channel_count: self.channels.len(),
304			exponents,
305			table_size_specs,
306		})
307	}
308}
309
310#[derive(Debug, Copy, Clone)]
311pub(crate) enum OracleMapping {
312	Regular(OracleId),
313	/// This is used for constant columns.
314	///
315	/// A constant columns are backed by a transparent oracle. That oracle is a single row and
316	/// is not repeating which is not really expected. So in order to reduce the factor of surprise
317	/// for the user, the original transparent is wrapped into a repeating virtual oracle.
318	TransparentCompound {
319		original: OracleId,
320		repeating: OracleId,
321	},
322}
323
324/// This structure holds metadata about every oracle ID in a constraint system.
325///
326/// This structure maintains mapping between the [`OracleId`] and the related [`ColumnId`].
327#[derive(Debug, Default)]
328pub(crate) struct OracleLookup {
329	column_to_oracle: BTreeMap<ColumnId, OracleMapping>,
330}
331
332impl OracleLookup {
333	/// Creates a new empty Oracle Registry.
334	pub(crate) fn new() -> Self {
335		Self {
336			column_to_oracle: BTreeMap::new(),
337		}
338	}
339
340	/// Looks up the [`OracleMapping`] for a given column ID.
341	///
342	/// # Preconditions
343	///
344	/// The column ID must exist in the registry, otherwise this function will panic.
345	pub fn lookup(&self, column_id: ColumnId) -> &OracleMapping {
346		&self.column_to_oracle[&column_id]
347	}
348
349	/// Adds a mapping from a column ID to an oracle mapping.
350	///
351	/// # Preconditions
352	///
353	/// The column ID must not already be registered in the registry, otherwise this function
354	/// will panic.
355	fn register_regular(&mut self, column_id: ColumnId, oracle_id: OracleId) {
356		let prev = self
357			.column_to_oracle
358			.insert(column_id, OracleMapping::Regular(oracle_id));
359		assert!(prev.is_none());
360	}
361
362	/// Registers a transparent oracle mapping for a column.
363	///
364	/// This creates a compound mapping from a column to both an original and a repeating oracle.
365	/// This is specifically used for constant columns that need repeating behavior.
366	///
367	/// # Preconditions
368	///
369	/// The column ID must not already be registered in the registry, otherwise this function
370	/// will panic.
371	fn register_transparent(
372		&mut self,
373		column_id: ColumnId,
374		original: OracleId,
375		repeating: OracleId,
376	) {
377		let prev = self.column_to_oracle.insert(
378			column_id,
379			OracleMapping::TransparentCompound {
380				original,
381				repeating,
382			},
383		);
384		assert!(prev.is_none());
385	}
386}
387
388/// Indexing for [`OracleLookup`]. For transparents this returns the repeating column.
389impl Index<ColumnId> for OracleLookup {
390	type Output = OracleId;
391
392	fn index(&self, id: ColumnId) -> &Self::Output {
393		match &self.column_to_oracle[&id] {
394			OracleMapping::Regular(oracle_id) => oracle_id,
395			OracleMapping::TransparentCompound { repeating, .. } => repeating,
396		}
397	}
398}
399
400/// Add all columns within the given table into the given `oracle_lookup`. Also, fills out
401/// the `non_zero_oracle_ids`.
402fn add_oracles_for_columns<F: TowerField>(
403	oracle_lookup: &mut OracleLookup,
404	oracle_set: &mut SymbolicMultilinearOracleSet<F>,
405	table: &Table<F>,
406	non_zero_oracle_ids: &mut Vec<OracleId>,
407) -> Result<(), Error> {
408	for column_info in table.columns.iter() {
409		add_oracle_for_column(oracle_set, oracle_lookup, column_info, table.id())?;
410		if column_info.is_nonzero {
411			non_zero_oracle_ids.push(oracle_lookup[column_info.id]);
412		}
413	}
414	Ok(())
415}
416
417/// Add a table column to the multilinear oracle set with a specified number of variables.
418///
419/// ## Arguments
420///
421/// * `oracles` - The set of multilinear oracles to add to.
422/// * `oracle_lookup` - mapping of column indices in the table to oracle IDs in the oracle set
423/// * `column_info` - information about the column to be added
424/// * `n_vars` - number of variables of the multilinear oracle
425fn add_oracle_for_column<F: TowerField>(
426	oracles: &mut SymbolicMultilinearOracleSet<F>,
427	oracle_lookup: &mut OracleLookup,
428	column_info: &ColumnInfo<F>,
429	table_id: TableId,
430) -> Result<(), Error> {
431	let ColumnInfo {
432		id: column_id,
433		col,
434		name,
435		shape,
436		..
437	} = column_info;
438	match col {
439		ColumnDef::Committed { tower_level } => {
440			let oracle_id = oracles
441				.add_oracle(table_id, shape.log_values_per_row, name)
442				.committed(*tower_level);
443			oracle_lookup.register_regular(*column_id, oracle_id);
444		}
445		ColumnDef::Selected {
446			col,
447			index,
448			index_bits,
449		} => {
450			let index_values = (0..*index_bits)
451				.map(|i| {
452					if (index >> i) & 1 == 0 {
453						F::ZERO
454					} else {
455						F::ONE
456					}
457				})
458				.collect();
459			let oracle_id = oracles
460				.add_oracle(table_id, shape.log_values_per_row, name)
461				.projected(oracle_lookup[*col], index_values, 0)?;
462			oracle_lookup.register_regular(*column_id, oracle_id);
463		}
464		ColumnDef::Projected {
465			col,
466			start_index,
467			query_size,
468			query_bits,
469		} => {
470			let query_values = (0..*query_size)
471				.map(|i| -> F {
472					if (query_bits >> i) & 1 == 0 {
473						F::ZERO
474					} else {
475						F::ONE
476					}
477				})
478				.collect();
479			let oracle_id = oracles
480				.add_oracle(table_id, shape.log_values_per_row, name)
481				.projected(oracle_lookup[*col], query_values, *start_index)?;
482			oracle_lookup.register_regular(*column_id, oracle_id);
483		}
484		ColumnDef::ZeroPadded {
485			col,
486			n_pad_vars,
487			start_index,
488			nonzero_index,
489		} => {
490			let oracle_id = oracles
491				.add_oracle(table_id, shape.log_values_per_row, name)
492				.zero_padded(oracle_lookup[*col], *n_pad_vars, *nonzero_index, *start_index)?;
493			oracle_lookup.register_regular(*column_id, oracle_id);
494		}
495		ColumnDef::Shifted {
496			col,
497			offset,
498			log_block_size,
499			variant,
500		} => {
501			// TODO: debug assert column at col.table_index has the same values_per_row as col.id
502			let oracle_id = oracles
503				.add_oracle(table_id, shape.log_values_per_row, name)
504				.shifted(oracle_lookup[*col], *offset, *log_block_size, *variant)?;
505			oracle_lookup.register_regular(*column_id, oracle_id);
506		}
507		ColumnDef::Packed { col, log_degree } => {
508			// TODO: debug assert column at col.table_index has the same values_per_row as col.id
509			let source = oracle_lookup[*col];
510			let oracle_id = oracles
511				.add_oracle(table_id, shape.log_values_per_row, name)
512				.packed(source, *log_degree)?;
513			oracle_lookup.register_regular(*column_id, oracle_id);
514		}
515		ColumnDef::Computed { cols, expr } => {
516			if let Ok(LinearNormalForm {
517				constant: offset,
518				var_coeffs,
519			}) = expr.linear_normal_form()
520			{
521				let col_scalars = cols
522					.iter()
523					.zip(var_coeffs)
524					.map(|(&col_id, coeff)| (oracle_lookup[col_id], coeff))
525					.collect::<Vec<_>>();
526				let oracle_id = oracles
527					.add_oracle(table_id, shape.log_values_per_row, name)
528					.linear_combination_with_offset(offset, col_scalars)?;
529				oracle_lookup.register_regular(*column_id, oracle_id);
530			} else {
531				let inner_oracles = cols
532					.iter()
533					.map(|&col_index| oracle_lookup[col_index])
534					.collect::<Vec<_>>();
535				let oracle_id = oracles
536					.add_oracle(table_id, shape.log_values_per_row, name)
537					.composite_mle(inner_oracles, expr.clone())?;
538				oracle_lookup.register_regular(*column_id, oracle_id);
539			};
540		}
541		ColumnDef::Constant { poly, .. } => {
542			let oracle_id_original = oracles
543				.add_oracle(table_id, shape.log_values_per_row, format!("{name}_single"))
544				.transparent(poly.clone())?;
545			let oracle_id_repeating = oracles
546				.add_oracle(table_id, shape.log_values_per_row, name)
547				.repeating(oracle_id_original)?;
548			oracle_lookup.register_transparent(*column_id, oracle_id_original, oracle_id_repeating);
549		}
550		ColumnDef::StructuredDynSize(structured) => {
551			let expr = structured.expr()?;
552			let oracle_id = oracles
553				.add_oracle(table_id, shape.log_values_per_row, name)
554				.structured(ArithCircuit::from(&expr))?;
555			oracle_lookup.register_regular(*column_id, oracle_id);
556		}
557		ColumnDef::StructuredFixedSize { expr } => {
558			let oracle_id = oracles
559				.add_oracle(table_id, shape.log_values_per_row, name)
560				.transparent(expr.clone())?;
561			oracle_lookup.register_regular(*column_id, oracle_id);
562		}
563		ColumnDef::StaticExp {
564			base_tower_level, ..
565		} => {
566			let oracle_id = oracles
567				.add_oracle(table_id, shape.log_values_per_row, name)
568				.committed(*base_tower_level);
569			oracle_lookup.register_regular(*column_id, oracle_id);
570		}
571		ColumnDef::DynamicExp {
572			base_tower_level, ..
573		} => {
574			let oracle_id = oracles
575				.add_oracle(table_id, shape.log_values_per_row, name)
576				.committed(*base_tower_level);
577			oracle_lookup.register_regular(*column_id, oracle_id);
578		}
579	};
580	Ok(())
581}
582
583/// Translates a set of zero constraints from a particular table partition into a constraint set.
584///
585/// The resulting constraint set will only contain oracles that were actually referenced from any
586/// of the constraint expressions.
587fn translate_constraint_set<F: TowerField>(
588	table_id: TableId,
589	log_values_per_row: usize,
590	zero_constraints: &[ZeroConstraint<F>],
591	partition_oracle_ids: Vec<OracleId>,
592) -> ConstraintSet<F> {
593	// We need to figure out which oracle ids from the entire set of the partition oracles is
594	// actually referenced in every zero constraint expressions.
595	let mut oracle_appears_in_expr = vec![false; partition_oracle_ids.len()];
596	let mut n_used_oracles = 0usize;
597	for zero_constraint in zero_constraints {
598		let vars_usage = zero_constraint.expr.vars_usage();
599		for (oracle_index, used) in vars_usage.iter().enumerate() {
600			if *used && !oracle_appears_in_expr[oracle_index] {
601				oracle_appears_in_expr[oracle_index] = true;
602				n_used_oracles += 1;
603			}
604		}
605	}
606
607	// Now that we've got the set of oracle ids that appear in the expr we are going to create
608	// a new list of oracle ids each of which is used. Along the way we create a new mapping table
609	// that maps the original oracle index to the new index in the dense list.
610	const INVALID_SENTINEL: usize = usize::MAX;
611	let mut remap_indices_table = vec![INVALID_SENTINEL; partition_oracle_ids.len()];
612	let mut dense_oracle_ids = Vec::with_capacity(n_used_oracles);
613	for (i, &used) in oracle_appears_in_expr.iter().enumerate() {
614		if !used {
615			continue;
616		}
617		let dense_index = dense_oracle_ids.len();
618		dense_oracle_ids.push(partition_oracle_ids[i]);
619		remap_indices_table[i] = dense_index;
620	}
621
622	// Translate zero constraints for the compiled constraint system.
623	let compiled_constraints = zero_constraints
624		.iter()
625		.map(|zero_constraint| {
626			let expr = zero_constraint
627				.expr
628				.clone()
629				.remap_vars(&remap_indices_table)
630				.expect(
631					"the expr must have the same length as partition_oracle_ids which is the\
632				 same length of remap_indices_table",
633				);
634			Constraint {
635				name: zero_constraint.name.clone(),
636				composition: expr,
637				predicate: ConstraintPredicate::Zero,
638			}
639		})
640		.collect::<Vec<_>>();
641
642	ConstraintSet {
643		table_id,
644		log_values_per_row,
645		oracle_ids: dense_oracle_ids,
646		constraints: compiled_constraints,
647	}
648}