binius_circuits/builder/
constraint_system.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cell::RefCell, collections::HashMap, rc::Rc};
4
5use anyhow::{anyhow, ensure};
6use binius_core::{
7	constraint_system::{
8		channel::{ChannelId, Flush, FlushDirection, OracleOrConst},
9		exp::Exp,
10		ConstraintSystem,
11	},
12	oracle::{
13		ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, OracleId, ShiftVariant,
14	},
15	polynomial::MultivariatePoly,
16	transparent::step_down::StepDown,
17	witness::MultilinearExtensionIndex,
18};
19use binius_field::{
20	as_packed_field::{PackScalar, PackedType},
21	BinaryField1b,
22};
23use binius_math::ArithExpr;
24use binius_utils::bail;
25
26use crate::builder::{
27	types::{F, U},
28	witness,
29};
30
31#[derive(Default)]
32pub struct ConstraintSystemBuilder<'arena> {
33	oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
34	constraints: ConstraintSetBuilder<F>,
35	non_zero_oracle_ids: Vec<OracleId>,
36	flushes: Vec<Flush<F>>,
37	exponents: Vec<Exp<F>>,
38	step_down_dedup: HashMap<(usize, usize), OracleId>,
39	witness: Option<witness::Builder<'arena>>,
40	next_channel_id: ChannelId,
41	namespace_path: Vec<String>,
42}
43
44impl<'arena> ConstraintSystemBuilder<'arena> {
45	pub fn new() -> Self {
46		Self::default()
47	}
48
49	pub fn new_with_witness(allocator: &'arena bumpalo::Bump) -> Self {
50		let oracles = Rc::new(RefCell::new(MultilinearOracleSet::new()));
51		Self {
52			witness: Some(witness::Builder::new(allocator, oracles.clone())),
53			oracles,
54			..Default::default()
55		}
56	}
57
58	#[allow(clippy::type_complexity)]
59	pub fn build(self) -> Result<ConstraintSystem<F>, anyhow::Error> {
60		let table_constraints = self.constraints.build(&self.oracles.borrow())?;
61		Ok(ConstraintSystem {
62			max_channel_id: self
63				.flushes
64				.iter()
65				.map(|flush| flush.channel_id)
66				.max()
67				.unwrap_or(0),
68			table_constraints,
69			non_zero_oracle_ids: self.non_zero_oracle_ids,
70			oracles: Rc::into_inner(self.oracles)
71				.ok_or_else(|| {
72					anyhow!("Failed to build ConstraintSystem: references still exist to oracles")
73				})?
74				.into_inner(),
75			flushes: self.flushes,
76			exponents: self.exponents,
77		})
78	}
79
80	pub const fn witness(&mut self) -> Option<&mut witness::Builder<'arena>> {
81		self.witness.as_mut()
82	}
83
84	pub fn take_witness(
85		&mut self,
86	) -> Result<MultilinearExtensionIndex<'arena, PackedType<U, F>>, anyhow::Error> {
87		Option::take(&mut self.witness)
88			.ok_or_else(|| {
89				anyhow!("Witness is missing. Are you in verifier mode, or have you already extraced the witness?")
90			})?
91			.build()
92	}
93
94	pub fn flush(
95		&mut self,
96		direction: FlushDirection,
97		channel_id: ChannelId,
98		count: usize,
99		oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
100	) -> anyhow::Result<()>
101	where
102		U: PackScalar<BinaryField1b>,
103	{
104		self.flush_with_multiplicity(direction, channel_id, count, oracle_ids, 1)
105	}
106
107	pub fn flush_with_multiplicity(
108		&mut self,
109		direction: FlushDirection,
110		channel_id: ChannelId,
111		count: usize,
112		oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
113		multiplicity: u64,
114	) -> anyhow::Result<()>
115	where
116		U: PackScalar<BinaryField1b>,
117	{
118		//We assume there is at least one non constant in the collection of oracle ids.
119		let non_const_oracles = oracle_ids
120			.clone()
121			.into_iter()
122			.filter_map(|id| match id {
123				OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
124				_ => None,
125			})
126			.collect::<Vec<_>>();
127
128		let n_vars = self.log_rows(non_const_oracles)?;
129
130		let selector = if let Some(&selector) = self.step_down_dedup.get(&(n_vars, count)) {
131			selector
132		} else {
133			let step_down = StepDown::new(n_vars, count)?;
134			let selector = self.add_transparent(
135				format!("internal step_down {count}-{n_vars}"),
136				step_down.clone(),
137			)?;
138
139			if let Some(witness) = self.witness() {
140				step_down.populate(witness.new_column::<BinaryField1b>(selector).packed());
141			}
142
143			self.step_down_dedup.insert((n_vars, count), selector);
144			selector
145		};
146
147		self.flush_custom(direction, channel_id, selector, oracle_ids, multiplicity)
148	}
149
150	pub fn flush_custom(
151		&mut self,
152		direction: FlushDirection,
153		channel_id: ChannelId,
154		selector: OracleId,
155		oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
156		multiplicity: u64,
157	) -> anyhow::Result<()> {
158		//We assume there is atleast one non constant in the collection of oracle ids.
159		let non_const_oracles = oracle_ids
160			.clone()
161			.into_iter()
162			.filter_map(|id| match id {
163				OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
164				_ => None,
165			})
166			.collect::<Vec<_>>();
167
168		let log_rows = self.log_rows(non_const_oracles.iter().copied())?;
169		ensure!(
170			log_rows == self.log_rows([selector])?,
171			"Selector {} n_vars does not match flush {:?}",
172			selector,
173			non_const_oracles
174		);
175
176		let oracles = oracle_ids.into_iter().collect();
177		self.flushes.push(Flush {
178			channel_id,
179			direction,
180			selector: Some(selector),
181			oracles,
182			multiplicity,
183		});
184
185		Ok(())
186	}
187
188	pub fn send(
189		&mut self,
190		channel_id: ChannelId,
191		count: usize,
192		oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
193	) -> anyhow::Result<()>
194	where
195		U: PackScalar<BinaryField1b>,
196	{
197		self.flush(FlushDirection::Push, channel_id, count, oracle_ids)
198	}
199
200	pub fn receive(
201		&mut self,
202		channel_id: ChannelId,
203		count: usize,
204		oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
205	) -> anyhow::Result<()>
206	where
207		U: PackScalar<BinaryField1b>,
208	{
209		self.flush(FlushDirection::Pull, channel_id, count, oracle_ids)
210	}
211
212	pub fn assert_zero(
213		&mut self,
214		name: impl ToString,
215		oracle_ids: impl IntoIterator<Item = OracleId>,
216		composition: ArithExpr<F>,
217	) {
218		self.constraints
219			.add_zerocheck(name, oracle_ids, composition);
220	}
221
222	pub fn assert_not_zero(&mut self, oracle_id: OracleId) {
223		self.non_zero_oracle_ids.push(oracle_id);
224	}
225
226	pub const fn add_channel(&mut self) -> ChannelId {
227		let channel_id = self.next_channel_id;
228		self.next_channel_id += 1;
229		channel_id
230	}
231
232	pub fn add_committed(
233		&mut self,
234		name: impl ToString,
235		n_vars: usize,
236		tower_level: usize,
237	) -> OracleId {
238		self.oracles
239			.borrow_mut()
240			.add_named(self.scoped_name(name))
241			.committed(n_vars, tower_level)
242	}
243
244	/// Adds an exponentiation operation to the constraint system.
245	///
246	/// # Parameters
247	/// - `bits_ids`: A vector of `OracleId` representing the exponent in little-endian bit order.
248	/// - `exp_result_id`: The `OracleId` that holds the result of the exponentiation..
249	/// - `base`: The static base value.
250	/// - `base_tower_level`: Specifies the field level in the tower where `base` is defined
251	pub fn add_static_exp(
252		&mut self,
253		bits_ids: Vec<OracleId>,
254		exp_result_id: OracleId,
255		base: F,
256		base_tower_level: usize,
257	) {
258		self.exponents.push(Exp {
259			bits_ids,
260			exp_result_id,
261			base: OracleOrConst::Const {
262				base,
263				tower_level: base_tower_level,
264			},
265		});
266	}
267
268	/// Adds an exponentiation operation to the constraint system.
269	///
270	/// # Parameters
271	/// - `bits_ids`: A vector of `OracleId` representing the exponent in little-endian bit order.
272	/// - `exp_result_id`: The `OracleId` that holds the result of the exponentiation..
273	/// - `base`: The dynamic base value.
274	pub fn add_dynamic_exp(
275		&mut self,
276		bits_ids: Vec<OracleId>,
277		exp_result_id: OracleId,
278		base: OracleId,
279	) {
280		self.exponents.push(Exp {
281			bits_ids,
282			exp_result_id,
283			base: OracleOrConst::Oracle(base),
284		});
285	}
286
287	pub fn add_committed_multiple<const N: usize>(
288		&mut self,
289		name: impl ToString,
290		n_vars: usize,
291		tower_level: usize,
292	) -> [OracleId; N] {
293		self.oracles
294			.borrow_mut()
295			.add_named(self.scoped_name(name))
296			.committed_multiple(n_vars, tower_level)
297	}
298
299	pub fn add_linear_combination(
300		&mut self,
301		name: impl ToString,
302		n_vars: usize,
303		inner: impl IntoIterator<Item = (OracleId, F)>,
304	) -> Result<OracleId, OracleError> {
305		self.oracles
306			.borrow_mut()
307			.add_named(self.scoped_name(name))
308			.linear_combination(n_vars, inner)
309	}
310
311	pub fn add_linear_combination_with_offset(
312		&mut self,
313		name: impl ToString,
314		n_vars: usize,
315		offset: F,
316		inner: impl IntoIterator<Item = (OracleId, F)>,
317	) -> Result<OracleId, OracleError> {
318		self.oracles
319			.borrow_mut()
320			.add_named(self.scoped_name(name))
321			.linear_combination_with_offset(n_vars, offset, inner)
322	}
323
324	pub fn add_composite_mle(
325		&mut self,
326		name: impl ToString,
327		n_vars: usize,
328		inner: impl IntoIterator<Item = OracleId>,
329		comp: ArithExpr<F>,
330	) -> Result<OracleId, OracleError> {
331		self.oracles
332			.borrow_mut()
333			.add_named(self.scoped_name(name))
334			.composite_mle(n_vars, inner, comp)
335	}
336
337	pub fn add_packed(
338		&mut self,
339		name: impl ToString,
340		id: OracleId,
341		log_degree: usize,
342	) -> Result<OracleId, OracleError> {
343		self.oracles
344			.borrow_mut()
345			.add_named(self.scoped_name(name))
346			.packed(id, log_degree)
347	}
348
349	/// Adds a projection to the variables starting at `start_index`.
350	pub fn add_projected(
351		&mut self,
352		name: impl ToString,
353		id: OracleId,
354		values: Vec<F>,
355		start_index: usize,
356	) -> Result<usize, OracleError> {
357		self.oracles
358			.borrow_mut()
359			.add_named(self.scoped_name(name))
360			.projected(id, values, start_index)
361	}
362
363	/// Adds a projection to the last variables.
364	pub fn add_projected_last_vars(
365		&mut self,
366		name: impl ToString,
367		id: OracleId,
368		values: Vec<F>,
369	) -> Result<usize, OracleError> {
370		self.oracles
371			.borrow_mut()
372			.add_named(self.scoped_name(name))
373			.projected_last_vars(id, values)
374	}
375
376	pub fn add_repeating(
377		&mut self,
378		name: impl ToString,
379		id: OracleId,
380		log_count: usize,
381	) -> Result<OracleId, OracleError> {
382		self.oracles
383			.borrow_mut()
384			.add_named(self.scoped_name(name))
385			.repeating(id, log_count)
386	}
387
388	pub fn add_shifted(
389		&mut self,
390		name: impl ToString,
391		id: OracleId,
392		offset: usize,
393		block_bits: usize,
394		variant: ShiftVariant,
395	) -> Result<OracleId, OracleError> {
396		self.oracles
397			.borrow_mut()
398			.add_named(self.scoped_name(name))
399			.shifted(id, offset, block_bits, variant)
400	}
401
402	pub fn add_transparent(
403		&mut self,
404		name: impl ToString,
405		poly: impl MultivariatePoly<F> + 'static,
406	) -> Result<OracleId, OracleError> {
407		self.oracles
408			.borrow_mut()
409			.add_named(self.scoped_name(name))
410			.transparent(poly)
411	}
412
413	pub fn add_zero_padded(
414		&mut self,
415		name: impl ToString,
416		id: OracleId,
417		n_vars: usize,
418	) -> Result<OracleId, OracleError> {
419		self.oracles
420			.borrow_mut()
421			.add_named(self.scoped_name(name))
422			.zero_padded(id, n_vars)
423	}
424
425	fn scoped_name(&self, name: impl ToString) -> String {
426		let name = name.to_string();
427		if self.namespace_path.is_empty() {
428			name
429		} else {
430			format!("{}::{name}", self.namespace_path.join("::"))
431		}
432	}
433
434	/// Anything pushed to the namespace will become part of oracle name, which is useful for debugging.
435	///
436	/// Use `pop_namespace(&mut self)` to remove the latest name.
437	///
438	/// Example
439	/// ```
440	/// use binius_circuits::builder::ConstraintSystemBuilder;
441	/// use binius_field::{TowerField, BinaryField128b, BinaryField1b, arch::OptimalUnderlier};
442	///
443	/// let log_size = 14;
444	///
445	/// let mut builder = ConstraintSystemBuilder::new();
446	/// builder.push_namespace("a");
447	/// let x = builder.add_committed("x", log_size, BinaryField1b::TOWER_LEVEL);
448	/// builder.push_namespace("b");
449	/// let y = builder.add_committed("y", log_size, BinaryField1b::TOWER_LEVEL);
450	/// builder.pop_namespace();
451	/// builder.pop_namespace();
452	/// let z = builder.add_committed("z", log_size, BinaryField1b::TOWER_LEVEL);
453	///
454	/// let system = builder.build().unwrap();
455	/// assert_eq!(system.oracles.oracle(x).name().unwrap(), "a::x");
456	/// assert_eq!(system.oracles.oracle(y).name().unwrap(), "a::b::y");
457	/// assert_eq!(system.oracles.oracle(z).name().unwrap(), "z");
458	/// ```
459	pub fn push_namespace(&mut self, name: impl ToString) {
460		self.namespace_path.push(name.to_string());
461	}
462
463	pub fn pop_namespace(&mut self) {
464		self.namespace_path.pop();
465	}
466
467	/// Returns the number of rows shared by a set of columns.
468	///
469	/// Fails if no columns are provided, or not all columns have the same number of rows.
470	///
471	/// This is useful for writing circuits with internal columns that depend on the height of input columns.
472	pub fn log_rows(
473		&self,
474		oracle_ids: impl IntoIterator<Item = OracleId>,
475	) -> anyhow::Result<usize> {
476		let mut oracle_ids = oracle_ids.into_iter();
477		let oracles = self.oracles.borrow();
478		let Some(first_id) = oracle_ids.next() else {
479			bail!(anyhow!("log_rows: You need to specify at least one column"));
480		};
481		let log_rows = oracles.n_vars(first_id);
482		if oracle_ids.any(|id| oracles.n_vars(id) != log_rows) {
483			bail!(anyhow!("log_rows: All columns must have the same number of rows"))
484		}
485		Ok(log_rows)
486	}
487}