use super::OracleId;
use crate::{composition::index_composition, polynomial::CompositionPoly};
use binius_field::{Field, PackedField};
use std::sync::Arc;
pub type TypeErasedComposition<P> = Arc<dyn CompositionPoly<P>>;
#[derive(Clone)]
pub struct Constraint<P: PackedField> {
pub composition: TypeErasedComposition<P>,
pub predicate: ConstraintPredicate<P::Scalar>,
}
#[derive(Clone, Debug)]
pub enum ConstraintPredicate<F: Field> {
Sum(F),
Zero,
}
impl<F: Field> ConstraintPredicate<F> {
pub fn isomorphic<FI: Field + From<F>>(self) -> ConstraintPredicate<FI> {
match self {
ConstraintPredicate::Sum(sum) => ConstraintPredicate::Sum(sum.into()),
ConstraintPredicate::Zero => ConstraintPredicate::Zero,
}
}
}
#[derive(Clone)]
pub struct ConstraintSet<P: PackedField> {
pub oracle_ids: Vec<OracleId>,
pub constraints: Vec<Constraint<P>>,
}
#[allow(clippy::type_complexity)]
struct ConstraintThunk<P: PackedField> {
composition_thunk: Box<dyn FnOnce(&[OracleId]) -> TypeErasedComposition<P>>,
predicate: ConstraintPredicate<P::Scalar>,
}
pub struct ConstraintSetBuilder<P: PackedField> {
oracle_ids: Vec<OracleId>,
constraint_thunks: Vec<ConstraintThunk<P>>,
}
impl<P: PackedField> ConstraintSetBuilder<P> {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
oracle_ids: Vec::new(),
constraint_thunks: Vec::new(),
}
}
pub fn add_sumcheck<Composition, const N: usize>(
&mut self,
oracle_ids: [OracleId; N],
composition: Composition,
sum: P::Scalar,
) where
Composition: CompositionPoly<P> + 'static,
{
self.oracle_ids.extend(&oracle_ids);
self.constraint_thunks.push(ConstraintThunk {
composition_thunk: thunk(oracle_ids, composition),
predicate: ConstraintPredicate::Sum(sum),
});
}
pub fn add_zerocheck<Composition, const N: usize>(
&mut self,
oracle_ids: [OracleId; N],
composition: Composition,
) where
Composition: CompositionPoly<P> + 'static,
{
self.oracle_ids.extend(&oracle_ids);
self.constraint_thunks.push(ConstraintThunk {
composition_thunk: thunk(oracle_ids, composition),
predicate: ConstraintPredicate::Zero,
});
}
pub fn build(self) -> ConstraintSet<P> {
let mut oracle_ids = self.oracle_ids;
oracle_ids.sort();
oracle_ids.dedup();
let constraints = self
.constraint_thunks
.into_iter()
.map(|constraint_thunk| {
let composition = (constraint_thunk.composition_thunk)(&oracle_ids);
Constraint {
composition,
predicate: constraint_thunk.predicate,
}
})
.collect();
ConstraintSet {
oracle_ids,
constraints,
}
}
}
#[allow(clippy::type_complexity)]
fn thunk<P, Composition, const N: usize>(
oracle_ids: [OracleId; N],
composition: Composition,
) -> Box<dyn FnOnce(&[OracleId]) -> TypeErasedComposition<P>>
where
P: PackedField,
Composition: CompositionPoly<P> + 'static,
{
Box::new(move |all_oracle_ids| {
let indexed = index_composition(all_oracle_ids, oracle_ids, composition)
.expect("Infallible by ConstraintSetBuilder invariants.");
Arc::new(indexed)
})
}