use crate::polynomial::{CompositionPoly, Error};
use binius_field::PackedField;
use binius_utils::bail;
use std::fmt::Debug;
#[derive(Clone, Debug)]
pub struct IndexComposition<C, const N: usize> {
n_vars: usize,
indices: [usize; N],
composition: C,
}
impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
for IndexComposition<C, N>
{
fn n_vars(&self) -> usize {
self.n_vars
}
fn degree(&self) -> usize {
self.composition.degree()
}
fn evaluate(&self, query: &[P]) -> Result<P, Error> {
if query.len() != self.n_vars {
bail!(Error::IncorrectQuerySize {
expected: self.n_vars,
});
}
let subquery = self.indices.map(|index| query[index]);
self.composition.evaluate(&subquery)
}
fn binary_tower_level(&self) -> usize {
self.composition.binary_tower_level()
}
fn sparse_batch_evaluate(
&self,
sparse_batch_query: &[&[P]],
evals: &mut [P],
) -> Result<(), Error> {
let sparse_batch_subquery = self.indices.map(|index| sparse_batch_query[index]);
self.composition
.sparse_batch_evaluate(sparse_batch_subquery.as_slice(), evals)
}
}
pub fn index_composition<E, C, const N: usize>(
superset: &[E],
subset: [E; N],
composition: C,
) -> Result<IndexComposition<C, N>, Error>
where
E: PartialEq,
{
let n_vars = superset.len();
let proper_subset = subset.iter().all(|subset_item| {
superset
.iter()
.any(|superset_item| superset_item == subset_item)
});
if !proper_subset {
bail!(Error::MixedMultilinearNotFound);
}
let indices = subset.map(|subset_item| {
superset
.iter()
.position(|superset_item| superset_item == &subset_item)
.expect("Superset condition checked above.")
});
Ok(IndexComposition {
n_vars,
indices,
composition,
})
}