binius_core/oracle/
constraint.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::iter::IntoIterator;
4use std::sync::Arc;
5
6use binius_field::{Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::{ArithExpr, CompositionPoly};
9use binius_utils::bail;
10use itertools::Itertools;
11
12use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId};
13
14/// Composition trait object that can be used to create lists of compositions of differing
15/// concrete types.
16pub type TypeErasedComposition<P> = Arc<dyn CompositionPoly<P>>;
17
18/// Constraint is a type erased composition along with a predicate on its values on the boolean hypercube
19#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
20pub struct Constraint<F: Field> {
21	pub name: String,
22	pub composition: ArithExpr<F>,
23	pub predicate: ConstraintPredicate<F>,
24}
25
26/// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality to zero
27/// on the hypercube (zerocheck)
28#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)]
29pub enum ConstraintPredicate<F: Field> {
30	Sum(F),
31	Zero,
32}
33
34/// Constraint set is a group of constraints that operate over the same set of oracle-identified multilinears
35#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
36pub struct ConstraintSet<F: Field> {
37	pub n_vars: usize,
38	pub oracle_ids: Vec<OracleId>,
39	pub constraints: Vec<Constraint<F>>,
40}
41
42// A deferred constraint constructor that instantiates index composition after the superset of oracles is known
43#[allow(clippy::type_complexity)]
44struct UngroupedConstraint<F: Field> {
45	name: String,
46	oracle_ids: Vec<OracleId>,
47	composition: ArithExpr<F>,
48	predicate: ConstraintPredicate<F>,
49}
50
51/// A builder struct that turns individual compositions over oraclized multilinears into a set of
52/// type erased `IndexComposition` instances operating over a superset of oracles of all constraints.
53#[derive(Default)]
54pub struct ConstraintSetBuilder<F: Field> {
55	constraints: Vec<UngroupedConstraint<F>>,
56}
57
58impl<F: Field> ConstraintSetBuilder<F> {
59	pub const fn new() -> Self {
60		Self {
61			constraints: Vec::new(),
62		}
63	}
64
65	pub fn add_sumcheck(
66		&mut self,
67		oracle_ids: impl IntoIterator<Item = OracleId>,
68		composition: ArithExpr<F>,
69		sum: F,
70	) {
71		self.constraints.push(UngroupedConstraint {
72			name: "sumcheck".into(),
73			oracle_ids: oracle_ids.into_iter().collect(),
74			composition,
75			predicate: ConstraintPredicate::Sum(sum),
76		});
77	}
78
79	pub fn add_zerocheck(
80		&mut self,
81		name: impl ToString,
82		oracle_ids: impl IntoIterator<Item = OracleId>,
83		composition: ArithExpr<F>,
84	) {
85		self.constraints.push(UngroupedConstraint {
86			name: name.to_string(),
87			oracle_ids: oracle_ids.into_iter().collect(),
88			composition,
89			predicate: ConstraintPredicate::Zero,
90		});
91	}
92
93	/// Build a single constraint set, requiring that all included oracle n_vars are the same
94	pub fn build_one(
95		self,
96		oracles: &MultilinearOracleSet<impl TowerField>,
97	) -> Result<ConstraintSet<F>, Error> {
98		let mut oracle_ids = self
99			.constraints
100			.iter()
101			.flat_map(|constraint| constraint.oracle_ids.clone())
102			.collect::<Vec<_>>();
103		if oracle_ids.is_empty() {
104			// Do not bail!, this error is handled in evalcheck.
105			return Err(Error::EmptyConstraintSet);
106		}
107		for id in &oracle_ids {
108			if !oracles.is_valid_oracle_id(*id) {
109				bail!(Error::InvalidOracleId(*id));
110			}
111		}
112		oracle_ids.sort();
113		oracle_ids.dedup();
114
115		let n_vars = oracle_ids
116			.first()
117			.map(|id| oracles.n_vars(*id))
118			.unwrap_or_default();
119
120		for id in &oracle_ids {
121			if oracles.n_vars(*id) != n_vars {
122				bail!(Error::ConstraintSetNvarsMismatch {
123					expected: n_vars,
124					got: oracles.n_vars(*id)
125				});
126			}
127		}
128
129		// at this point the superset of oracles is known and index compositions
130		// may be finally instantiated
131		let constraints =
132			self.constraints
133				.into_iter()
134				.map(|constraint| Constraint {
135					name: constraint.name,
136					composition: constraint
137						.composition
138						.remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
139							"precondition: oracle_ids is a superset of constraint.oracle_ids",
140						))
141						.expect("Infallible by ConstraintSetBuilder invariants."),
142					predicate: constraint.predicate,
143				})
144				.collect();
145
146		Ok(ConstraintSet {
147			n_vars,
148			oracle_ids,
149			constraints,
150		})
151	}
152
153	/// Create one ConstraintSet for every unique n_vars used.
154	///
155	/// Note that you can't mix oracles with different n_vars in a single constraint.
156	pub fn build(
157		self,
158		oracles: &MultilinearOracleSet<impl TowerField>,
159	) -> Result<Vec<ConstraintSet<F>>, Error> {
160		let connected_oracle_chunks = self
161			.constraints
162			.iter()
163			.map(|constraint| constraint.oracle_ids.clone())
164			.chain(oracles.iter().filter_map(|oracle| match oracle.variant {
165				MultilinearPolyVariant::Shifted(ref shifted) => {
166					Some(vec![oracle.id(), shifted.id()])
167				}
168				MultilinearPolyVariant::LinearCombination(ref linear_combination) => {
169					Some(linear_combination.polys().chain([oracle.id()]).collect())
170				}
171				_ => None,
172			}))
173			.collect::<Vec<_>>();
174
175		let groups = binius_utils::graph::connected_components(
176			&connected_oracle_chunks
177				.iter()
178				.map(|x| x.as_slice())
179				.collect::<Vec<_>>(),
180		);
181
182		let n_vars_and_constraints = self
183			.constraints
184			.into_iter()
185			.map(|constraint| {
186				if constraint.oracle_ids.is_empty() {
187					bail!(Error::EmptyConstraintSet);
188				}
189				for id in &constraint.oracle_ids {
190					if !oracles.is_valid_oracle_id(*id) {
191						bail!(Error::InvalidOracleId(*id));
192					}
193				}
194				let n_vars = constraint
195					.oracle_ids
196					.first()
197					.map(|id| oracles.n_vars(*id))
198					.unwrap();
199
200				for id in &constraint.oracle_ids {
201					if oracles.n_vars(*id) != n_vars {
202						bail!(Error::ConstraintSetNvarsMismatch {
203							expected: n_vars,
204							got: oracles.n_vars(*id)
205						});
206					}
207				}
208				Ok::<_, Error>((n_vars, constraint))
209			})
210			.collect::<Result<Vec<_>, _>>()?;
211
212		let grouped_constraints = n_vars_and_constraints
213			.into_iter()
214			.sorted_by_key(|(_, constraint)| groups[constraint.oracle_ids[0]])
215			.chunk_by(|(_, constraint)| groups[constraint.oracle_ids[0]]);
216
217		let constraint_sets = grouped_constraints
218			.into_iter()
219			.map(|(_, grouped_constraints)| {
220				let mut constraints = vec![];
221				let mut oracle_ids = vec![];
222
223				let grouped_constraints = grouped_constraints.into_iter().collect::<Vec<_>>();
224				let (n_vars, _) = grouped_constraints[0];
225
226				for (_, constraint) in grouped_constraints {
227					oracle_ids.extend(&constraint.oracle_ids);
228					constraints.push(constraint);
229				}
230				oracle_ids.sort();
231				oracle_ids.dedup();
232
233				let constraints = constraints
234					.into_iter()
235					.map(|constraint| Constraint {
236						name: constraint.name,
237						composition: constraint
238							.composition
239							.remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
240								"precondition: oracle_ids is a superset of constraint.oracle_ids",
241							))
242							.expect("Infallible by ConstraintSetBuilder invariants."),
243						predicate: constraint.predicate,
244					})
245					.collect();
246
247				ConstraintSet {
248					constraints,
249					oracle_ids,
250					n_vars,
251				}
252			})
253			.collect();
254
255		Ok(constraint_sets)
256	}
257}
258
259/// Find index of every subset element within the superset.
260/// If the superset contains duplicate elements the index of the first match is used
261///
262/// Returns None if the subset contains elements that don't exist in the superset
263fn positions<T: Eq>(subset: &[T], superset: &[T]) -> Option<Vec<usize>> {
264	subset
265		.iter()
266		.map(|subset_item| {
267			superset
268				.iter()
269				.position(|superset_item| superset_item == subset_item)
270		})
271		.collect()
272}