binius_core/oracle/
constraint.rs1use core::iter::IntoIterator;
4use std::sync::Arc;
5
6use binius_field::{Field, TowerField};
7use binius_macros::{DeserializeBytes, SerializeBytes};
8use binius_math::{ArithCircuit, CompositionPoly};
9use binius_utils::bail;
10
11use super::{Error, MultilinearOracleSet, OracleId};
12use crate::constraint_system::TableId;
13
14pub type TypeErasedComposition<P> = Arc<dyn CompositionPoly<P>>;
17
18#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
21pub struct Constraint<F: Field> {
22 pub name: String,
23 pub composition: ArithCircuit<F>,
24 pub predicate: ConstraintPredicate<F>,
25}
26
27#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)]
30pub enum ConstraintPredicate<F: Field> {
31 Sum(F),
32 Zero,
33}
34
35#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
41pub struct SizedConstraintSet<F: Field> {
42 pub n_vars: usize,
43 pub oracle_ids: Vec<OracleId>,
44 pub constraints: Vec<Constraint<F>>,
45}
46
47impl<F: Field> SizedConstraintSet<F> {
48 pub fn new(n_vars: usize, u: ConstraintSet<F>) -> Self {
49 let oracle_ids = u.oracle_ids;
50 let constraints = u.constraints;
51
52 Self {
53 n_vars,
54 oracle_ids,
55 constraints,
56 }
57 }
58}
59
60#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
63pub struct ConstraintSet<F: Field> {
64 pub table_id: TableId,
65 pub log_values_per_row: usize,
66 pub oracle_ids: Vec<OracleId>,
67 pub constraints: Vec<Constraint<F>>,
68}
69
70#[allow(clippy::type_complexity)]
73struct UngroupedConstraint<F: Field> {
74 name: String,
75 oracle_ids: Vec<OracleId>,
76 composition: ArithCircuit<F>,
77 predicate: ConstraintPredicate<F>,
78}
79
80#[derive(Default)]
84pub struct ConstraintSetBuilder<F: Field> {
85 constraints: Vec<UngroupedConstraint<F>>,
86}
87
88impl<F: Field> ConstraintSetBuilder<F> {
89 pub const fn new() -> Self {
90 Self {
91 constraints: Vec::new(),
92 }
93 }
94
95 pub fn add_sumcheck(
96 &mut self,
97 oracle_ids: impl IntoIterator<Item = OracleId>,
98 composition: ArithCircuit<F>,
99 sum: F,
100 ) {
101 self.constraints.push(UngroupedConstraint {
102 name: "sumcheck".into(),
103 oracle_ids: oracle_ids.into_iter().collect(),
104 composition,
105 predicate: ConstraintPredicate::Sum(sum),
106 });
107 }
108
109 pub fn build_one(
111 self,
112 oracles: &MultilinearOracleSet<impl TowerField>,
113 ) -> Result<SizedConstraintSet<F>, Error> {
114 let mut oracle_ids = self
115 .constraints
116 .iter()
117 .flat_map(|constraint| constraint.oracle_ids.clone())
118 .collect::<Vec<_>>();
119 if oracle_ids.is_empty() {
120 return Err(Error::EmptyConstraintSet);
122 }
123 for id in &oracle_ids {
124 if !oracles.is_valid_oracle_id(*id) {
125 bail!(Error::InvalidOracleId(*id));
126 }
127 }
128 oracle_ids.sort();
129 oracle_ids.dedup();
130
131 let n_vars = oracle_ids
132 .first()
133 .map(|id| oracles.n_vars(*id))
134 .unwrap_or_default();
135
136 for id in &oracle_ids {
137 if oracles.n_vars(*id) != n_vars {
138 bail!(Error::ConstraintSetNvarsMismatch {
139 expected: n_vars,
140 got: oracles.n_vars(*id)
141 });
142 }
143 }
144
145 let constraints =
148 self.constraints
149 .into_iter()
150 .map(|constraint| Constraint {
151 name: constraint.name,
152 composition: constraint
153 .composition
154 .remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
155 "precondition: oracle_ids is a superset of constraint.oracle_ids",
156 ))
157 .expect("Infallible by ConstraintSetBuilder invariants."),
158 predicate: constraint.predicate,
159 })
160 .collect();
161
162 Ok(SizedConstraintSet {
163 n_vars,
164 oracle_ids,
165 constraints,
166 })
167 }
168}
169
170fn positions<T: Eq>(subset: &[T], superset: &[T]) -> Option<Vec<usize>> {
175 subset
176 .iter()
177 .map(|subset_item| {
178 superset
179 .iter()
180 .position(|superset_item| superset_item == subset_item)
181 })
182 .collect()
183}