1use std::{cell::RefCell, collections::HashMap, rc::Rc};
4
5use anyhow::{anyhow, ensure};
6use binius_core::{
7 constraint_system::{
8 channel::{ChannelId, Flush, FlushDirection, OracleOrConst},
9 exp::Exp,
10 ConstraintSystem,
11 },
12 oracle::{
13 ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, OracleId, ShiftVariant,
14 },
15 polynomial::MultivariatePoly,
16 transparent::step_down::StepDown,
17 witness::MultilinearExtensionIndex,
18};
19use binius_field::{
20 as_packed_field::{PackScalar, PackedType},
21 BinaryField1b,
22};
23use binius_math::ArithExpr;
24use binius_utils::bail;
25
26use crate::builder::{
27 types::{F, U},
28 witness,
29};
30
31#[derive(Default)]
32pub struct ConstraintSystemBuilder<'arena> {
33 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
34 constraints: ConstraintSetBuilder<F>,
35 non_zero_oracle_ids: Vec<OracleId>,
36 flushes: Vec<Flush<F>>,
37 exponents: Vec<Exp<F>>,
38 step_down_dedup: HashMap<(usize, usize), OracleId>,
39 witness: Option<witness::Builder<'arena>>,
40 next_channel_id: ChannelId,
41 namespace_path: Vec<String>,
42}
43
44impl<'arena> ConstraintSystemBuilder<'arena> {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn new_with_witness(allocator: &'arena bumpalo::Bump) -> Self {
50 let oracles = Rc::new(RefCell::new(MultilinearOracleSet::new()));
51 Self {
52 witness: Some(witness::Builder::new(allocator, oracles.clone())),
53 oracles,
54 ..Default::default()
55 }
56 }
57
58 #[allow(clippy::type_complexity)]
59 pub fn build(self) -> Result<ConstraintSystem<F>, anyhow::Error> {
60 let table_constraints = self.constraints.build(&self.oracles.borrow())?;
61 Ok(ConstraintSystem {
62 max_channel_id: self
63 .flushes
64 .iter()
65 .map(|flush| flush.channel_id)
66 .max()
67 .unwrap_or(0),
68 table_constraints,
69 non_zero_oracle_ids: self.non_zero_oracle_ids,
70 oracles: Rc::into_inner(self.oracles)
71 .ok_or_else(|| {
72 anyhow!("Failed to build ConstraintSystem: references still exist to oracles")
73 })?
74 .into_inner(),
75 flushes: self.flushes,
76 exponents: self.exponents,
77 })
78 }
79
80 pub const fn witness(&mut self) -> Option<&mut witness::Builder<'arena>> {
81 self.witness.as_mut()
82 }
83
84 pub fn take_witness(
85 &mut self,
86 ) -> Result<MultilinearExtensionIndex<'arena, PackedType<U, F>>, anyhow::Error> {
87 Option::take(&mut self.witness)
88 .ok_or_else(|| {
89 anyhow!("Witness is missing. Are you in verifier mode, or have you already extraced the witness?")
90 })?
91 .build()
92 }
93
94 pub fn flush(
95 &mut self,
96 direction: FlushDirection,
97 channel_id: ChannelId,
98 count: usize,
99 oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
100 ) -> anyhow::Result<()>
101 where
102 U: PackScalar<BinaryField1b>,
103 {
104 self.flush_with_multiplicity(direction, channel_id, count, oracle_ids, 1)
105 }
106
107 pub fn flush_with_multiplicity(
108 &mut self,
109 direction: FlushDirection,
110 channel_id: ChannelId,
111 count: usize,
112 oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
113 multiplicity: u64,
114 ) -> anyhow::Result<()>
115 where
116 U: PackScalar<BinaryField1b>,
117 {
118 let non_const_oracles = oracle_ids
120 .clone()
121 .into_iter()
122 .filter_map(|id| match id {
123 OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
124 _ => None,
125 })
126 .collect::<Vec<_>>();
127
128 let n_vars = self.log_rows(non_const_oracles)?;
129
130 let selector = if let Some(&selector) = self.step_down_dedup.get(&(n_vars, count)) {
131 selector
132 } else {
133 let step_down = StepDown::new(n_vars, count)?;
134 let selector = self.add_transparent(
135 format!("internal step_down {count}-{n_vars}"),
136 step_down.clone(),
137 )?;
138
139 if let Some(witness) = self.witness() {
140 step_down.populate(witness.new_column::<BinaryField1b>(selector).packed());
141 }
142
143 self.step_down_dedup.insert((n_vars, count), selector);
144 selector
145 };
146
147 self.flush_custom(direction, channel_id, selector, oracle_ids, multiplicity)
148 }
149
150 pub fn flush_custom(
151 &mut self,
152 direction: FlushDirection,
153 channel_id: ChannelId,
154 selector: OracleId,
155 oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
156 multiplicity: u64,
157 ) -> anyhow::Result<()> {
158 let non_const_oracles = oracle_ids
160 .clone()
161 .into_iter()
162 .filter_map(|id| match id {
163 OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
164 _ => None,
165 })
166 .collect::<Vec<_>>();
167
168 let log_rows = self.log_rows(non_const_oracles.iter().copied())?;
169 ensure!(
170 log_rows == self.log_rows([selector])?,
171 "Selector {} n_vars does not match flush {:?}",
172 selector,
173 non_const_oracles
174 );
175
176 let oracles = oracle_ids.into_iter().collect();
177 self.flushes.push(Flush {
178 channel_id,
179 direction,
180 selector: Some(selector),
181 oracles,
182 multiplicity,
183 });
184
185 Ok(())
186 }
187
188 pub fn send(
189 &mut self,
190 channel_id: ChannelId,
191 count: usize,
192 oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
193 ) -> anyhow::Result<()>
194 where
195 U: PackScalar<BinaryField1b>,
196 {
197 self.flush(FlushDirection::Push, channel_id, count, oracle_ids)
198 }
199
200 pub fn receive(
201 &mut self,
202 channel_id: ChannelId,
203 count: usize,
204 oracle_ids: impl IntoIterator<Item = OracleOrConst<F>> + Clone,
205 ) -> anyhow::Result<()>
206 where
207 U: PackScalar<BinaryField1b>,
208 {
209 self.flush(FlushDirection::Pull, channel_id, count, oracle_ids)
210 }
211
212 pub fn assert_zero(
213 &mut self,
214 name: impl ToString,
215 oracle_ids: impl IntoIterator<Item = OracleId>,
216 composition: ArithExpr<F>,
217 ) {
218 self.constraints
219 .add_zerocheck(name, oracle_ids, composition);
220 }
221
222 pub fn assert_not_zero(&mut self, oracle_id: OracleId) {
223 self.non_zero_oracle_ids.push(oracle_id);
224 }
225
226 pub const fn add_channel(&mut self) -> ChannelId {
227 let channel_id = self.next_channel_id;
228 self.next_channel_id += 1;
229 channel_id
230 }
231
232 pub fn add_committed(
233 &mut self,
234 name: impl ToString,
235 n_vars: usize,
236 tower_level: usize,
237 ) -> OracleId {
238 self.oracles
239 .borrow_mut()
240 .add_named(self.scoped_name(name))
241 .committed(n_vars, tower_level)
242 }
243
244 pub fn add_static_exp(
252 &mut self,
253 bits_ids: Vec<OracleId>,
254 exp_result_id: OracleId,
255 base: F,
256 base_tower_level: usize,
257 ) {
258 self.exponents.push(Exp {
259 bits_ids,
260 exp_result_id,
261 base: OracleOrConst::Const {
262 base,
263 tower_level: base_tower_level,
264 },
265 });
266 }
267
268 pub fn add_dynamic_exp(
275 &mut self,
276 bits_ids: Vec<OracleId>,
277 exp_result_id: OracleId,
278 base: OracleId,
279 ) {
280 self.exponents.push(Exp {
281 bits_ids,
282 exp_result_id,
283 base: OracleOrConst::Oracle(base),
284 });
285 }
286
287 pub fn add_committed_multiple<const N: usize>(
288 &mut self,
289 name: impl ToString,
290 n_vars: usize,
291 tower_level: usize,
292 ) -> [OracleId; N] {
293 self.oracles
294 .borrow_mut()
295 .add_named(self.scoped_name(name))
296 .committed_multiple(n_vars, tower_level)
297 }
298
299 pub fn add_linear_combination(
300 &mut self,
301 name: impl ToString,
302 n_vars: usize,
303 inner: impl IntoIterator<Item = (OracleId, F)>,
304 ) -> Result<OracleId, OracleError> {
305 self.oracles
306 .borrow_mut()
307 .add_named(self.scoped_name(name))
308 .linear_combination(n_vars, inner)
309 }
310
311 pub fn add_linear_combination_with_offset(
312 &mut self,
313 name: impl ToString,
314 n_vars: usize,
315 offset: F,
316 inner: impl IntoIterator<Item = (OracleId, F)>,
317 ) -> Result<OracleId, OracleError> {
318 self.oracles
319 .borrow_mut()
320 .add_named(self.scoped_name(name))
321 .linear_combination_with_offset(n_vars, offset, inner)
322 }
323
324 pub fn add_composite_mle(
325 &mut self,
326 name: impl ToString,
327 n_vars: usize,
328 inner: impl IntoIterator<Item = OracleId>,
329 comp: ArithExpr<F>,
330 ) -> Result<OracleId, OracleError> {
331 self.oracles
332 .borrow_mut()
333 .add_named(self.scoped_name(name))
334 .composite_mle(n_vars, inner, comp)
335 }
336
337 pub fn add_packed(
338 &mut self,
339 name: impl ToString,
340 id: OracleId,
341 log_degree: usize,
342 ) -> Result<OracleId, OracleError> {
343 self.oracles
344 .borrow_mut()
345 .add_named(self.scoped_name(name))
346 .packed(id, log_degree)
347 }
348
349 pub fn add_projected(
351 &mut self,
352 name: impl ToString,
353 id: OracleId,
354 values: Vec<F>,
355 start_index: usize,
356 ) -> Result<usize, OracleError> {
357 self.oracles
358 .borrow_mut()
359 .add_named(self.scoped_name(name))
360 .projected(id, values, start_index)
361 }
362
363 pub fn add_projected_last_vars(
365 &mut self,
366 name: impl ToString,
367 id: OracleId,
368 values: Vec<F>,
369 ) -> Result<usize, OracleError> {
370 self.oracles
371 .borrow_mut()
372 .add_named(self.scoped_name(name))
373 .projected_last_vars(id, values)
374 }
375
376 pub fn add_repeating(
377 &mut self,
378 name: impl ToString,
379 id: OracleId,
380 log_count: usize,
381 ) -> Result<OracleId, OracleError> {
382 self.oracles
383 .borrow_mut()
384 .add_named(self.scoped_name(name))
385 .repeating(id, log_count)
386 }
387
388 pub fn add_shifted(
389 &mut self,
390 name: impl ToString,
391 id: OracleId,
392 offset: usize,
393 block_bits: usize,
394 variant: ShiftVariant,
395 ) -> Result<OracleId, OracleError> {
396 self.oracles
397 .borrow_mut()
398 .add_named(self.scoped_name(name))
399 .shifted(id, offset, block_bits, variant)
400 }
401
402 pub fn add_transparent(
403 &mut self,
404 name: impl ToString,
405 poly: impl MultivariatePoly<F> + 'static,
406 ) -> Result<OracleId, OracleError> {
407 self.oracles
408 .borrow_mut()
409 .add_named(self.scoped_name(name))
410 .transparent(poly)
411 }
412
413 pub fn add_zero_padded(
414 &mut self,
415 name: impl ToString,
416 id: OracleId,
417 n_vars: usize,
418 ) -> Result<OracleId, OracleError> {
419 self.oracles
420 .borrow_mut()
421 .add_named(self.scoped_name(name))
422 .zero_padded(id, n_vars)
423 }
424
425 fn scoped_name(&self, name: impl ToString) -> String {
426 let name = name.to_string();
427 if self.namespace_path.is_empty() {
428 name
429 } else {
430 format!("{}::{name}", self.namespace_path.join("::"))
431 }
432 }
433
434 pub fn push_namespace(&mut self, name: impl ToString) {
460 self.namespace_path.push(name.to_string());
461 }
462
463 pub fn pop_namespace(&mut self) {
464 self.namespace_path.pop();
465 }
466
467 pub fn log_rows(
473 &self,
474 oracle_ids: impl IntoIterator<Item = OracleId>,
475 ) -> anyhow::Result<usize> {
476 let mut oracle_ids = oracle_ids.into_iter();
477 let oracles = self.oracles.borrow();
478 let Some(first_id) = oracle_ids.next() else {
479 bail!(anyhow!("log_rows: You need to specify at least one column"));
480 };
481 let log_rows = oracles.n_vars(first_id);
482 if oracle_ids.any(|id| oracles.n_vars(id) != log_rows) {
483 bail!(anyhow!("log_rows: All columns must have the same number of rows"))
484 }
485 Ok(log_rows)
486 }
487}