1use std::{fmt::Debug, sync::Arc};
4
5use binius_field::PackedField;
6use binius_math::MultilinearPoly;
7
8use crate::{oracle::OracleId, polynomial::Error as PolynomialError};
9
10pub type MultilinearWitness<'a, P> = Arc<dyn MultilinearPoly<P> + Send + Sync + 'a>;
11
12#[derive(Clone, Debug)]
13pub struct IndexEntry<'a, P: PackedField> {
14 pub multilin_poly: MultilinearWitness<'a, P>,
15 pub nonzero_scalars_prefix: usize,
16}
17
18#[derive(Default, Debug)]
23pub struct MultilinearExtensionIndex<'a, P>
24where
25 P: PackedField,
26{
27 entries: Vec<Option<IndexEntry<'a, P>>>,
28}
29
30#[derive(Debug, thiserror::Error)]
31pub enum Error {
32 #[error("witness not found for oracle {id}")]
33 MissingWitness { id: OracleId },
34 #[error("witness for oracle id {id} does not have an explicit backing multilinear")]
35 NoExplicitBackingMultilinearExtension { id: OracleId },
36 #[error("log degree mismatch for oracle id {oracle_id}. field_log_extension_degree = {field_log_extension_degree} entry_log_extension_degree = {entry_log_extension_degree}")]
37 OracleExtensionDegreeMismatch {
38 oracle_id: OracleId,
39 field_log_extension_degree: usize,
40 entry_log_extension_degree: usize,
41 },
42 #[error("polynomial error: {0}")]
43 Polynomial(#[from] PolynomialError),
44 #[error("HAL error: {0}")]
45 HalError(#[from] binius_hal::Error),
46 #[error("Math error: {0}")]
47 MathError(#[from] binius_math::Error),
48}
49
50impl<'a, P> MultilinearExtensionIndex<'a, P>
51where
52 P: PackedField,
53{
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn get_index_entry(&self, id: OracleId) -> Result<IndexEntry<'a, P>, Error> {
59 let entry = self
60 .entries
61 .get(id)
62 .ok_or(Error::MissingWitness { id })?
63 .as_ref()
64 .ok_or(Error::MissingWitness { id })?;
65 Ok(entry.clone())
66 }
67
68 pub fn get_multilin_poly(&self, id: OracleId) -> Result<MultilinearWitness<'a, P>, Error> {
69 Ok(self.get_index_entry(id)?.multilin_poly)
70 }
71
72 pub fn has(&self, id: OracleId) -> bool {
74 self.entries.get(id).is_some_and(Option::is_some)
75 }
76
77 pub fn update_multilin_poly(
78 &mut self,
79 witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>)>,
80 ) -> Result<(), Error> {
81 self.update_multilin_poly_with_nonzero_scalars_prefixes(witnesses.into_iter().map(
82 |(id, multilin_poly)| {
83 let nonzero_scalars_prefix = 1 << multilin_poly.n_vars();
84 (id, multilin_poly, nonzero_scalars_prefix)
85 },
86 ))
87 }
88
89 pub fn update_multilin_poly_with_nonzero_scalars_prefixes(
90 &mut self,
91 witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>, usize)>,
92 ) -> Result<(), Error> {
93 for (id, multilin_poly, nonzero_scalars_prefix) in witnesses {
94 if id >= self.entries.len() {
95 self.entries.resize_with(id + 1, || None);
96 }
97 self.entries[id] = Some(IndexEntry {
99 multilin_poly,
100 nonzero_scalars_prefix,
101 });
102 }
103 Ok(())
104 }
105}