1use std::{cell::RefCell, collections::HashMap, rc::Rc};
4
5use anyhow::{anyhow, ensure};
6use binius_core::{
7 constraint_system::{
8 channel::{ChannelId, Flush, FlushDirection},
9 ConstraintSystem,
10 },
11 oracle::{
12 ConstraintSetBuilder, Error as OracleError, MultilinearOracleSet, OracleId,
13 ProjectionVariant, ShiftVariant,
14 },
15 polynomial::MultivariatePoly,
16 transparent::step_down::StepDown,
17 witness::MultilinearExtensionIndex,
18};
19use binius_field::{as_packed_field::PackScalar, BinaryField1b};
20use binius_math::ArithExpr;
21use binius_utils::bail;
22
23use crate::builder::{
24 types::{F, U},
25 witness,
26};
27
28#[derive(Default)]
29pub struct ConstraintSystemBuilder<'arena> {
30 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
31 constraints: ConstraintSetBuilder<F>,
32 non_zero_oracle_ids: Vec<OracleId>,
33 flushes: Vec<Flush>,
34 step_down_dedup: HashMap<(usize, usize), OracleId>,
35 witness: Option<witness::Builder<'arena>>,
36 next_channel_id: ChannelId,
37 namespace_path: Vec<String>,
38}
39
40impl<'arena> ConstraintSystemBuilder<'arena> {
41 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn new_with_witness(allocator: &'arena bumpalo::Bump) -> Self {
46 let oracles = Rc::new(RefCell::new(MultilinearOracleSet::new()));
47 Self {
48 witness: Some(witness::Builder::new(allocator, oracles.clone())),
49 oracles,
50 ..Default::default()
51 }
52 }
53
54 #[allow(clippy::type_complexity)]
55 pub fn build(self) -> Result<ConstraintSystem<F>, anyhow::Error> {
56 let table_constraints = self.constraints.build(&self.oracles.borrow())?;
57 Ok(ConstraintSystem {
58 max_channel_id: self
59 .flushes
60 .iter()
61 .map(|flush| flush.channel_id)
62 .max()
63 .unwrap_or(0),
64 table_constraints,
65 non_zero_oracle_ids: self.non_zero_oracle_ids,
66 oracles: Rc::into_inner(self.oracles)
67 .ok_or_else(|| {
68 anyhow!("Failed to build ConstraintSystem: references still exist to oracles")
69 })?
70 .into_inner(),
71 flushes: self.flushes,
72 })
73 }
74
75 pub const fn witness(&mut self) -> Option<&mut witness::Builder<'arena>> {
76 self.witness.as_mut()
77 }
78
79 pub fn take_witness(
80 &mut self,
81 ) -> Result<MultilinearExtensionIndex<'arena, U, F>, anyhow::Error> {
82 Option::take(&mut self.witness)
83 .ok_or_else(|| {
84 anyhow!("Witness is missing. Are you in verifier mode, or have you already extraced the witness?")
85 })?
86 .build()
87 }
88
89 pub fn flush(
90 &mut self,
91 direction: FlushDirection,
92 channel_id: ChannelId,
93 count: usize,
94 oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
95 ) -> anyhow::Result<()>
96 where
97 U: PackScalar<BinaryField1b>,
98 {
99 self.flush_with_multiplicity(direction, channel_id, count, oracle_ids, 1)
100 }
101
102 pub fn flush_with_multiplicity(
103 &mut self,
104 direction: FlushDirection,
105 channel_id: ChannelId,
106 count: usize,
107 oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
108 multiplicity: u64,
109 ) -> anyhow::Result<()>
110 where
111 U: PackScalar<BinaryField1b>,
112 {
113 let n_vars = self.log_rows(oracle_ids.clone())?;
114
115 let selector = if let Some(&selector) = self.step_down_dedup.get(&(n_vars, count)) {
116 selector
117 } else {
118 let step_down = StepDown::new(n_vars, count)?;
119 let selector = self.add_transparent(
120 format!("internal step_down {count}-{n_vars}"),
121 step_down.clone(),
122 )?;
123
124 if let Some(witness) = self.witness() {
125 step_down.populate(witness.new_column::<BinaryField1b>(selector).packed());
126 }
127
128 self.step_down_dedup.insert((n_vars, count), selector);
129 selector
130 };
131
132 self.flush_custom(direction, channel_id, selector, oracle_ids, multiplicity)
133 }
134
135 pub fn flush_custom(
136 &mut self,
137 direction: FlushDirection,
138 channel_id: ChannelId,
139 selector: OracleId,
140 oracle_ids: impl IntoIterator<Item = OracleId>,
141 multiplicity: u64,
142 ) -> anyhow::Result<()> {
143 let oracles = oracle_ids.into_iter().collect::<Vec<_>>();
144 let log_rows = self.log_rows(oracles.iter().copied())?;
145 ensure!(
146 log_rows == self.log_rows([selector])?,
147 "Selector {} n_vars does not match flush {:?}",
148 selector,
149 oracles
150 );
151
152 self.flushes.push(Flush {
153 channel_id,
154 direction,
155 selector,
156 oracles,
157 multiplicity,
158 });
159
160 Ok(())
161 }
162
163 pub fn send(
164 &mut self,
165 channel_id: ChannelId,
166 count: usize,
167 oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
168 ) -> anyhow::Result<()>
169 where
170 U: PackScalar<BinaryField1b>,
171 {
172 self.flush(FlushDirection::Push, channel_id, count, oracle_ids)
173 }
174
175 pub fn receive(
176 &mut self,
177 channel_id: ChannelId,
178 count: usize,
179 oracle_ids: impl IntoIterator<Item = OracleId> + Clone,
180 ) -> anyhow::Result<()>
181 where
182 U: PackScalar<BinaryField1b>,
183 {
184 self.flush(FlushDirection::Pull, channel_id, count, oracle_ids)
185 }
186
187 pub fn assert_zero(
188 &mut self,
189 name: impl ToString,
190 oracle_ids: impl IntoIterator<Item = OracleId>,
191 composition: ArithExpr<F>,
192 ) {
193 self.constraints
194 .add_zerocheck(name, oracle_ids, composition);
195 }
196
197 pub fn assert_not_zero(&mut self, oracle_id: OracleId) {
198 self.non_zero_oracle_ids.push(oracle_id);
199 }
200
201 pub const fn add_channel(&mut self) -> ChannelId {
202 let channel_id = self.next_channel_id;
203 self.next_channel_id += 1;
204 channel_id
205 }
206
207 pub fn add_committed(
208 &mut self,
209 name: impl ToString,
210 n_vars: usize,
211 tower_level: usize,
212 ) -> OracleId {
213 self.oracles
214 .borrow_mut()
215 .add_named(self.scoped_name(name))
216 .committed(n_vars, tower_level)
217 }
218
219 pub fn add_committed_multiple<const N: usize>(
220 &mut self,
221 name: impl ToString,
222 n_vars: usize,
223 tower_level: usize,
224 ) -> [OracleId; N] {
225 self.oracles
226 .borrow_mut()
227 .add_named(self.scoped_name(name))
228 .committed_multiple(n_vars, tower_level)
229 }
230
231 pub fn add_linear_combination(
232 &mut self,
233 name: impl ToString,
234 n_vars: usize,
235 inner: impl IntoIterator<Item = (OracleId, F)>,
236 ) -> Result<OracleId, OracleError> {
237 self.oracles
238 .borrow_mut()
239 .add_named(self.scoped_name(name))
240 .linear_combination(n_vars, inner)
241 }
242
243 pub fn add_linear_combination_with_offset(
244 &mut self,
245 name: impl ToString,
246 n_vars: usize,
247 offset: F,
248 inner: impl IntoIterator<Item = (OracleId, F)>,
249 ) -> Result<OracleId, OracleError> {
250 self.oracles
251 .borrow_mut()
252 .add_named(self.scoped_name(name))
253 .linear_combination_with_offset(n_vars, offset, inner)
254 }
255
256 pub fn add_composite_mle(
257 &mut self,
258 name: impl ToString,
259 n_vars: usize,
260 inner: impl IntoIterator<Item = OracleId>,
261 comp: ArithExpr<F>,
262 ) -> Result<OracleId, OracleError> {
263 self.oracles
264 .borrow_mut()
265 .add_named(self.scoped_name(name))
266 .composite_mle(n_vars, inner, comp)
267 }
268
269 pub fn add_packed(
270 &mut self,
271 name: impl ToString,
272 id: OracleId,
273 log_degree: usize,
274 ) -> Result<OracleId, OracleError> {
275 self.oracles
276 .borrow_mut()
277 .add_named(self.scoped_name(name))
278 .packed(id, log_degree)
279 }
280
281 pub fn add_projected(
282 &mut self,
283 name: impl ToString,
284 id: OracleId,
285 values: Vec<F>,
286 variant: ProjectionVariant,
287 ) -> Result<usize, OracleError> {
288 self.oracles
289 .borrow_mut()
290 .add_named(self.scoped_name(name))
291 .projected(id, values, variant)
292 }
293
294 pub fn add_repeating(
295 &mut self,
296 name: impl ToString,
297 id: OracleId,
298 log_count: usize,
299 ) -> Result<OracleId, OracleError> {
300 self.oracles
301 .borrow_mut()
302 .add_named(self.scoped_name(name))
303 .repeating(id, log_count)
304 }
305
306 pub fn add_shifted(
307 &mut self,
308 name: impl ToString,
309 id: OracleId,
310 offset: usize,
311 block_bits: usize,
312 variant: ShiftVariant,
313 ) -> Result<OracleId, OracleError> {
314 self.oracles
315 .borrow_mut()
316 .add_named(self.scoped_name(name))
317 .shifted(id, offset, block_bits, variant)
318 }
319
320 pub fn add_transparent(
321 &mut self,
322 name: impl ToString,
323 poly: impl MultivariatePoly<F> + 'static,
324 ) -> Result<OracleId, OracleError> {
325 self.oracles
326 .borrow_mut()
327 .add_named(self.scoped_name(name))
328 .transparent(poly)
329 }
330
331 pub fn add_zero_padded(
332 &mut self,
333 name: impl ToString,
334 id: OracleId,
335 n_vars: usize,
336 ) -> Result<OracleId, OracleError> {
337 self.oracles
338 .borrow_mut()
339 .add_named(self.scoped_name(name))
340 .zero_padded(id, n_vars)
341 }
342
343 fn scoped_name(&self, name: impl ToString) -> String {
344 let name = name.to_string();
345 if self.namespace_path.is_empty() {
346 name
347 } else {
348 format!("{}::{name}", self.namespace_path.join("::"))
349 }
350 }
351
352 pub fn push_namespace(&mut self, name: impl ToString) {
378 self.namespace_path.push(name.to_string());
379 }
380
381 pub fn pop_namespace(&mut self) {
382 self.namespace_path.pop();
383 }
384
385 pub fn log_rows(
391 &self,
392 oracle_ids: impl IntoIterator<Item = OracleId>,
393 ) -> anyhow::Result<usize> {
394 let mut oracle_ids = oracle_ids.into_iter();
395 let oracles = self.oracles.borrow();
396 let Some(first_id) = oracle_ids.next() else {
397 bail!(anyhow!("log_rows: You need to specify at least one column"));
398 };
399 let log_rows = oracles.n_vars(first_id);
400 if oracle_ids.any(|id| oracles.n_vars(id) != log_rows) {
401 bail!(anyhow!("log_rows: All columns must have the same number of rows"))
402 }
403 Ok(log_rows)
404 }
405}