binius_core/oracle/
constraint.rs1use core::iter::IntoIterator;
4use std::sync::Arc;
5
6use binius_field::{Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::{ArithCircuit, CompositionPoly};
9use binius_utils::bail;
10use itertools::Itertools;
11
12use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId};
13
14pub type TypeErasedComposition<P> = Arc<dyn CompositionPoly<P>>;
17
18#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
21pub struct Constraint<F: Field> {
22 pub name: String,
23 pub composition: ArithCircuit<F>,
24 pub predicate: ConstraintPredicate<F>,
25}
26
27#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)]
30pub enum ConstraintPredicate<F: Field> {
31 Sum(F),
32 Zero,
33}
34
35#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
38pub struct ConstraintSet<F: Field> {
39 pub n_vars: usize,
40 pub oracle_ids: Vec<OracleId>,
41 pub constraints: Vec<Constraint<F>>,
42}
43
44#[allow(clippy::type_complexity)]
47struct UngroupedConstraint<F: Field> {
48 name: String,
49 oracle_ids: Vec<OracleId>,
50 composition: ArithCircuit<F>,
51 predicate: ConstraintPredicate<F>,
52}
53
54#[derive(Default)]
58pub struct ConstraintSetBuilder<F: Field> {
59 constraints: Vec<UngroupedConstraint<F>>,
60}
61
62impl<F: Field> ConstraintSetBuilder<F> {
63 pub const fn new() -> Self {
64 Self {
65 constraints: Vec::new(),
66 }
67 }
68
69 pub fn add_sumcheck(
70 &mut self,
71 oracle_ids: impl IntoIterator<Item = OracleId>,
72 composition: ArithCircuit<F>,
73 sum: F,
74 ) {
75 self.constraints.push(UngroupedConstraint {
76 name: "sumcheck".into(),
77 oracle_ids: oracle_ids.into_iter().collect(),
78 composition,
79 predicate: ConstraintPredicate::Sum(sum),
80 });
81 }
82
83 pub fn add_zerocheck(
84 &mut self,
85 name: impl ToString,
86 oracle_ids: impl IntoIterator<Item = OracleId>,
87 composition: ArithCircuit<F>,
88 ) {
89 self.constraints.push(UngroupedConstraint {
90 name: name.to_string(),
91 oracle_ids: oracle_ids.into_iter().collect(),
92 composition,
93 predicate: ConstraintPredicate::Zero,
94 });
95 }
96
97 pub fn build_one(
99 self,
100 oracles: &MultilinearOracleSet<impl TowerField>,
101 ) -> Result<ConstraintSet<F>, Error> {
102 let mut oracle_ids = self
103 .constraints
104 .iter()
105 .flat_map(|constraint| constraint.oracle_ids.clone())
106 .collect::<Vec<_>>();
107 if oracle_ids.is_empty() {
108 return Err(Error::EmptyConstraintSet);
110 }
111 for id in &oracle_ids {
112 if !oracles.is_valid_oracle_id(*id) {
113 bail!(Error::InvalidOracleId(*id));
114 }
115 }
116 oracle_ids.sort();
117 oracle_ids.dedup();
118
119 let n_vars = oracle_ids
120 .first()
121 .map(|id| oracles.n_vars(*id))
122 .unwrap_or_default();
123
124 for id in &oracle_ids {
125 if oracles.n_vars(*id) != n_vars {
126 bail!(Error::ConstraintSetNvarsMismatch {
127 expected: n_vars,
128 got: oracles.n_vars(*id)
129 });
130 }
131 }
132
133 let constraints =
136 self.constraints
137 .into_iter()
138 .map(|constraint| Constraint {
139 name: constraint.name,
140 composition: constraint
141 .composition
142 .remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
143 "precondition: oracle_ids is a superset of constraint.oracle_ids",
144 ))
145 .expect("Infallible by ConstraintSetBuilder invariants."),
146 predicate: constraint.predicate,
147 })
148 .collect();
149
150 Ok(ConstraintSet {
151 n_vars,
152 oracle_ids,
153 constraints,
154 })
155 }
156
157 pub fn build(
161 self,
162 oracles: &MultilinearOracleSet<impl TowerField>,
163 ) -> Result<Vec<ConstraintSet<F>>, Error> {
164 let connected_oracle_chunks = self
165 .constraints
166 .iter()
167 .map(|constraint| constraint.oracle_ids.clone())
168 .chain(oracles.polys().filter_map(|oracle| match &oracle.variant {
169 MultilinearPolyVariant::Shifted(shifted) => Some(vec![oracle.id(), shifted.id()]),
170 MultilinearPolyVariant::LinearCombination(linear_combination) => {
171 Some(linear_combination.polys().chain([oracle.id()]).collect())
172 }
173 _ => None,
174 }))
175 .collect::<Vec<_>>();
176
177 let connected_oracle_chunks = connected_oracle_chunks
178 .iter()
179 .map(|x| x.iter().map(|y| y.index()).collect::<Vec<usize>>())
180 .collect::<Vec<Vec<usize>>>();
181
182 let groups = binius_utils::graph::connected_components(&connected_oracle_chunks);
183
184 let n_vars_and_constraints = self
185 .constraints
186 .into_iter()
187 .map(|constraint| {
188 if constraint.oracle_ids.is_empty() {
189 bail!(Error::EmptyConstraintSet);
190 }
191 for id in &constraint.oracle_ids {
192 if !oracles.is_valid_oracle_id(*id) {
193 bail!(Error::InvalidOracleId(*id));
194 }
195 }
196 let n_vars = constraint
197 .oracle_ids
198 .first()
199 .map(|id| oracles.n_vars(*id))
200 .unwrap();
201
202 for id in &constraint.oracle_ids {
203 if oracles.n_vars(*id) != n_vars {
204 bail!(Error::ConstraintSetNvarsMismatch {
205 expected: n_vars,
206 got: oracles.n_vars(*id)
207 });
208 }
209 }
210 Ok::<_, Error>((n_vars, constraint))
211 })
212 .collect::<Result<Vec<_>, _>>()?;
213
214 let grouped_constraints = n_vars_and_constraints
215 .into_iter()
216 .sorted_by_key(|(_, constraint)| groups[constraint.oracle_ids[0].index()])
217 .chunk_by(|(_, constraint)| groups[constraint.oracle_ids[0].index()]);
218
219 let constraint_sets = grouped_constraints
220 .into_iter()
221 .map(|(_, grouped_constraints)| {
222 let mut constraints = vec![];
223 let mut oracle_ids = vec![];
224
225 let grouped_constraints = grouped_constraints.into_iter().collect::<Vec<_>>();
226 let (n_vars, _) = grouped_constraints[0];
227
228 for (_, constraint) in grouped_constraints {
229 oracle_ids.extend(&constraint.oracle_ids);
230 constraints.push(constraint);
231 }
232 oracle_ids.sort();
233 oracle_ids.dedup();
234
235 let constraints = constraints
236 .into_iter()
237 .map(|constraint| Constraint {
238 name: constraint.name,
239 composition: constraint
240 .composition
241 .remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
242 "precondition: oracle_ids is a superset of constraint.oracle_ids",
243 ))
244 .expect("Infallible by ConstraintSetBuilder invariants."),
245 predicate: constraint.predicate,
246 })
247 .collect();
248
249 ConstraintSet {
250 constraints,
251 oracle_ids,
252 n_vars,
253 }
254 })
255 .collect();
256
257 Ok(constraint_sets)
258 }
259}
260
261fn positions<T: Eq>(subset: &[T], superset: &[T]) -> Option<Vec<usize>> {
266 subset
267 .iter()
268 .map(|subset_item| {
269 superset
270 .iter()
271 .position(|superset_item| superset_item == subset_item)
272 })
273 .collect()
274}