binius_core/ring_switch/
common.rsuse std::sync::Arc;
use binius_field::{Field, TowerField};
use binius_utils::sparse_index::SparseIndex;
use super::error::Error;
use crate::{
oracle::MultilinearPolyOracle, piop::CommitMeta,
protocols::evalcheck::EvalcheckMultilinearClaim,
};
#[derive(Debug)]
pub struct EvalClaimPrefixDesc<F: Field> {
pub prefix: Vec<F>,
}
impl<F: Field> EvalClaimPrefixDesc<F> {
pub fn kappa(&self) -> usize {
self.prefix.len()
}
}
#[derive(Debug)]
pub struct EvalClaimSuffixDesc<F: Field> {
pub suffix: Arc<[F]>,
pub kappa: usize,
}
#[derive(Debug)]
pub struct PIOPSumcheckClaimDesc<'a, F: Field> {
pub committed_idx: usize,
pub suffix_desc_idx: usize,
pub eval_claim: &'a EvalcheckMultilinearClaim<F>,
}
#[derive(Debug)]
pub struct EvalClaimSystem<'a, F: Field> {
pub commit_meta: &'a CommitMeta,
pub prefix_descs: Vec<EvalClaimPrefixDesc<F>>,
pub suffix_descs: Vec<EvalClaimSuffixDesc<F>>,
pub sumcheck_claim_descs: Vec<PIOPSumcheckClaimDesc<'a, F>>,
pub eval_claim_to_prefix_desc_index: Vec<usize>,
}
impl<'a, F: TowerField> EvalClaimSystem<'a, F> {
pub fn new(
commit_meta: &'a CommitMeta,
oracle_to_commit_index: SparseIndex<usize>,
eval_claims: &'a [EvalcheckMultilinearClaim<F>],
) -> Result<Self, Error> {
let mut eval_claims = eval_claims.iter().collect::<Vec<_>>();
eval_claims.sort_by_key(|claim| match claim.poly {
MultilinearPolyOracle::Committed {
n_vars,
tower_level,
..
} => n_vars + tower_level,
_ => 0,
});
let (
prefix_descs,
eval_claim_to_prefix_desc_index,
suffix_descs,
eval_claim_to_suffix_desc_index,
) = group_claims_by_eval_point(&eval_claims)?;
let sumcheck_claim_descs = eval_claims
.into_iter()
.enumerate()
.map(|(i, eval_claim)| {
let MultilinearPolyOracle::Committed { oracle_id, .. } = eval_claim.poly else {
return Err(Error::EvalcheckClaimForDerivedPoly {
id: eval_claim.poly.id(),
});
};
let committed_idx =
oracle_to_commit_index
.get(oracle_id)
.copied()
.ok_or_else(|| Error::OracleToCommitIndexMissingEntry {
id: eval_claim.poly.id(),
})?;
let suffix_desc_idx = eval_claim_to_suffix_desc_index[i];
Ok(PIOPSumcheckClaimDesc {
committed_idx,
suffix_desc_idx,
eval_claim,
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Self {
commit_meta,
prefix_descs,
suffix_descs,
sumcheck_claim_descs,
eval_claim_to_prefix_desc_index,
})
}
pub fn max_claim_kappa(&self) -> usize {
self.prefix_descs
.iter()
.map(|desc| desc.kappa())
.max()
.unwrap_or(0)
}
}
#[allow(clippy::type_complexity)]
fn group_claims_by_eval_point<F: TowerField>(
claims: &[&EvalcheckMultilinearClaim<F>],
) -> Result<(Vec<EvalClaimPrefixDesc<F>>, Vec<usize>, Vec<EvalClaimSuffixDesc<F>>, Vec<usize>), Error>
{
let mut prefix_descs = Vec::<EvalClaimPrefixDesc<F>>::new();
let mut suffix_descs = Vec::<EvalClaimSuffixDesc<F>>::new();
let mut claim_to_prefix_index = Vec::with_capacity(claims.len());
let mut claim_to_suffix_index = Vec::with_capacity(claims.len());
for claim in claims {
let MultilinearPolyOracle::Committed {
oracle_id: id,
tower_level,
..
} = claim.poly
else {
return Err(Error::EvalcheckClaimForDerivedPoly {
id: claim.poly.id(),
});
};
let kappa = F::TOWER_LEVEL.checked_sub(tower_level).ok_or_else(|| {
Error::OracleTowerLevelTooHigh {
id,
max: F::TOWER_LEVEL,
}
})?;
let (prefix, suffix) = claim.eval_point.split_at(kappa);
let prefix_id = prefix_descs
.iter()
.position(|desc| desc.prefix == prefix)
.unwrap_or_else(|| {
let index = prefix_descs.len();
prefix_descs.push(EvalClaimPrefixDesc {
prefix: prefix.to_vec(),
});
index
});
claim_to_prefix_index.push(prefix_id);
let suffix_id = suffix_descs
.iter()
.position(|desc| &*desc.suffix == suffix && desc.kappa == kappa)
.unwrap_or_else(|| {
let index = suffix_descs.len();
suffix_descs.push(EvalClaimSuffixDesc {
suffix: suffix.to_vec().into(),
kappa,
});
index
});
claim_to_suffix_index.push(suffix_id);
}
Ok((prefix_descs, claim_to_prefix_index, suffix_descs, claim_to_suffix_index))
}