binius_core/oracle/
constraint.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::iter::IntoIterator;
4use std::sync::Arc;
5
6use binius_field::{Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::{ArithCircuit, CompositionPoly};
9use binius_utils::bail;
10use itertools::Itertools;
11
12use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId};
13
14/// Composition trait object that can be used to create lists of compositions of differing
15/// concrete types.
16pub type TypeErasedComposition<P> = Arc<dyn CompositionPoly<P>>;
17
18/// Constraint is a type erased composition along with a predicate on its values on the boolean
19/// hypercube
20#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
21pub struct Constraint<F: Field> {
22	pub name: String,
23	pub composition: ArithCircuit<F>,
24	pub predicate: ConstraintPredicate<F>,
25}
26
27/// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality
28/// to zero on the hypercube (zerocheck)
29#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)]
30pub enum ConstraintPredicate<F: Field> {
31	Sum(F),
32	Zero,
33}
34
35/// Constraint set is a group of constraints that operate over the same set of oracle-identified
36/// multilinears
37#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
38pub struct ConstraintSet<F: Field> {
39	pub n_vars: usize,
40	pub oracle_ids: Vec<OracleId>,
41	pub constraints: Vec<Constraint<F>>,
42}
43
44// A deferred constraint constructor that instantiates index composition after the superset of
45// oracles is known
46#[allow(clippy::type_complexity)]
47struct UngroupedConstraint<F: Field> {
48	name: String,
49	oracle_ids: Vec<OracleId>,
50	composition: ArithCircuit<F>,
51	predicate: ConstraintPredicate<F>,
52}
53
54/// A builder struct that turns individual compositions over oraclized multilinears into a set of
55/// type erased `IndexComposition` instances operating over a superset of oracles of all
56/// constraints.
57#[derive(Default)]
58pub struct ConstraintSetBuilder<F: Field> {
59	constraints: Vec<UngroupedConstraint<F>>,
60}
61
62impl<F: Field> ConstraintSetBuilder<F> {
63	pub const fn new() -> Self {
64		Self {
65			constraints: Vec::new(),
66		}
67	}
68
69	pub fn add_sumcheck(
70		&mut self,
71		oracle_ids: impl IntoIterator<Item = OracleId>,
72		composition: ArithCircuit<F>,
73		sum: F,
74	) {
75		self.constraints.push(UngroupedConstraint {
76			name: "sumcheck".into(),
77			oracle_ids: oracle_ids.into_iter().collect(),
78			composition,
79			predicate: ConstraintPredicate::Sum(sum),
80		});
81	}
82
83	pub fn add_zerocheck(
84		&mut self,
85		name: impl ToString,
86		oracle_ids: impl IntoIterator<Item = OracleId>,
87		composition: ArithCircuit<F>,
88	) {
89		self.constraints.push(UngroupedConstraint {
90			name: name.to_string(),
91			oracle_ids: oracle_ids.into_iter().collect(),
92			composition,
93			predicate: ConstraintPredicate::Zero,
94		});
95	}
96
97	/// Build a single constraint set, requiring that all included oracle n_vars are the same
98	pub fn build_one(
99		self,
100		oracles: &MultilinearOracleSet<impl TowerField>,
101	) -> Result<ConstraintSet<F>, Error> {
102		let mut oracle_ids = self
103			.constraints
104			.iter()
105			.flat_map(|constraint| constraint.oracle_ids.clone())
106			.collect::<Vec<_>>();
107		if oracle_ids.is_empty() {
108			// Do not bail!, this error is handled in evalcheck.
109			return Err(Error::EmptyConstraintSet);
110		}
111		for id in &oracle_ids {
112			if !oracles.is_valid_oracle_id(*id) {
113				bail!(Error::InvalidOracleId(*id));
114			}
115		}
116		oracle_ids.sort();
117		oracle_ids.dedup();
118
119		let n_vars = oracle_ids
120			.first()
121			.map(|id| oracles.n_vars(*id))
122			.unwrap_or_default();
123
124		for id in &oracle_ids {
125			if oracles.n_vars(*id) != n_vars {
126				bail!(Error::ConstraintSetNvarsMismatch {
127					expected: n_vars,
128					got: oracles.n_vars(*id)
129				});
130			}
131		}
132
133		// at this point the superset of oracles is known and index compositions
134		// may be finally instantiated
135		let constraints =
136			self.constraints
137				.into_iter()
138				.map(|constraint| Constraint {
139					name: constraint.name,
140					composition: constraint
141						.composition
142						.remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
143							"precondition: oracle_ids is a superset of constraint.oracle_ids",
144						))
145						.expect("Infallible by ConstraintSetBuilder invariants."),
146					predicate: constraint.predicate,
147				})
148				.collect();
149
150		Ok(ConstraintSet {
151			n_vars,
152			oracle_ids,
153			constraints,
154		})
155	}
156
157	/// Create one ConstraintSet for every unique n_vars used.
158	///
159	/// Note that you can't mix oracles with different n_vars in a single constraint.
160	pub fn build(
161		self,
162		oracles: &MultilinearOracleSet<impl TowerField>,
163	) -> Result<Vec<ConstraintSet<F>>, Error> {
164		let connected_oracle_chunks = self
165			.constraints
166			.iter()
167			.map(|constraint| constraint.oracle_ids.clone())
168			.chain(oracles.polys().filter_map(|oracle| match &oracle.variant {
169				MultilinearPolyVariant::Shifted(shifted) => Some(vec![oracle.id(), shifted.id()]),
170				MultilinearPolyVariant::LinearCombination(linear_combination) => {
171					Some(linear_combination.polys().chain([oracle.id()]).collect())
172				}
173				_ => None,
174			}))
175			.collect::<Vec<_>>();
176
177		let connected_oracle_chunks = connected_oracle_chunks
178			.iter()
179			.map(|x| x.iter().map(|y| y.index()).collect::<Vec<usize>>())
180			.collect::<Vec<Vec<usize>>>();
181
182		let groups = binius_utils::graph::connected_components(&connected_oracle_chunks);
183
184		let n_vars_and_constraints = self
185			.constraints
186			.into_iter()
187			.map(|constraint| {
188				if constraint.oracle_ids.is_empty() {
189					bail!(Error::EmptyConstraintSet);
190				}
191				for id in &constraint.oracle_ids {
192					if !oracles.is_valid_oracle_id(*id) {
193						bail!(Error::InvalidOracleId(*id));
194					}
195				}
196				let n_vars = constraint
197					.oracle_ids
198					.first()
199					.map(|id| oracles.n_vars(*id))
200					.unwrap();
201
202				for id in &constraint.oracle_ids {
203					if oracles.n_vars(*id) != n_vars {
204						bail!(Error::ConstraintSetNvarsMismatch {
205							expected: n_vars,
206							got: oracles.n_vars(*id)
207						});
208					}
209				}
210				Ok::<_, Error>((n_vars, constraint))
211			})
212			.collect::<Result<Vec<_>, _>>()?;
213
214		let grouped_constraints = n_vars_and_constraints
215			.into_iter()
216			.sorted_by_key(|(_, constraint)| groups[constraint.oracle_ids[0].index()])
217			.chunk_by(|(_, constraint)| groups[constraint.oracle_ids[0].index()]);
218
219		let constraint_sets = grouped_constraints
220			.into_iter()
221			.map(|(_, grouped_constraints)| {
222				let mut constraints = vec![];
223				let mut oracle_ids = vec![];
224
225				let grouped_constraints = grouped_constraints.into_iter().collect::<Vec<_>>();
226				let (n_vars, _) = grouped_constraints[0];
227
228				for (_, constraint) in grouped_constraints {
229					oracle_ids.extend(&constraint.oracle_ids);
230					constraints.push(constraint);
231				}
232				oracle_ids.sort();
233				oracle_ids.dedup();
234
235				let constraints = constraints
236					.into_iter()
237					.map(|constraint| Constraint {
238						name: constraint.name,
239						composition: constraint
240							.composition
241							.remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
242								"precondition: oracle_ids is a superset of constraint.oracle_ids",
243							))
244							.expect("Infallible by ConstraintSetBuilder invariants."),
245						predicate: constraint.predicate,
246					})
247					.collect();
248
249				ConstraintSet {
250					constraints,
251					oracle_ids,
252					n_vars,
253				}
254			})
255			.collect();
256
257		Ok(constraint_sets)
258	}
259}
260
261/// Find index of every subset element within the superset.
262/// If the superset contains duplicate elements the index of the first match is used
263///
264/// Returns None if the subset contains elements that don't exist in the superset
265fn positions<T: Eq>(subset: &[T], superset: &[T]) -> Option<Vec<usize>> {
266	subset
267		.iter()
268		.map(|subset_item| {
269			superset
270				.iter()
271				.position(|superset_item| superset_item == subset_item)
272		})
273		.collect()
274}