1pub use binius_core::constraint_system::channel::{
4 Boundary, Flush as CompiledFlush, FlushDirection,
5};
6use binius_core::{
7 constraint_system::{
8 channel::{ChannelId, OracleOrConst},
9 ConstraintSystem as CompiledConstraintSystem,
10 },
11 oracle::{
12 Constraint, ConstraintPredicate, ConstraintSet, MultilinearOracleSet, OracleId,
13 ProjectionVariant,
14 },
15 transparent::step_down::StepDown,
16};
17use binius_field::{PackedField, TowerField};
18use binius_math::LinearNormalForm;
19use binius_utils::checked_arithmetics::log2_strict_usize;
20use bumpalo::Bump;
21
22use super::{
23 channel::{Channel, Flush},
24 column::{ColumnDef, ColumnInfo},
25 error::Error,
26 statement::Statement,
27 table::TablePartition,
28 types::B128,
29 witness::{TableWitnessIndex, WitnessIndex},
30 Table, TableBuilder,
31};
32use crate::builder::expr::ArithExprNamedVars;
33
34#[derive(Debug, Default)]
36pub struct ConstraintSystem<F: TowerField = B128> {
37 pub tables: Vec<Table<F>>,
38 pub channels: Vec<Channel>,
39}
40
41impl<F: TowerField> std::fmt::Display for ConstraintSystem<F> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 writeln!(f, "ConstraintSystem {{")?;
44
45 for channel in self.channels.iter() {
46 writeln!(f, " CHANNEL {}", channel.name)?;
47 }
48
49 let mut oracle_id = 0;
50
51 for table in self.tables.iter() {
52 writeln!(f, " TABLE {} {{", table.name)?;
53
54 for partition in table.partitions.values() {
55 for flush in partition.flushes.iter() {
56 let channel = self.channels[flush.channel_id].name.clone();
57 let columns = flush
58 .column_indices
59 .iter()
60 .map(|i| table.columns[partition.columns[*i]].name.clone())
61 .collect::<Vec<_>>()
62 .join(", ");
63 match flush.direction {
64 FlushDirection::Push => {
65 writeln!(f, " PUSH ({columns}) to {channel}")?
66 }
67 FlushDirection::Pull => {
68 writeln!(f, " PULL ({columns}) from {channel}")?
69 }
70 };
71 }
72
73 let names = partition
74 .columns
75 .iter()
76 .map(|&index| table.columns[index].name.clone())
77 .collect::<Vec<_>>();
78
79 for constraint in partition.zero_constraints.iter() {
80 let name = constraint.name.clone();
81 let expr = ArithExprNamedVars(&constraint.expr, &names);
82 writeln!(f, " ZERO {name}: {expr}")?;
83 }
84 }
85
86 for col in table.columns.iter() {
87 if matches!(col.col, ColumnDef::Constant { .. }) {
88 oracle_id += 1;
89 }
90 }
91
92 for col in table.columns.iter() {
93 let name = col.name.clone();
94 let log_values_per_row = col.shape.log_values_per_row;
95 let field = match col.shape.tower_height {
96 0 => "B1",
97 1 => "B2",
98 2 => "B4",
99 3 => "B8",
100 4 => "B16",
101 5 => "B32",
102 6 => "B64",
103 _ => "B128",
104 };
105 let type_str = if log_values_per_row > 0 {
106 let values_per_row = 1 << log_values_per_row;
107 format!("{field}x{values_per_row}")
108 } else {
109 field.to_string()
110 };
111 writeln!(f, " {oracle_id:04} {type_str} {name}")?;
112 oracle_id += 1;
113 }
114
115 for log_values_per_row in table.partitions.keys() {
117 let values_per_row = 1 << log_values_per_row;
118 let selector_type_str = if values_per_row > 1 {
119 format!("B1x{}", values_per_row)
120 } else {
121 "B1".to_string()
122 };
123 writeln!(f, " {oracle_id:04} {selector_type_str} (ROW_SELECTOR)")?;
124 oracle_id += 1;
125 }
126
127 writeln!(f, " }}")?;
128 }
129 writeln!(f, "}}")
130 }
131}
132
133impl<F: TowerField> ConstraintSystem<F> {
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 pub fn add_table(&mut self, name: impl ToString) -> TableBuilder<'_, F> {
139 let id = self.tables.len();
140 self.tables.push(Table::new(id, name.to_string()));
141 TableBuilder::new(self.tables.last_mut().expect("table was just pushed"))
142 }
143
144 pub fn add_channel(&mut self, name: impl ToString) -> ChannelId {
145 let id = self.channels.len();
146 self.channels.push(Channel {
147 name: name.to_string(),
148 });
149 id
150 }
151
152 pub fn build_witness<'cs, 'alloc, P: PackedField<Scalar = F>>(
158 &'cs self,
159 allocator: &'alloc Bump,
160 statement: &Statement,
161 ) -> Result<WitnessIndex<'cs, 'alloc, P>, Error> {
162 Ok(WitnessIndex {
163 tables: self
164 .tables
165 .iter()
166 .zip(&statement.table_sizes)
167 .map(|(table, &table_size)| {
168 let witness = if table_size > 0 {
169 Some(TableWitnessIndex::new(allocator, table, table_size))
170 } else {
171 None
172 };
173 witness.transpose()
174 })
175 .collect::<Result<_, _>>()?,
176 })
177 }
178
179 pub fn compile(&self, statement: &Statement<F>) -> Result<CompiledConstraintSystem<F>, Error> {
186 if statement.table_sizes.len() != self.tables.len() {
187 return Err(Error::StatementMissingTableSize {
188 expected: self.tables.len(),
189 actual: statement.table_sizes.len(),
190 });
191 }
192
193 let mut oracles = MultilinearOracleSet::new();
195 let mut table_constraints = Vec::new();
196 let mut compiled_flushes = Vec::new();
197 let mut non_zero_oracle_ids = Vec::new();
198
199 for (table, &count) in std::iter::zip(&self.tables, &statement.table_sizes) {
200 if count == 0 {
201 continue;
202 }
203 let mut oracle_lookup = Vec::new();
204
205 let mut transparent_single = vec![None; table.columns.len()];
206 for (table_index, info) in table.columns.iter().enumerate() {
207 if let ColumnDef::Constant { poly } = &info.col {
208 let oracle_id = oracles
209 .add_named(format!("{}_single", info.name))
210 .transparent(poly.clone())?;
211 transparent_single[table_index] = Some(oracle_id);
212 }
213 }
214
215 let log_capacity = table.log_capacity(count);
217 for column_info in table.columns.iter() {
218 let n_vars = log_capacity + column_info.shape.log_values_per_row;
219 let oracle_id = add_oracle_for_column(
220 &mut oracles,
221 &oracle_lookup,
222 &transparent_single,
223 column_info,
224 n_vars,
225 )?;
226 oracle_lookup.push(oracle_id);
227 if column_info.is_nonzero {
228 non_zero_oracle_ids.push(oracle_id);
229 }
230 }
231
232 for partition in table.partitions.values() {
233 let TablePartition {
234 columns,
235 flushes,
236 zero_constraints,
237 values_per_row,
238 ..
239 } = partition;
240
241 let n_vars = log_capacity + log2_strict_usize(*values_per_row);
242
243 let partition_oracle_ids = columns
244 .iter()
245 .map(|&index| oracle_lookup[index])
246 .collect::<Vec<_>>();
247
248 let step_down =
250 oracles.add_transparent(StepDown::new(n_vars, count * values_per_row)?)?;
251
252 for Flush {
254 column_indices,
255 channel_id,
256 direction,
257 multiplicity,
258 selector,
259 } in flushes
260 {
261 let flush_oracles = column_indices
262 .iter()
263 .map(|&column_index| OracleOrConst::Oracle(oracle_lookup[column_index]))
264 .collect::<Vec<_>>();
265 compiled_flushes.push(CompiledFlush {
266 oracles: flush_oracles,
267 channel_id: *channel_id,
268 direction: *direction,
269 selector: selector.unwrap_or(step_down),
270 multiplicity: *multiplicity as u64,
271 });
272 }
273
274 if !zero_constraints.is_empty() {
275 let compiled_constraints = zero_constraints
277 .iter()
278 .map(|zero_constraint| Constraint {
279 name: zero_constraint.name.clone(),
280 composition: zero_constraint.expr.clone(),
281 predicate: ConstraintPredicate::Zero,
282 })
283 .collect::<Vec<_>>();
284
285 table_constraints.push(ConstraintSet {
286 n_vars,
287 oracle_ids: partition_oracle_ids,
288 constraints: compiled_constraints,
289 });
290 }
291 }
292 }
293
294 Ok(CompiledConstraintSystem {
295 oracles,
296 table_constraints,
297 flushes: compiled_flushes,
298 non_zero_oracle_ids,
299 max_channel_id: self.channels.len().saturating_sub(1),
300 exponents: Vec::new(),
301 })
302 }
303}
304
305fn add_oracle_for_column<F: TowerField>(
314 oracles: &mut MultilinearOracleSet<F>,
315 oracle_lookup: &[OracleId],
316 transparent_single: &[Option<OracleId>],
317 column_info: &ColumnInfo<F>,
318 n_vars: usize,
319) -> Result<OracleId, Error> {
320 let ColumnInfo {
321 id,
322 col,
323 name,
324 shape,
325 ..
326 } = column_info;
327 let addition = oracles.add_named(name);
328 let oracle_id = match col {
329 ColumnDef::Committed { tower_level } => addition.committed(n_vars, *tower_level),
330 ColumnDef::Selected {
331 col,
332 index,
333 index_bits,
334 } => {
335 let index_values = (0..*index_bits)
336 .map(|i| {
337 if (index >> i) & 1 == 0 {
338 F::ZERO
339 } else {
340 F::ONE
341 }
342 })
343 .collect();
344 addition.projected(
345 oracle_lookup[col.table_index],
346 index_values,
347 ProjectionVariant::FirstVars,
348 )?
349 }
350 ColumnDef::Shifted {
351 col,
352 offset,
353 log_block_size,
354 variant,
355 } => {
356 addition.shifted(oracle_lookup[col.table_index], *offset, *log_block_size, *variant)?
358 }
359 ColumnDef::Packed { col, log_degree } => {
360 addition.packed(oracle_lookup[col.table_index], *log_degree)?
362 }
363 ColumnDef::Computed { cols, expr } => {
364 if let Ok(LinearNormalForm {
365 constant: offset,
366 var_coeffs,
367 }) = expr.linear_normal_form()
368 {
369 let col_scalars = cols
370 .iter()
371 .zip(var_coeffs)
372 .map(|(&col_index, coeff)| (oracle_lookup[col_index], coeff))
373 .collect::<Vec<_>>();
374 addition.linear_combination_with_offset(n_vars, offset, col_scalars)?
375 } else {
376 let inner_oracles = cols
377 .iter()
378 .map(|&col_index| oracle_lookup[col_index])
379 .collect::<Vec<_>>();
380 addition.composite_mle(n_vars, inner_oracles, expr.clone())?
381 }
382 }
383 ColumnDef::Constant { .. } => addition.repeating(
384 transparent_single[id.table_index].unwrap(),
385 n_vars - shape.log_values_per_row,
386 )?,
387 };
388 Ok(oracle_id)
389}