binius_core/oracle/
constraint.rs1use core::iter::IntoIterator;
4use std::sync::Arc;
5
6use binius_field::{Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::{ArithExpr, CompositionPoly};
9use binius_utils::bail;
10use itertools::Itertools;
11
12use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId};
13
14pub type TypeErasedComposition<P> = Arc<dyn CompositionPoly<P>>;
17
18#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
20pub struct Constraint<F: Field> {
21 pub name: String,
22 pub composition: ArithExpr<F>,
23 pub predicate: ConstraintPredicate<F>,
24}
25
26#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)]
29pub enum ConstraintPredicate<F: Field> {
30 Sum(F),
31 Zero,
32}
33
34#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
36pub struct ConstraintSet<F: Field> {
37 pub n_vars: usize,
38 pub oracle_ids: Vec<OracleId>,
39 pub constraints: Vec<Constraint<F>>,
40}
41
42#[allow(clippy::type_complexity)]
44struct UngroupedConstraint<F: Field> {
45 name: String,
46 oracle_ids: Vec<OracleId>,
47 composition: ArithExpr<F>,
48 predicate: ConstraintPredicate<F>,
49}
50
51#[derive(Default)]
54pub struct ConstraintSetBuilder<F: Field> {
55 constraints: Vec<UngroupedConstraint<F>>,
56}
57
58impl<F: Field> ConstraintSetBuilder<F> {
59 pub const fn new() -> Self {
60 Self {
61 constraints: Vec::new(),
62 }
63 }
64
65 pub fn add_sumcheck(
66 &mut self,
67 oracle_ids: impl IntoIterator<Item = OracleId>,
68 composition: ArithExpr<F>,
69 sum: F,
70 ) {
71 self.constraints.push(UngroupedConstraint {
72 name: "sumcheck".into(),
73 oracle_ids: oracle_ids.into_iter().collect(),
74 composition,
75 predicate: ConstraintPredicate::Sum(sum),
76 });
77 }
78
79 pub fn add_zerocheck(
80 &mut self,
81 name: impl ToString,
82 oracle_ids: impl IntoIterator<Item = OracleId>,
83 composition: ArithExpr<F>,
84 ) {
85 self.constraints.push(UngroupedConstraint {
86 name: name.to_string(),
87 oracle_ids: oracle_ids.into_iter().collect(),
88 composition,
89 predicate: ConstraintPredicate::Zero,
90 });
91 }
92
93 pub fn build_one(
95 self,
96 oracles: &MultilinearOracleSet<impl TowerField>,
97 ) -> Result<ConstraintSet<F>, Error> {
98 let mut oracle_ids = self
99 .constraints
100 .iter()
101 .flat_map(|constraint| constraint.oracle_ids.clone())
102 .collect::<Vec<_>>();
103 if oracle_ids.is_empty() {
104 return Err(Error::EmptyConstraintSet);
106 }
107 for id in &oracle_ids {
108 if !oracles.is_valid_oracle_id(*id) {
109 bail!(Error::InvalidOracleId(*id));
110 }
111 }
112 oracle_ids.sort();
113 oracle_ids.dedup();
114
115 let n_vars = oracle_ids
116 .first()
117 .map(|id| oracles.n_vars(*id))
118 .unwrap_or_default();
119
120 for id in &oracle_ids {
121 if oracles.n_vars(*id) != n_vars {
122 bail!(Error::ConstraintSetNvarsMismatch {
123 expected: n_vars,
124 got: oracles.n_vars(*id)
125 });
126 }
127 }
128
129 let constraints =
132 self.constraints
133 .into_iter()
134 .map(|constraint| Constraint {
135 name: constraint.name,
136 composition: constraint
137 .composition
138 .remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
139 "precondition: oracle_ids is a superset of constraint.oracle_ids",
140 ))
141 .expect("Infallible by ConstraintSetBuilder invariants."),
142 predicate: constraint.predicate,
143 })
144 .collect();
145
146 Ok(ConstraintSet {
147 n_vars,
148 oracle_ids,
149 constraints,
150 })
151 }
152
153 pub fn build(
157 self,
158 oracles: &MultilinearOracleSet<impl TowerField>,
159 ) -> Result<Vec<ConstraintSet<F>>, Error> {
160 let connected_oracle_chunks = self
161 .constraints
162 .iter()
163 .map(|constraint| constraint.oracle_ids.clone())
164 .chain(oracles.iter().filter_map(|oracle| match oracle.variant {
165 MultilinearPolyVariant::Shifted(ref shifted) => {
166 Some(vec![oracle.id(), shifted.id()])
167 }
168 MultilinearPolyVariant::LinearCombination(ref linear_combination) => {
169 Some(linear_combination.polys().chain([oracle.id()]).collect())
170 }
171 _ => None,
172 }))
173 .collect::<Vec<_>>();
174
175 let groups = binius_utils::graph::connected_components(
176 &connected_oracle_chunks
177 .iter()
178 .map(|x| x.as_slice())
179 .collect::<Vec<_>>(),
180 );
181
182 let n_vars_and_constraints = self
183 .constraints
184 .into_iter()
185 .map(|constraint| {
186 if constraint.oracle_ids.is_empty() {
187 bail!(Error::EmptyConstraintSet);
188 }
189 for id in &constraint.oracle_ids {
190 if !oracles.is_valid_oracle_id(*id) {
191 bail!(Error::InvalidOracleId(*id));
192 }
193 }
194 let n_vars = constraint
195 .oracle_ids
196 .first()
197 .map(|id| oracles.n_vars(*id))
198 .unwrap();
199
200 for id in &constraint.oracle_ids {
201 if oracles.n_vars(*id) != n_vars {
202 bail!(Error::ConstraintSetNvarsMismatch {
203 expected: n_vars,
204 got: oracles.n_vars(*id)
205 });
206 }
207 }
208 Ok::<_, Error>((n_vars, constraint))
209 })
210 .collect::<Result<Vec<_>, _>>()?;
211
212 let grouped_constraints = n_vars_and_constraints
213 .into_iter()
214 .sorted_by_key(|(_, constraint)| groups[constraint.oracle_ids[0]])
215 .chunk_by(|(_, constraint)| groups[constraint.oracle_ids[0]]);
216
217 let constraint_sets = grouped_constraints
218 .into_iter()
219 .map(|(_, grouped_constraints)| {
220 let mut constraints = vec![];
221 let mut oracle_ids = vec![];
222
223 let grouped_constraints = grouped_constraints.into_iter().collect::<Vec<_>>();
224 let (n_vars, _) = grouped_constraints[0];
225
226 for (_, constraint) in grouped_constraints {
227 oracle_ids.extend(&constraint.oracle_ids);
228 constraints.push(constraint);
229 }
230 oracle_ids.sort();
231 oracle_ids.dedup();
232
233 let constraints = constraints
234 .into_iter()
235 .map(|constraint| Constraint {
236 name: constraint.name,
237 composition: constraint
238 .composition
239 .remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
240 "precondition: oracle_ids is a superset of constraint.oracle_ids",
241 ))
242 .expect("Infallible by ConstraintSetBuilder invariants."),
243 predicate: constraint.predicate,
244 })
245 .collect();
246
247 ConstraintSet {
248 constraints,
249 oracle_ids,
250 n_vars,
251 }
252 })
253 .collect();
254
255 Ok(constraint_sets)
256 }
257}
258
259fn positions<T: Eq>(subset: &[T], superset: &[T]) -> Option<Vec<usize>> {
264 subset
265 .iter()
266 .map(|subset_item| {
267 superset
268 .iter()
269 .position(|superset_item| superset_item == subset_item)
270 })
271 .collect()
272}