1use std::{cell, collections::BTreeMap, ops::Index};
4
5use binius_compute::alloc::HostBumpAllocator;
6pub use binius_core::constraint_system::channel::{
7 Boundary, Flush as CompiledFlush, FlushDirection,
8};
9use binius_core::{
10 constraint_system::{
11 ConstraintSystem as CompiledConstraintSystem,
12 channel::{ChannelId, OracleOrConst},
13 exp::Exp,
14 },
15 oracle::{
16 Constraint, ConstraintPredicate, ConstraintSet, OracleId, SymbolicMultilinearOracleSet,
17 },
18};
19use binius_field::{PackedField, TowerField};
20use binius_math::{ArithCircuit, LinearNormalForm};
21use binius_utils::checked_arithmetics::log2_strict_usize;
22
23use super::{
24 ColumnId, Table, TableBuilder, TableId, ZeroConstraint,
25 channel::{Channel, Flush},
26 column::{ColumnDef, ColumnInfo},
27 error::Error,
28 table::TablePartition,
29 types::B128,
30 witness::WitnessIndex,
31};
32use crate::builder::expr::ArithExprNamedVars;
33
34#[derive(Debug, Default)]
36pub struct ConstraintSystem<F: TowerField = B128> {
37 pub tables: Vec<Table<F>>,
38 pub channels: Vec<Channel>,
39
40 oracle_lookup: cell::RefCell<Option<OracleLookup>>,
42}
43
44impl<F: TowerField> std::fmt::Display for ConstraintSystem<F> {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 writeln!(f, "ConstraintSystem {{")?;
47
48 for channel in self.channels.iter() {
49 writeln!(f, " CHANNEL {}", channel.name)?;
50 }
51
52 let mut oracle_id = 0;
53
54 for table in self.tables.iter() {
55 writeln!(f, " TABLE {} {{", table.name)?;
56
57 for partition in table.partitions.values() {
58 for flush in partition.flushes.iter() {
59 let channel = self.channels[flush.channel_id].name.clone();
60 let columns = flush
61 .columns
62 .iter()
63 .map(|i| table[*i].name.clone())
64 .collect::<Vec<_>>()
65 .join(", ");
66 match flush.direction {
67 FlushDirection::Push => {
68 writeln!(f, " PUSH ({columns}) to {channel}")?
69 }
70 FlushDirection::Pull => {
71 writeln!(f, " PULL ({columns}) from {channel}")?
72 }
73 };
74 }
75
76 let names = partition
77 .columns
78 .iter()
79 .map(|&index| table[index].name.clone())
80 .collect::<Vec<_>>();
81
82 for constraint in partition.zero_constraints.iter() {
83 let name = constraint.name.clone();
84 let expr = ArithExprNamedVars(&constraint.expr, &names);
85 writeln!(f, " ZERO {name}: {expr}")?;
86 }
87 }
88
89 for col in table.columns.iter() {
90 if matches!(col.col, ColumnDef::Constant { .. }) {
91 oracle_id += 1;
92 }
93 }
94
95 for col in table.columns.iter() {
96 let name = col.name.clone();
97 let log_values_per_row = col.shape.log_values_per_row;
98 let field = match col.shape.tower_height {
99 0 => "B1",
100 1 => "B2",
101 2 => "B4",
102 3 => "B8",
103 4 => "B16",
104 5 => "B32",
105 6 => "B64",
106 _ => "B128",
107 };
108 let type_str = if log_values_per_row > 0 {
109 let values_per_row = 1 << log_values_per_row;
110 format!("{field}x{values_per_row}")
111 } else {
112 field.to_string()
113 };
114 writeln!(f, " {oracle_id:04} {type_str} {name}")?;
115 oracle_id += 1;
116 }
117
118 for log_values_per_row in table.partitions.keys() {
120 let values_per_row = 1 << log_values_per_row;
121 let selector_type_str = if values_per_row > 1 {
122 format!("B1x{values_per_row}")
123 } else {
124 "B1".to_string()
125 };
126 writeln!(f, " {oracle_id:04} {selector_type_str} (ROW_SELECTOR)")?;
127 oracle_id += 1;
128 }
129
130 writeln!(f, " }}")?;
131 }
132 writeln!(f, "}}")
133 }
134}
135
136impl<F: TowerField> ConstraintSystem<F> {
137 pub fn new() -> Self {
138 Self::default()
139 }
140
141 pub fn add_table(&mut self, name: impl ToString) -> TableBuilder<'_, F> {
142 let id = self.tables.len();
143 self.tables.push(Table::new(id, name.to_string()));
144 TableBuilder::new(self.tables.last_mut().expect("table was just pushed"))
145 }
146
147 pub fn add_channel(&mut self, name: impl ToString) -> ChannelId {
148 let id = self.channels.len();
149 self.channels.push(Channel {
150 name: name.to_string(),
151 });
152 id
153 }
154
155 #[deprecated]
159 pub fn build_witness<'cs, 'alloc, P: PackedField<Scalar = F>>(
160 &'cs self,
161 allocator: &'alloc HostBumpAllocator<'alloc, P>,
162 ) -> WitnessIndex<'cs, 'alloc, P> {
163 WitnessIndex::new(self, allocator)
164 }
165
166 #[track_caller]
171 pub(crate) fn oracle_lookup<'a>(&'a self) -> cell::Ref<'a, OracleLookup> {
172 const MESSAGE: &str = "oracle_lookup was requested but constraint system was not compiled";
173 cell::Ref::map(self.oracle_lookup.borrow(), |o| o.as_ref().expect(MESSAGE))
174 }
175
176 pub fn compile(&self) -> Result<CompiledConstraintSystem<F>, Error> {
183 let mut oracles = SymbolicMultilinearOracleSet::new();
184 let mut table_constraints = Vec::new();
185 let mut compiled_flushes = Vec::new();
186 let mut non_zero_oracle_ids = Vec::new();
187 let mut exponents = Vec::new();
188 let mut table_size_specs = Vec::new();
189
190 let mut oracle_lookup = OracleLookup::new();
191
192 for table in &self.tables {
193 table_size_specs.push(table.size_spec());
194
195 add_oracles_for_columns(
197 &mut oracle_lookup,
198 &mut oracles,
199 table,
200 &mut non_zero_oracle_ids,
201 )?;
202
203 for partition in table.partitions.values() {
204 let TablePartition {
205 columns,
206 flushes,
207 zero_constraints,
208 values_per_row,
209 ..
210 } = partition;
211
212 let partition_oracle_ids = columns
213 .iter()
214 .map(|&index| oracle_lookup[index])
215 .collect::<Vec<_>>();
216
217 columns.iter().for_each(|index| {
219 let col = &table[*index];
220 let col_info = &table[*index].col;
221 match col_info {
222 ColumnDef::StaticExp {
223 bit_cols,
224 base,
225 base_tower_level,
226 } => {
227 let bits_ids = bit_cols
228 .iter()
229 .map(|&column_id| oracle_lookup[column_id])
230 .collect();
231 exponents.push(Exp {
232 base: OracleOrConst::Const {
233 base: *base,
234 tower_level: *base_tower_level,
235 },
236 bits_ids,
237 exp_result_id: oracle_lookup[col.id],
238 });
239 }
240 ColumnDef::DynamicExp { bit_cols, base, .. } => {
241 let bits_ids = bit_cols
242 .iter()
243 .map(|&col_idx| oracle_lookup[col_idx])
244 .collect();
245 exponents.push(Exp {
246 base: OracleOrConst::Oracle(oracle_lookup[*base]),
247 bits_ids,
248 exp_result_id: oracle_lookup[col.id],
249 })
250 }
251 _ => (),
252 }
253 });
254
255 for Flush {
257 columns: flush_columns,
258 channel_id,
259 direction,
260 multiplicity,
261 selectors,
262 } in flushes
263 {
264 let flush_oracles = flush_columns
265 .iter()
266 .map(|&column_id| OracleOrConst::Oracle(oracle_lookup[column_id]))
267 .collect::<Vec<_>>();
268 let selectors = selectors
269 .iter()
270 .map(|column_idx| oracle_lookup[*column_idx])
271 .collect::<Vec<_>>();
272
273 compiled_flushes.push(CompiledFlush {
274 table_id: table.id(),
275 log_values_per_row: log2_strict_usize(*values_per_row),
276 oracles: flush_oracles,
277 channel_id: *channel_id,
278 direction: *direction,
279 selectors,
280 multiplicity: *multiplicity as u64,
281 });
282 }
283
284 if !zero_constraints.is_empty() {
285 let constraint_set = translate_constraint_set(
286 table.id(),
287 log2_strict_usize(*values_per_row),
288 zero_constraints,
289 partition_oracle_ids,
290 );
291 table_constraints.push(constraint_set);
292 }
293 }
294 }
295
296 *self.oracle_lookup.borrow_mut() = Some(oracle_lookup);
297
298 Ok(CompiledConstraintSystem {
299 oracles,
300 table_constraints,
301 flushes: compiled_flushes,
302 non_zero_oracle_ids,
303 channel_count: self.channels.len(),
304 exponents,
305 table_size_specs,
306 })
307 }
308}
309
310#[derive(Debug, Copy, Clone)]
311pub(crate) enum OracleMapping {
312 Regular(OracleId),
313 TransparentCompound {
319 original: OracleId,
320 repeating: OracleId,
321 },
322}
323
324#[derive(Debug, Default)]
328pub(crate) struct OracleLookup {
329 column_to_oracle: BTreeMap<ColumnId, OracleMapping>,
330}
331
332impl OracleLookup {
333 pub(crate) fn new() -> Self {
335 Self {
336 column_to_oracle: BTreeMap::new(),
337 }
338 }
339
340 pub fn lookup(&self, column_id: ColumnId) -> &OracleMapping {
346 &self.column_to_oracle[&column_id]
347 }
348
349 fn register_regular(&mut self, column_id: ColumnId, oracle_id: OracleId) {
356 let prev = self
357 .column_to_oracle
358 .insert(column_id, OracleMapping::Regular(oracle_id));
359 assert!(prev.is_none());
360 }
361
362 fn register_transparent(
372 &mut self,
373 column_id: ColumnId,
374 original: OracleId,
375 repeating: OracleId,
376 ) {
377 let prev = self.column_to_oracle.insert(
378 column_id,
379 OracleMapping::TransparentCompound {
380 original,
381 repeating,
382 },
383 );
384 assert!(prev.is_none());
385 }
386}
387
388impl Index<ColumnId> for OracleLookup {
390 type Output = OracleId;
391
392 fn index(&self, id: ColumnId) -> &Self::Output {
393 match &self.column_to_oracle[&id] {
394 OracleMapping::Regular(oracle_id) => oracle_id,
395 OracleMapping::TransparentCompound { repeating, .. } => repeating,
396 }
397 }
398}
399
400fn add_oracles_for_columns<F: TowerField>(
403 oracle_lookup: &mut OracleLookup,
404 oracle_set: &mut SymbolicMultilinearOracleSet<F>,
405 table: &Table<F>,
406 non_zero_oracle_ids: &mut Vec<OracleId>,
407) -> Result<(), Error> {
408 for column_info in table.columns.iter() {
409 add_oracle_for_column(oracle_set, oracle_lookup, column_info, table.id())?;
410 if column_info.is_nonzero {
411 non_zero_oracle_ids.push(oracle_lookup[column_info.id]);
412 }
413 }
414 Ok(())
415}
416
417fn add_oracle_for_column<F: TowerField>(
426 oracles: &mut SymbolicMultilinearOracleSet<F>,
427 oracle_lookup: &mut OracleLookup,
428 column_info: &ColumnInfo<F>,
429 table_id: TableId,
430) -> Result<(), Error> {
431 let ColumnInfo {
432 id: column_id,
433 col,
434 name,
435 shape,
436 ..
437 } = column_info;
438 match col {
439 ColumnDef::Committed { tower_level } => {
440 let oracle_id = oracles
441 .add_oracle(table_id, shape.log_values_per_row, name)
442 .committed(*tower_level);
443 oracle_lookup.register_regular(*column_id, oracle_id);
444 }
445 ColumnDef::Selected {
446 col,
447 index,
448 index_bits,
449 } => {
450 let index_values = (0..*index_bits)
451 .map(|i| {
452 if (index >> i) & 1 == 0 {
453 F::ZERO
454 } else {
455 F::ONE
456 }
457 })
458 .collect();
459 let oracle_id = oracles
460 .add_oracle(table_id, shape.log_values_per_row, name)
461 .projected(oracle_lookup[*col], index_values, 0)?;
462 oracle_lookup.register_regular(*column_id, oracle_id);
463 }
464 ColumnDef::Projected {
465 col,
466 start_index,
467 query_size,
468 query_bits,
469 } => {
470 let query_values = (0..*query_size)
471 .map(|i| -> F {
472 if (query_bits >> i) & 1 == 0 {
473 F::ZERO
474 } else {
475 F::ONE
476 }
477 })
478 .collect();
479 let oracle_id = oracles
480 .add_oracle(table_id, shape.log_values_per_row, name)
481 .projected(oracle_lookup[*col], query_values, *start_index)?;
482 oracle_lookup.register_regular(*column_id, oracle_id);
483 }
484 ColumnDef::ZeroPadded {
485 col,
486 n_pad_vars,
487 start_index,
488 nonzero_index,
489 } => {
490 let oracle_id = oracles
491 .add_oracle(table_id, shape.log_values_per_row, name)
492 .zero_padded(oracle_lookup[*col], *n_pad_vars, *nonzero_index, *start_index)?;
493 oracle_lookup.register_regular(*column_id, oracle_id);
494 }
495 ColumnDef::Shifted {
496 col,
497 offset,
498 log_block_size,
499 variant,
500 } => {
501 let oracle_id = oracles
503 .add_oracle(table_id, shape.log_values_per_row, name)
504 .shifted(oracle_lookup[*col], *offset, *log_block_size, *variant)?;
505 oracle_lookup.register_regular(*column_id, oracle_id);
506 }
507 ColumnDef::Packed { col, log_degree } => {
508 let source = oracle_lookup[*col];
510 let oracle_id = oracles
511 .add_oracle(table_id, shape.log_values_per_row, name)
512 .packed(source, *log_degree)?;
513 oracle_lookup.register_regular(*column_id, oracle_id);
514 }
515 ColumnDef::Computed { cols, expr } => {
516 if let Ok(LinearNormalForm {
517 constant: offset,
518 var_coeffs,
519 }) = expr.linear_normal_form()
520 {
521 let col_scalars = cols
522 .iter()
523 .zip(var_coeffs)
524 .map(|(&col_id, coeff)| (oracle_lookup[col_id], coeff))
525 .collect::<Vec<_>>();
526 let oracle_id = oracles
527 .add_oracle(table_id, shape.log_values_per_row, name)
528 .linear_combination_with_offset(offset, col_scalars)?;
529 oracle_lookup.register_regular(*column_id, oracle_id);
530 } else {
531 let inner_oracles = cols
532 .iter()
533 .map(|&col_index| oracle_lookup[col_index])
534 .collect::<Vec<_>>();
535 let oracle_id = oracles
536 .add_oracle(table_id, shape.log_values_per_row, name)
537 .composite_mle(inner_oracles, expr.clone())?;
538 oracle_lookup.register_regular(*column_id, oracle_id);
539 };
540 }
541 ColumnDef::Constant { poly, .. } => {
542 let oracle_id_original = oracles
543 .add_oracle(table_id, shape.log_values_per_row, format!("{name}_single"))
544 .transparent(poly.clone())?;
545 let oracle_id_repeating = oracles
546 .add_oracle(table_id, shape.log_values_per_row, name)
547 .repeating(oracle_id_original)?;
548 oracle_lookup.register_transparent(*column_id, oracle_id_original, oracle_id_repeating);
549 }
550 ColumnDef::StructuredDynSize(structured) => {
551 let expr = structured.expr()?;
552 let oracle_id = oracles
553 .add_oracle(table_id, shape.log_values_per_row, name)
554 .structured(ArithCircuit::from(&expr))?;
555 oracle_lookup.register_regular(*column_id, oracle_id);
556 }
557 ColumnDef::StructuredFixedSize { expr } => {
558 let oracle_id = oracles
559 .add_oracle(table_id, shape.log_values_per_row, name)
560 .transparent(expr.clone())?;
561 oracle_lookup.register_regular(*column_id, oracle_id);
562 }
563 ColumnDef::StaticExp {
564 base_tower_level, ..
565 } => {
566 let oracle_id = oracles
567 .add_oracle(table_id, shape.log_values_per_row, name)
568 .committed(*base_tower_level);
569 oracle_lookup.register_regular(*column_id, oracle_id);
570 }
571 ColumnDef::DynamicExp {
572 base_tower_level, ..
573 } => {
574 let oracle_id = oracles
575 .add_oracle(table_id, shape.log_values_per_row, name)
576 .committed(*base_tower_level);
577 oracle_lookup.register_regular(*column_id, oracle_id);
578 }
579 };
580 Ok(())
581}
582
583fn translate_constraint_set<F: TowerField>(
588 table_id: TableId,
589 log_values_per_row: usize,
590 zero_constraints: &[ZeroConstraint<F>],
591 partition_oracle_ids: Vec<OracleId>,
592) -> ConstraintSet<F> {
593 let mut oracle_appears_in_expr = vec![false; partition_oracle_ids.len()];
596 let mut n_used_oracles = 0usize;
597 for zero_constraint in zero_constraints {
598 let vars_usage = zero_constraint.expr.vars_usage();
599 for (oracle_index, used) in vars_usage.iter().enumerate() {
600 if *used && !oracle_appears_in_expr[oracle_index] {
601 oracle_appears_in_expr[oracle_index] = true;
602 n_used_oracles += 1;
603 }
604 }
605 }
606
607 const INVALID_SENTINEL: usize = usize::MAX;
611 let mut remap_indices_table = vec![INVALID_SENTINEL; partition_oracle_ids.len()];
612 let mut dense_oracle_ids = Vec::with_capacity(n_used_oracles);
613 for (i, &used) in oracle_appears_in_expr.iter().enumerate() {
614 if !used {
615 continue;
616 }
617 let dense_index = dense_oracle_ids.len();
618 dense_oracle_ids.push(partition_oracle_ids[i]);
619 remap_indices_table[i] = dense_index;
620 }
621
622 let compiled_constraints = zero_constraints
624 .iter()
625 .map(|zero_constraint| {
626 let expr = zero_constraint
627 .expr
628 .clone()
629 .remap_vars(&remap_indices_table)
630 .expect(
631 "the expr must have the same length as partition_oracle_ids which is the\
632 same length of remap_indices_table",
633 );
634 Constraint {
635 name: zero_constraint.name.clone(),
636 composition: expr,
637 predicate: ConstraintPredicate::Zero,
638 }
639 })
640 .collect::<Vec<_>>();
641
642 ConstraintSet {
643 table_id,
644 log_values_per_row,
645 oracle_ids: dense_oracle_ids,
646 constraints: compiled_constraints,
647 }
648}