binius_core/oracle/
constraint.rsuse core::iter::IntoIterator;
use std::sync::Arc;
use binius_field::{Field, TowerField};
use binius_math::{ArithExpr, CompositionPolyOS};
use binius_utils::bail;
use itertools::Itertools;
use super::{Error, MultilinearOracleSet, MultilinearPolyOracle, OracleId};
pub type TypeErasedComposition<P> = Arc<dyn CompositionPolyOS<P>>;
#[derive(Debug, Clone)]
pub struct Constraint<F: Field> {
pub name: Arc<str>,
pub composition: ArithExpr<F>,
pub predicate: ConstraintPredicate<F>,
}
#[derive(Clone, Debug)]
pub enum ConstraintPredicate<F: Field> {
Sum(F),
Zero,
}
#[derive(Debug, Clone)]
pub struct ConstraintSet<F: Field> {
pub n_vars: usize,
pub oracle_ids: Vec<OracleId>,
pub constraints: Vec<Constraint<F>>,
}
#[allow(clippy::type_complexity)]
struct UngroupedConstraint<F: Field> {
name: Arc<str>,
oracle_ids: Vec<OracleId>,
composition: ArithExpr<F>,
predicate: ConstraintPredicate<F>,
}
#[derive(Default)]
pub struct ConstraintSetBuilder<F: Field> {
constraints: Vec<UngroupedConstraint<F>>,
}
impl<F: Field> ConstraintSetBuilder<F> {
pub fn new() -> Self {
Self {
constraints: Vec::new(),
}
}
pub fn add_sumcheck(
&mut self,
oracle_ids: impl IntoIterator<Item = OracleId>,
composition: ArithExpr<F>,
sum: F,
) {
self.constraints.push(UngroupedConstraint {
name: "sumcheck".into(),
oracle_ids: oracle_ids.into_iter().collect(),
composition,
predicate: ConstraintPredicate::Sum(sum),
});
}
pub fn add_zerocheck(
&mut self,
name: impl ToString,
oracle_ids: impl IntoIterator<Item = OracleId>,
composition: ArithExpr<F>,
) {
self.constraints.push(UngroupedConstraint {
name: name.to_string().into(),
oracle_ids: oracle_ids.into_iter().collect(),
composition,
predicate: ConstraintPredicate::Zero,
});
}
pub fn build_one(
self,
oracles: &MultilinearOracleSet<impl TowerField>,
) -> Result<ConstraintSet<F>, Error> {
let mut oracle_ids = self
.constraints
.iter()
.flat_map(|constraint| constraint.oracle_ids.clone())
.collect::<Vec<_>>();
if oracle_ids.is_empty() {
return Err(Error::EmptyConstraintSet);
}
for id in oracle_ids.iter() {
if !oracles.is_valid_oracle_id(*id) {
bail!(Error::InvalidOracleId(*id));
}
}
oracle_ids.sort();
oracle_ids.dedup();
let n_vars = oracle_ids
.first()
.map(|id| oracles.n_vars(*id))
.unwrap_or_default();
for id in oracle_ids.iter() {
if oracles.n_vars(*id) != n_vars {
bail!(Error::ConstraintSetNvarsMismatch {
expected: n_vars,
got: oracles.n_vars(*id)
});
}
}
let constraints =
self.constraints
.into_iter()
.map(|constraint| Constraint {
name: constraint.name,
composition: constraint
.composition
.remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
"precondition: oracle_ids is a superset of constraint.oracle_ids",
))
.expect("Infallible by ConstraintSetBuilder invariants."),
predicate: constraint.predicate,
})
.collect();
Ok(ConstraintSet {
n_vars,
oracle_ids,
constraints,
})
}
pub fn build(
self,
oracles: &MultilinearOracleSet<impl TowerField>,
) -> Result<Vec<ConstraintSet<F>>, Error> {
let connected_oracle_chunks = self
.constraints
.iter()
.map(|constraint| constraint.oracle_ids.clone())
.chain(oracles.iter().filter_map(|oracle| {
match oracle {
MultilinearPolyOracle::Shifted { id, shifted, .. } => {
Some(vec![id, shifted.inner().id()])
}
MultilinearPolyOracle::LinearCombination {
id,
linear_combination,
..
} => Some(
linear_combination
.polys()
.map(|p| p.id())
.chain([id])
.collect(),
),
_ => None,
}
}))
.collect::<Vec<_>>();
let groups = binius_utils::graph::connected_components(
&connected_oracle_chunks
.iter()
.map(|x| x.as_slice())
.collect::<Vec<_>>(),
);
let n_vars_and_constraints = self
.constraints
.into_iter()
.map(|constraint| {
if constraint.oracle_ids.is_empty() {
bail!(Error::EmptyConstraintSet);
}
for id in constraint.oracle_ids.iter() {
if !oracles.is_valid_oracle_id(*id) {
bail!(Error::InvalidOracleId(*id));
}
}
let n_vars = constraint
.oracle_ids
.first()
.map(|id| oracles.n_vars(*id))
.unwrap();
for id in constraint.oracle_ids.iter() {
if oracles.n_vars(*id) != n_vars {
bail!(Error::ConstraintSetNvarsMismatch {
expected: n_vars,
got: oracles.n_vars(*id)
});
}
}
Ok::<_, Error>((n_vars, constraint))
})
.collect::<Result<Vec<_>, _>>()?;
let grouped_constraints = n_vars_and_constraints
.into_iter()
.sorted_by_key(|(_, constraint)| groups[constraint.oracle_ids[0]])
.chunk_by(|(_, constraint)| groups[constraint.oracle_ids[0]]);
let constraint_sets = grouped_constraints
.into_iter()
.map(|(_, grouped_constraints)| {
let mut constraints = vec![];
let mut oracle_ids = vec![];
let grouped_constraints = grouped_constraints.into_iter().collect::<Vec<_>>();
let (n_vars, _) = grouped_constraints[0];
for (_, constraint) in grouped_constraints {
oracle_ids.extend(&constraint.oracle_ids);
constraints.push(constraint);
}
oracle_ids.sort();
oracle_ids.dedup();
let constraints = constraints
.into_iter()
.map(|constraint| Constraint {
name: constraint.name,
composition: constraint
.composition
.remap_vars(&positions(&constraint.oracle_ids, &oracle_ids).expect(
"precondition: oracle_ids is a superset of constraint.oracle_ids",
))
.expect("Infallible by ConstraintSetBuilder invariants."),
predicate: constraint.predicate,
})
.collect();
ConstraintSet {
constraints,
oracle_ids,
n_vars,
}
})
.collect();
Ok(constraint_sets)
}
}
fn positions<T: Eq>(subset: &[T], superset: &[T]) -> Option<Vec<usize>> {
subset
.iter()
.map(|subset_item| {
superset
.iter()
.position(|superset_item| superset_item == subset_item)
})
.collect()
}