binius_m3/builder/
constraint_system.rs

1// Copyright 2025 Irreducible Inc.
2
3pub use binius_core::constraint_system::channel::{
4	Boundary, Flush as CompiledFlush, FlushDirection,
5};
6use binius_core::{
7	constraint_system::{
8		channel::{ChannelId, OracleOrConst},
9		ConstraintSystem as CompiledConstraintSystem,
10	},
11	oracle::{
12		Constraint, ConstraintPredicate, ConstraintSet, MultilinearOracleSet, OracleId,
13		ProjectionVariant,
14	},
15	transparent::step_down::StepDown,
16};
17use binius_field::{PackedField, TowerField};
18use binius_math::LinearNormalForm;
19use binius_utils::checked_arithmetics::log2_strict_usize;
20use bumpalo::Bump;
21
22use super::{
23	channel::{Channel, Flush},
24	column::{ColumnDef, ColumnInfo},
25	error::Error,
26	statement::Statement,
27	table::TablePartition,
28	types::B128,
29	witness::{TableWitnessIndex, WitnessIndex},
30	Table, TableBuilder,
31};
32use crate::builder::expr::ArithExprNamedVars;
33
34/// An M3 constraint system, independent of the table sizes.
35#[derive(Debug, Default)]
36pub struct ConstraintSystem<F: TowerField = B128> {
37	pub tables: Vec<Table<F>>,
38	pub channels: Vec<Channel>,
39}
40
41impl<F: TowerField> std::fmt::Display for ConstraintSystem<F> {
42	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43		writeln!(f, "ConstraintSystem {{")?;
44
45		for channel in self.channels.iter() {
46			writeln!(f, "    CHANNEL {}", channel.name)?;
47		}
48
49		let mut oracle_id = 0;
50
51		for table in self.tables.iter() {
52			writeln!(f, "    TABLE {} {{", table.name)?;
53
54			for partition in table.partitions.values() {
55				for flush in partition.flushes.iter() {
56					let channel = self.channels[flush.channel_id].name.clone();
57					let columns = flush
58						.column_indices
59						.iter()
60						.map(|i| table.columns[partition.columns[*i]].name.clone())
61						.collect::<Vec<_>>()
62						.join(", ");
63					match flush.direction {
64						FlushDirection::Push => {
65							writeln!(f, "        PUSH ({columns}) to {channel}")?
66						}
67						FlushDirection::Pull => {
68							writeln!(f, "        PULL ({columns}) from {channel}")?
69						}
70					};
71				}
72
73				let names = partition
74					.columns
75					.iter()
76					.map(|&index| table.columns[index].name.clone())
77					.collect::<Vec<_>>();
78
79				for constraint in partition.zero_constraints.iter() {
80					let name = constraint.name.clone();
81					let expr = ArithExprNamedVars(&constraint.expr, &names);
82					writeln!(f, "        ZERO {name}: {expr}")?;
83				}
84			}
85
86			for col in table.columns.iter() {
87				if matches!(col.col, ColumnDef::Constant { .. }) {
88					oracle_id += 1;
89				}
90			}
91
92			for col in table.columns.iter() {
93				let name = col.name.clone();
94				let log_values_per_row = col.shape.log_values_per_row;
95				let field = match col.shape.tower_height {
96					0 => "B1",
97					1 => "B2",
98					2 => "B4",
99					3 => "B8",
100					4 => "B16",
101					5 => "B32",
102					6 => "B64",
103					_ => "B128",
104				};
105				let type_str = if log_values_per_row > 0 {
106					let values_per_row = 1 << log_values_per_row;
107					format!("{field}x{values_per_row}")
108				} else {
109					field.to_string()
110				};
111				writeln!(f, "        {oracle_id:04} {type_str} {name}")?;
112				oracle_id += 1;
113			}
114
115			// step_down selectors for the table
116			for log_values_per_row in table.partitions.keys() {
117				let values_per_row = 1 << log_values_per_row;
118				let selector_type_str = if values_per_row > 1 {
119					format!("B1x{}", values_per_row)
120				} else {
121					"B1".to_string()
122				};
123				writeln!(f, "        {oracle_id:04} {selector_type_str} (ROW_SELECTOR)")?;
124				oracle_id += 1;
125			}
126
127			writeln!(f, "    }}")?;
128		}
129		writeln!(f, "}}")
130	}
131}
132
133impl<F: TowerField> ConstraintSystem<F> {
134	pub fn new() -> Self {
135		Self::default()
136	}
137
138	pub fn add_table(&mut self, name: impl ToString) -> TableBuilder<'_, F> {
139		let id = self.tables.len();
140		self.tables.push(Table::new(id, name.to_string()));
141		TableBuilder::new(self.tables.last_mut().expect("table was just pushed"))
142	}
143
144	pub fn add_channel(&mut self, name: impl ToString) -> ChannelId {
145		let id = self.channels.len();
146		self.channels.push(Channel {
147			name: name.to_string(),
148		});
149		id
150	}
151
152	/// Creates and allocates the witness index for a statement.
153	///
154	/// The statement includes information about the tables sizes, which this requires in order to
155	/// allocate the column data correctly. The created witness index needs to be populated before
156	/// proving.
157	pub fn build_witness<'cs, 'alloc, P: PackedField<Scalar = F>>(
158		&'cs self,
159		allocator: &'alloc Bump,
160		statement: &Statement,
161	) -> Result<WitnessIndex<'cs, 'alloc, P>, Error> {
162		Ok(WitnessIndex {
163			tables: self
164				.tables
165				.iter()
166				.zip(&statement.table_sizes)
167				.map(|(table, &table_size)| {
168					let witness = if table_size > 0 {
169						Some(TableWitnessIndex::new(allocator, table, table_size))
170					} else {
171						None
172					};
173					witness.transpose()
174				})
175				.collect::<Result<_, _>>()?,
176		})
177	}
178
179	/// Compiles a [`CompiledConstraintSystem`] for a particular statement.
180	///
181	/// The most important transformation that takes place in this step is creating multilinear
182	/// oracles for all columns. The main difference between column definitions and oracle
183	/// definitions is that multilinear oracle definitions have a number of variables, whereas the
184	/// column definitions contained in a [`ConstraintSystem`] do not have size information.
185	pub fn compile(&self, statement: &Statement<F>) -> Result<CompiledConstraintSystem<F>, Error> {
186		if statement.table_sizes.len() != self.tables.len() {
187			return Err(Error::StatementMissingTableSize {
188				expected: self.tables.len(),
189				actual: statement.table_sizes.len(),
190			});
191		}
192
193		// TODO: new -> with_capacity
194		let mut oracles = MultilinearOracleSet::new();
195		let mut table_constraints = Vec::new();
196		let mut compiled_flushes = Vec::new();
197		let mut non_zero_oracle_ids = Vec::new();
198
199		for (table, &count) in std::iter::zip(&self.tables, &statement.table_sizes) {
200			if count == 0 {
201				continue;
202			}
203			let mut oracle_lookup = Vec::new();
204
205			let mut transparent_single = vec![None; table.columns.len()];
206			for (table_index, info) in table.columns.iter().enumerate() {
207				if let ColumnDef::Constant { poly } = &info.col {
208					let oracle_id = oracles
209						.add_named(format!("{}_single", info.name))
210						.transparent(poly.clone())?;
211					transparent_single[table_index] = Some(oracle_id);
212				}
213			}
214
215			// Add multilinear oracles for all table columns.
216			let log_capacity = table.log_capacity(count);
217			for column_info in table.columns.iter() {
218				let n_vars = log_capacity + column_info.shape.log_values_per_row;
219				let oracle_id = add_oracle_for_column(
220					&mut oracles,
221					&oracle_lookup,
222					&transparent_single,
223					column_info,
224					n_vars,
225				)?;
226				oracle_lookup.push(oracle_id);
227				if column_info.is_nonzero {
228					non_zero_oracle_ids.push(oracle_id);
229				}
230			}
231
232			for partition in table.partitions.values() {
233				let TablePartition {
234					columns,
235					flushes,
236					zero_constraints,
237					values_per_row,
238					..
239				} = partition;
240
241				let n_vars = log_capacity + log2_strict_usize(*values_per_row);
242
243				let partition_oracle_ids = columns
244					.iter()
245					.map(|&index| oracle_lookup[index])
246					.collect::<Vec<_>>();
247
248				// StepDown witness data is populated in WitnessIndex::into_multilinear_extension_index
249				let step_down =
250					oracles.add_transparent(StepDown::new(n_vars, count * values_per_row)?)?;
251
252				// Translate flushes for the compiled constraint system.
253				for Flush {
254					column_indices,
255					channel_id,
256					direction,
257					multiplicity,
258					selector,
259				} in flushes
260				{
261					let flush_oracles = column_indices
262						.iter()
263						.map(|&column_index| OracleOrConst::Oracle(oracle_lookup[column_index]))
264						.collect::<Vec<_>>();
265					compiled_flushes.push(CompiledFlush {
266						oracles: flush_oracles,
267						channel_id: *channel_id,
268						direction: *direction,
269						selector: selector.unwrap_or(step_down),
270						multiplicity: *multiplicity as u64,
271					});
272				}
273
274				if !zero_constraints.is_empty() {
275					// Translate zero constraints for the compiled constraint system.
276					let compiled_constraints = zero_constraints
277						.iter()
278						.map(|zero_constraint| Constraint {
279							name: zero_constraint.name.clone(),
280							composition: zero_constraint.expr.clone(),
281							predicate: ConstraintPredicate::Zero,
282						})
283						.collect::<Vec<_>>();
284
285					table_constraints.push(ConstraintSet {
286						n_vars,
287						oracle_ids: partition_oracle_ids,
288						constraints: compiled_constraints,
289					});
290				}
291			}
292		}
293
294		Ok(CompiledConstraintSystem {
295			oracles,
296			table_constraints,
297			flushes: compiled_flushes,
298			non_zero_oracle_ids,
299			max_channel_id: self.channels.len().saturating_sub(1),
300			exponents: Vec::new(),
301		})
302	}
303}
304
305/// Add a table column to the multilinear oracle set with a specified number of variables.
306///
307/// ## Arguments
308///
309/// * `oracles` - The set of multilinear oracles to add to.
310/// * `oracle_lookup` - mapping of column indices in the table to oracle IDs in the oracle set
311/// * `column_info` - information about the column to be added
312/// * `n_vars` - number of variables of the multilinear oracle
313fn add_oracle_for_column<F: TowerField>(
314	oracles: &mut MultilinearOracleSet<F>,
315	oracle_lookup: &[OracleId],
316	transparent_single: &[Option<OracleId>],
317	column_info: &ColumnInfo<F>,
318	n_vars: usize,
319) -> Result<OracleId, Error> {
320	let ColumnInfo {
321		id,
322		col,
323		name,
324		shape,
325		..
326	} = column_info;
327	let addition = oracles.add_named(name);
328	let oracle_id = match col {
329		ColumnDef::Committed { tower_level } => addition.committed(n_vars, *tower_level),
330		ColumnDef::Selected {
331			col,
332			index,
333			index_bits,
334		} => {
335			let index_values = (0..*index_bits)
336				.map(|i| {
337					if (index >> i) & 1 == 0 {
338						F::ZERO
339					} else {
340						F::ONE
341					}
342				})
343				.collect();
344			addition.projected(
345				oracle_lookup[col.table_index],
346				index_values,
347				ProjectionVariant::FirstVars,
348			)?
349		}
350		ColumnDef::Shifted {
351			col,
352			offset,
353			log_block_size,
354			variant,
355		} => {
356			// TODO: debug assert column at col.table_index has the same values_per_row as col.id
357			addition.shifted(oracle_lookup[col.table_index], *offset, *log_block_size, *variant)?
358		}
359		ColumnDef::Packed { col, log_degree } => {
360			// TODO: debug assert column at col.table_index has the same values_per_row as col.id
361			addition.packed(oracle_lookup[col.table_index], *log_degree)?
362		}
363		ColumnDef::Computed { cols, expr } => {
364			if let Ok(LinearNormalForm {
365				constant: offset,
366				var_coeffs,
367			}) = expr.linear_normal_form()
368			{
369				let col_scalars = cols
370					.iter()
371					.zip(var_coeffs)
372					.map(|(&col_index, coeff)| (oracle_lookup[col_index], coeff))
373					.collect::<Vec<_>>();
374				addition.linear_combination_with_offset(n_vars, offset, col_scalars)?
375			} else {
376				let inner_oracles = cols
377					.iter()
378					.map(|&col_index| oracle_lookup[col_index])
379					.collect::<Vec<_>>();
380				addition.composite_mle(n_vars, inner_oracles, expr.clone())?
381			}
382		}
383		ColumnDef::Constant { .. } => addition.repeating(
384			transparent_single[id.table_index].unwrap(),
385			n_vars - shape.log_values_per_row,
386		)?,
387	};
388	Ok(oracle_id)
389}