1use std::{fmt::Debug, sync::Arc};
4
5use binius_field::{ExtensionField, PackedExtension, PackedField, TowerField};
6use binius_math::{MultilinearExtension, MultilinearExtensionBorrowed, MultilinearPoly};
7use binius_utils::bail;
8
9use crate::{oracle::OracleId, polynomial::Error as PolynomialError};
10
11pub type MultilinearWitness<'a, P> = Arc<dyn MultilinearPoly<P> + Send + Sync + 'a>;
12
13#[derive(Clone, Debug)]
14pub struct IndexEntry<'a, P: PackedField> {
15 pub multilin_poly: MultilinearWitness<'a, P>,
16 pub nonzero_scalars_prefix: usize,
17}
18
19#[derive(Default, Debug)]
26pub struct MultilinearExtensionIndex<'a, P>
27where
28 P: PackedField,
29{
30 entries: Vec<Option<IndexEntry<'a, P>>>,
31}
32
33#[derive(Debug, thiserror::Error)]
34pub enum Error {
35 #[error("witness not found for oracle {id}")]
36 MissingWitness { id: OracleId },
37 #[error("witness for oracle id {id} does not have an explicit backing multilinear")]
38 NoExplicitBackingMultilinearExtension { id: OracleId },
39 #[error("log degree mismatch for oracle id {oracle_id}. field_log_extension_degree = {field_log_extension_degree} entry_log_extension_degree = {entry_log_extension_degree}")]
40 OracleExtensionDegreeMismatch {
41 oracle_id: OracleId,
42 field_log_extension_degree: usize,
43 entry_log_extension_degree: usize,
44 },
45 #[error("polynomial error: {0}")]
46 Polynomial(#[from] PolynomialError),
47 #[error("HAL error: {0}")]
48 HalError(#[from] binius_hal::Error),
49 #[error("Math error: {0}")]
50 MathError(#[from] binius_math::Error),
51}
52
53impl<'a, P> MultilinearExtensionIndex<'a, P>
54where
55 P: PackedField,
56{
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn get_index_entry(&self, id: OracleId) -> Result<IndexEntry<'a, P>, Error> {
62 let entry = self
63 .entries
64 .get(id)
65 .ok_or(Error::MissingWitness { id })?
66 .as_ref()
67 .ok_or(Error::MissingWitness { id })?;
68 Ok(entry.clone())
69 }
70
71 pub fn get_multilin_poly(&self, id: OracleId) -> Result<MultilinearWitness<'a, P>, Error> {
72 Ok(self.get_index_entry(id)?.multilin_poly)
73 }
74
75 pub fn has(&self, id: OracleId) -> bool {
77 self.entries.get(id).is_some_and(Option::is_some)
78 }
79
80 pub fn update_multilin_poly(
81 &mut self,
82 witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>)>,
83 ) -> Result<(), Error> {
84 self.update_multilin_poly_with_nonzero_scalars_prefixes(witnesses.into_iter().map(
85 |(id, multilin_poly)| {
86 let nonzero_scalars_prefix = 1 << multilin_poly.n_vars();
87 (id, multilin_poly, nonzero_scalars_prefix)
88 },
89 ))
90 }
91
92 pub fn update_multilin_poly_with_nonzero_scalars_prefixes(
93 &mut self,
94 witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>, usize)>,
95 ) -> Result<(), Error> {
96 for (id, multilin_poly, nonzero_scalars_prefix) in witnesses {
97 if id >= self.entries.len() {
98 self.entries.resize_with(id + 1, || None);
99 }
100 self.entries[id] = Some(IndexEntry {
102 multilin_poly,
103 nonzero_scalars_prefix,
104 });
105 }
106 Ok(())
107 }
108
109 pub fn get<PS>(&self, id: OracleId) -> Result<MultilinearExtensionBorrowed<PS>, Error>
111 where
112 PS: PackedField,
113 P: PackedExtension<PS::Scalar, PackedSubfield = PS>,
114 PS::Scalar: TowerField,
115 P::Scalar: ExtensionField<PS::Scalar>,
116 {
117 let entry = self
118 .entries
119 .get(id)
120 .ok_or(Error::MissingWitness { id })?
121 .as_ref()
122 .ok_or(Error::MissingWitness { id })?;
123
124 let log_extension_degree = entry.multilin_poly.log_extension_degree();
125 if log_extension_degree != PS::Scalar::LOG_DEGREE {
126 bail!(Error::OracleExtensionDegreeMismatch {
127 oracle_id: id,
128 field_log_extension_degree: PS::Scalar::LOG_DEGREE,
129 entry_log_extension_degree: log_extension_degree
130 })
131 }
132
133 let evals = entry
134 .multilin_poly
135 .packed_evals()
136 .map(P::cast_bases)
137 .ok_or(Error::NoExplicitBackingMultilinearExtension { id })?;
138
139 Ok(MultilinearExtension::from_values_slice(evals)?)
140 }
141}