binius_core/oracle/
constraint.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::iter::IntoIterator;
4use std::sync::Arc;
5
6use binius_field::{Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::{ArithCircuit, CompositionPoly};
9use binius_utils::bail;
10
11use super::{Error, MultilinearOracleSet, OracleId};
12use crate::constraint_system::TableId;
13
14/// Composition trait object that can be used to create lists of compositions of differing
15/// concrete types.
16pub type TypeErasedComposition<P> = Arc<dyn CompositionPoly<P>>;
17
18/// Constraint is a type erased composition along with a predicate on its values on the boolean
19/// hypercube
20#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
21pub struct Constraint<F: Field> {
22	pub name: String,
23	pub composition: ArithCircuit<F>,
24	pub predicate: ConstraintPredicate<F>,
25}
26
27/// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality
28/// to zero on the hypercube (zerocheck)
29#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)]
30pub enum ConstraintPredicate<F: Field> {
31	Sum(F),
32	Zero,
33}
34
35/// Constraint set is a group of constraints that operate over the same set of oracle-identified
36/// multilinears.
37///
38/// The difference from the [`ConstraintSet`] is that the latter is for the public API and this
39/// one should is supposed to be used within the core only.
40#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
41pub struct SizedConstraintSet<F: Field> {
42	pub n_vars: usize,
43	pub oracle_ids: Vec<OracleId>,
44	pub constraints: Vec<Constraint<F>>,
45}
46
47impl<F: Field> SizedConstraintSet<F> {
48	pub fn new(n_vars: usize, u: ConstraintSet<F>) -> Self {
49		let oracle_ids = u.oracle_ids;
50		let constraints = u.constraints;
51
52		Self {
53			n_vars,
54			oracle_ids,
55			constraints,
56		}
57	}
58}
59
60/// Constraint set is a group of constraints that operate over the same set of oracle-identified
61/// multilinears. The multilinears are expected to be of the same size.
62#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
63pub struct ConstraintSet<F: Field> {
64	pub table_id: TableId,
65	pub log_values_per_row: usize,
66	pub oracle_ids: Vec<OracleId>,
67	pub constraints: Vec<Constraint<F>>,
68}
69
70// A deferred constraint constructor that instantiates index composition after the superset of
71// oracles is known
72#[allow(clippy::type_complexity)]
73struct UngroupedConstraint<F: Field> {
74	name: String,
75	oracle_ids: Vec<OracleId>,
76	composition: ArithCircuit<F>,
77	predicate: ConstraintPredicate<F>,
78}
79
80/// A builder struct that turns individual compositions over oraclized multilinears into a set of
81/// type erased `IndexComposition` instances operating over a superset of oracles of all
82/// constraints.
83#[derive(Default)]
84pub struct ConstraintSetBuilder<F: Field> {
85	constraints: Vec<UngroupedConstraint<F>>,
86}
87
88impl<F: Field> ConstraintSetBuilder<F> {
89	pub const fn new() -> Self {
90		Self {
91			constraints: Vec::new(),
92		}
93	}
94
95	pub fn add_sumcheck(
96		&mut self,
97		oracle_ids: impl IntoIterator<Item = OracleId>,
98		composition: ArithCircuit<F>,
99		sum: F,
100	) {
101		self.constraints.push(UngroupedConstraint {
102			name: "sumcheck".into(),
103			oracle_ids: oracle_ids.into_iter().collect(),
104			composition,
105			predicate: ConstraintPredicate::Sum(sum),
106		});
107	}
108
109	/// Build a single constraint set, requiring that all included oracle n_vars are the same
110	pub fn build_one(
111		self,
112		oracles: &MultilinearOracleSet<impl TowerField>,
113	) -> Result<SizedConstraintSet<F>, Error> {
114		let mut oracle_ids = self
115			.constraints
116			.iter()
117			.flat_map(|constraint| constraint.oracle_ids.clone())
118			.collect::<Vec<_>>();
119		if oracle_ids.is_empty() {
120			// Do not bail!, this error is handled in evalcheck.
121			return Err(Error::EmptyConstraintSet);
122		}
123		for id in &oracle_ids {
124			if !oracles.is_valid_oracle_id(*id) {
125				bail!(Error::InvalidOracleId(*id));
126			}
127		}
128		oracle_ids.sort();
129		oracle_ids.dedup();
130
131		let n_vars = oracle_ids
132			.first()
133			.map(|id| oracles.n_vars(*id))
134			.unwrap_or_default();
135
136		for id in &oracle_ids {
137			if oracles.n_vars(*id) != n_vars {
138				bail!(Error::ConstraintSetNvarsMismatch {
139					expected: n_vars,
140					got: oracles.n_vars(*id)
141				});
142			}
143		}
144
145		// at this point the superset of oracles is known and index compositions
146		// may be finally instantiated
147		let constraints =
148			self.constraints
149				.into_iter()
150				.map(|constraint| Constraint {
151					name: constraint.name,
152					composition: constraint
153						.composition
154						.remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
155							"precondition: oracle_ids is a superset of constraint.oracle_ids",
156						))
157						.expect("Infallible by ConstraintSetBuilder invariants."),
158					predicate: constraint.predicate,
159				})
160				.collect();
161
162		Ok(SizedConstraintSet {
163			n_vars,
164			oracle_ids,
165			constraints,
166		})
167	}
168}
169
170/// Find index of every subset element within the superset.
171/// If the superset contains duplicate elements the index of the first match is used
172///
173/// Returns None if the subset contains elements that don't exist in the superset
174fn positions<T: Eq>(subset: &[T], superset: &[T]) -> Option<Vec<usize>> {
175	subset
176		.iter()
177		.map(|subset_item| {
178			superset
179				.iter()
180				.position(|superset_item| superset_item == subset_item)
181		})
182		.collect()
183}