binius_circuits/builder/
constraint_system.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cell::RefCell, collections::HashMap, rc::Rc};
4
5use anyhow::{anyhow, ensure};
6use binius_core::{
7	constraint_system::{
8		channel::{ChannelId, Flush, FlushDirection},
9		ConstraintSystem,
10	},
11	oracle::{
12		ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, OracleId,
13		ProjectionVariant, ShiftVariant,
14	},
15	polynomial::MultivariatePoly,
16	transparent::step_down::StepDown,
17	witness::MultilinearExtensionIndex,
18};
19use binius_field::{as_packed_field::PackScalar, BinaryField1b};
20use binius_math::ArithExpr;
21use binius_utils::bail;
22
23use crate::builder::{
24	types::{F, U},
25	witness,
26};
27
28#[derive(Default)]
29pub struct ConstraintSystemBuilder<'arena> {
30	oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
31	constraints: ConstraintSetBuilder<F>,
32	non_zero_oracle_ids: Vec<OracleId>,
33	flushes: Vec<Flush>,
34	step_down_dedup: HashMap<(usize, usize), OracleId>,
35	witness: Option<witness::Builder<'arena>>,
36	next_channel_id: ChannelId,
37	namespace_path: Vec<String>,
38}
39
40impl<'arena> ConstraintSystemBuilder<'arena> {
41	pub fn new() -> Self {
42		Self::default()
43	}
44
45	pub fn new_with_witness(allocator: &'arena bumpalo::Bump) -> Self {
46		let oracles = Rc::new(RefCell::new(MultilinearOracleSet::new()));
47		Self {
48			witness: Some(witness::Builder::new(allocator, oracles.clone())),
49			oracles,
50			..Default::default()
51		}
52	}
53
54	#[allow(clippy::type_complexity)]
55	pub fn build(self) -> Result<ConstraintSystem<F>, anyhow::Error> {
56		let table_constraints = self.constraints.build(&self.oracles.borrow())?;
57		Ok(ConstraintSystem {
58			max_channel_id: self
59				.flushes
60				.iter()
61				.map(|flush| flush.channel_id)
62				.max()
63				.unwrap_or(0),
64			table_constraints,
65			non_zero_oracle_ids: self.non_zero_oracle_ids,
66			oracles: Rc::into_inner(self.oracles)
67				.ok_or_else(|| {
68					anyhow!("Failed to build ConstraintSystem: references still exist to oracles")
69				})?
70				.into_inner(),
71			flushes: self.flushes,
72		})
73	}
74
75	pub const fn witness(&mut self) -> Option<&mut witness::Builder<'arena>> {
76		self.witness.as_mut()
77	}
78
79	pub fn take_witness(
80		&mut self,
81	) -> Result<MultilinearExtensionIndex<'arena, U, F>, anyhow::Error> {
82		Option::take(&mut self.witness)
83			.ok_or_else(|| {
84				anyhow!("Witness is missing. Are you in verifier mode, or have you already extraced the witness?")
85			})?
86			.build()
87	}
88
89	pub fn flush(
90		&mut self,
91		direction: FlushDirection,
92		channel_id: ChannelId,
93		count: usize,
94		oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
95	) -> anyhow::Result<()>
96	where
97		U: PackScalar<BinaryField1b>,
98	{
99		self.flush_with_multiplicity(direction, channel_id, count, oracle_ids, 1)
100	}
101
102	pub fn flush_with_multiplicity(
103		&mut self,
104		direction: FlushDirection,
105		channel_id: ChannelId,
106		count: usize,
107		oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
108		multiplicity: u64,
109	) -> anyhow::Result<()>
110	where
111		U: PackScalar<BinaryField1b>,
112	{
113		let n_vars = self.log_rows(oracle_ids.clone())?;
114
115		let selector = if let Some(&selector) = self.step_down_dedup.get(&(n_vars, count)) {
116			selector
117		} else {
118			let step_down = StepDown::new(n_vars, count)?;
119			let selector = self.add_transparent(
120				format!("internal step_down {count}-{n_vars}"),
121				step_down.clone(),
122			)?;
123
124			if let Some(witness) = self.witness() {
125				step_down.populate(witness.new_column::<BinaryField1b>(selector).packed());
126			}
127
128			self.step_down_dedup.insert((n_vars, count), selector);
129			selector
130		};
131
132		self.flush_custom(direction, channel_id, selector, oracle_ids, multiplicity)
133	}
134
135	pub fn flush_custom(
136		&mut self,
137		direction: FlushDirection,
138		channel_id: ChannelId,
139		selector: OracleId,
140		oracle_ids: impl IntoIterator<Item = OracleId>,
141		multiplicity: u64,
142	) -> anyhow::Result<()> {
143		let oracles = oracle_ids.into_iter().collect::<Vec<_>>();
144		let log_rows = self.log_rows(oracles.iter().copied())?;
145		ensure!(
146			log_rows == self.log_rows([selector])?,
147			"Selector {} n_vars does not match flush {:?}",
148			selector,
149			oracles
150		);
151
152		self.flushes.push(Flush {
153			channel_id,
154			direction,
155			selector,
156			oracles,
157			multiplicity,
158		});
159
160		Ok(())
161	}
162
163	pub fn send(
164		&mut self,
165		channel_id: ChannelId,
166		count: usize,
167		oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
168	) -> anyhow::Result<()>
169	where
170		U: PackScalar<BinaryField1b>,
171	{
172		self.flush(FlushDirection::Push, channel_id, count, oracle_ids)
173	}
174
175	pub fn receive(
176		&mut self,
177		channel_id: ChannelId,
178		count: usize,
179		oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
180	) -> anyhow::Result<()>
181	where
182		U: PackScalar<BinaryField1b>,
183	{
184		self.flush(FlushDirection::Pull, channel_id, count, oracle_ids)
185	}
186
187	pub fn assert_zero(
188		&mut self,
189		name: impl ToString,
190		oracle_ids: impl IntoIterator<Item = OracleId>,
191		composition: ArithExpr<F>,
192	) {
193		self.constraints
194			.add_zerocheck(name, oracle_ids, composition);
195	}
196
197	pub fn assert_not_zero(&mut self, oracle_id: OracleId) {
198		self.non_zero_oracle_ids.push(oracle_id);
199	}
200
201	pub const fn add_channel(&mut self) -> ChannelId {
202		let channel_id = self.next_channel_id;
203		self.next_channel_id += 1;
204		channel_id
205	}
206
207	pub fn add_committed(
208		&mut self,
209		name: impl ToString,
210		n_vars: usize,
211		tower_level: usize,
212	) -> OracleId {
213		self.oracles
214			.borrow_mut()
215			.add_named(self.scoped_name(name))
216			.committed(n_vars, tower_level)
217	}
218
219	pub fn add_committed_multiple<const N: usize>(
220		&mut self,
221		name: impl ToString,
222		n_vars: usize,
223		tower_level: usize,
224	) -> [OracleId; N] {
225		self.oracles
226			.borrow_mut()
227			.add_named(self.scoped_name(name))
228			.committed_multiple(n_vars, tower_level)
229	}
230
231	pub fn add_linear_combination(
232		&mut self,
233		name: impl ToString,
234		n_vars: usize,
235		inner: impl IntoIterator<Item = (OracleId, F)>,
236	) -> Result<OracleId, OracleError> {
237		self.oracles
238			.borrow_mut()
239			.add_named(self.scoped_name(name))
240			.linear_combination(n_vars, inner)
241	}
242
243	pub fn add_linear_combination_with_offset(
244		&mut self,
245		name: impl ToString,
246		n_vars: usize,
247		offset: F,
248		inner: impl IntoIterator<Item = (OracleId, F)>,
249	) -> Result<OracleId, OracleError> {
250		self.oracles
251			.borrow_mut()
252			.add_named(self.scoped_name(name))
253			.linear_combination_with_offset(n_vars, offset, inner)
254	}
255
256	pub fn add_composite_mle(
257		&mut self,
258		name: impl ToString,
259		n_vars: usize,
260		inner: impl IntoIterator<Item = OracleId>,
261		comp: ArithExpr<F>,
262	) -> Result<OracleId, OracleError> {
263		self.oracles
264			.borrow_mut()
265			.add_named(self.scoped_name(name))
266			.composite_mle(n_vars, inner, comp)
267	}
268
269	pub fn add_packed(
270		&mut self,
271		name: impl ToString,
272		id: OracleId,
273		log_degree: usize,
274	) -> Result<OracleId, OracleError> {
275		self.oracles
276			.borrow_mut()
277			.add_named(self.scoped_name(name))
278			.packed(id, log_degree)
279	}
280
281	pub fn add_projected(
282		&mut self,
283		name: impl ToString,
284		id: OracleId,
285		values: Vec<F>,
286		variant: ProjectionVariant,
287	) -> Result<usize, OracleError> {
288		self.oracles
289			.borrow_mut()
290			.add_named(self.scoped_name(name))
291			.projected(id, values, variant)
292	}
293
294	pub fn add_repeating(
295		&mut self,
296		name: impl ToString,
297		id: OracleId,
298		log_count: usize,
299	) -> Result<OracleId, OracleError> {
300		self.oracles
301			.borrow_mut()
302			.add_named(self.scoped_name(name))
303			.repeating(id, log_count)
304	}
305
306	pub fn add_shifted(
307		&mut self,
308		name: impl ToString,
309		id: OracleId,
310		offset: usize,
311		block_bits: usize,
312		variant: ShiftVariant,
313	) -> Result<OracleId, OracleError> {
314		self.oracles
315			.borrow_mut()
316			.add_named(self.scoped_name(name))
317			.shifted(id, offset, block_bits, variant)
318	}
319
320	pub fn add_transparent(
321		&mut self,
322		name: impl ToString,
323		poly: impl MultivariatePoly<F> + 'static,
324	) -> Result<OracleId, OracleError> {
325		self.oracles
326			.borrow_mut()
327			.add_named(self.scoped_name(name))
328			.transparent(poly)
329	}
330
331	pub fn add_zero_padded(
332		&mut self,
333		name: impl ToString,
334		id: OracleId,
335		n_vars: usize,
336	) -> Result<OracleId, OracleError> {
337		self.oracles
338			.borrow_mut()
339			.add_named(self.scoped_name(name))
340			.zero_padded(id, n_vars)
341	}
342
343	fn scoped_name(&self, name: impl ToString) -> String {
344		let name = name.to_string();
345		if self.namespace_path.is_empty() {
346			name
347		} else {
348			format!("{}::{name}", self.namespace_path.join("::"))
349		}
350	}
351
352	/// Anything pushed to the namespace will become part of oracle name, which is useful for debugging.
353	///
354	/// Use `pop_namespace(&mut self)` to remove the latest name.
355	///
356	/// Example
357	/// ```
358	/// use binius_circuits::builder::ConstraintSystemBuilder;
359	/// use binius_field::{TowerField, BinaryField128b, BinaryField1b, arch::OptimalUnderlier};
360	///
361	/// let log_size = 14;
362	///
363	/// let mut builder = ConstraintSystemBuilder::new();
364	/// builder.push_namespace("a");
365	/// let x = builder.add_committed("x", log_size, BinaryField1b::TOWER_LEVEL);
366	/// builder.push_namespace("b");
367	/// let y = builder.add_committed("y", log_size, BinaryField1b::TOWER_LEVEL);
368	/// builder.pop_namespace();
369	/// builder.pop_namespace();
370	/// let z = builder.add_committed("z", log_size, BinaryField1b::TOWER_LEVEL);
371	///
372	/// let system = builder.build().unwrap();
373	/// assert_eq!(system.oracles.oracle(x).name().unwrap(), "a::x");
374	/// assert_eq!(system.oracles.oracle(y).name().unwrap(), "a::b::y");
375	/// assert_eq!(system.oracles.oracle(z).name().unwrap(), "z");
376	/// ```
377	pub fn push_namespace(&mut self, name: impl ToString) {
378		self.namespace_path.push(name.to_string());
379	}
380
381	pub fn pop_namespace(&mut self) {
382		self.namespace_path.pop();
383	}
384
385	/// Returns the number of rows shared by a set of columns.
386	///
387	/// Fails if no columns are provided, or not all columns have the same number of rows.
388	///
389	/// This is useful for writing circuits with internal columns that depend on the height of input columns.
390	pub fn log_rows(
391		&self,
392		oracle_ids: impl IntoIterator<Item = OracleId>,
393	) -> anyhow::Result<usize> {
394		let mut oracle_ids = oracle_ids.into_iter();
395		let oracles = self.oracles.borrow();
396		let Some(first_id) = oracle_ids.next() else {
397			bail!(anyhow!("log_rows: You need to specify at least one column"));
398		};
399		let log_rows = oracles.n_vars(first_id);
400		if oracle_ids.any(|id| oracles.n_vars(id) != log_rows) {
401			bail!(anyhow!("log_rows: All columns must have the same number of rows"))
402		}
403		Ok(log_rows)
404	}
405}