1use std::{fmt::Debug, sync::Arc};
4
5use binius_field::{
6 as_packed_field::{PackScalar, PackedType},
7 underlier::UnderlierType,
8 ExtensionField, Field, PackedExtension, TowerField,
9};
10use binius_math::{MultilinearExtension, MultilinearExtensionBorrowed, MultilinearPoly};
11use binius_utils::bail;
12
13use crate::{oracle::OracleId, polynomial::Error as PolynomialError};
14
15pub type MultilinearWitness<'a, P> = Arc<dyn MultilinearPoly<P> + Send + Sync + 'a>;
16
17#[derive(Default, Debug)]
24pub struct MultilinearExtensionIndex<'a, U, FW>
25where
26 U: UnderlierType + PackScalar<FW>,
27 FW: Field,
28{
29 entries: Vec<Option<MultilinearWitness<'a, PackedType<U, FW>>>>,
30}
31
32#[derive(Debug, thiserror::Error)]
33pub enum Error {
34 #[error("witness not found for oracle {id}")]
35 MissingWitness { id: OracleId },
36 #[error("witness for oracle id {id} does not have an explicit backing multilinear")]
37 NoExplicitBackingMultilinearExtension { id: OracleId },
38 #[error("log degree mismatch for oracle id {oracle_id}. field_log_extension_degree = {field_log_extension_degree} entry_log_extension_degree = {entry_log_extension_degree}")]
39 OracleExtensionDegreeMismatch {
40 oracle_id: OracleId,
41 field_log_extension_degree: usize,
42 entry_log_extension_degree: usize,
43 },
44 #[error("polynomial error: {0}")]
45 Polynomial(#[from] PolynomialError),
46 #[error("HAL error: {0}")]
47 HalError(#[from] binius_hal::Error),
48 #[error("Math error: {0}")]
49 MathError(#[from] binius_math::Error),
50}
51
52impl<'a, U, FW> MultilinearExtensionIndex<'a, U, FW>
53where
54 U: UnderlierType + PackScalar<FW>,
55 FW: Field,
56{
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn get_multilin_poly(
62 &self,
63 id: OracleId,
64 ) -> Result<MultilinearWitness<'a, PackedType<U, FW>>, Error> {
65 let entry = self
66 .entries
67 .get(id)
68 .ok_or(Error::MissingWitness { id })?
69 .as_ref()
70 .ok_or(Error::MissingWitness { id })?;
71 Ok(entry.clone())
72 }
73
74 pub fn has(&self, id: OracleId) -> bool {
76 self.entries.get(id).is_some_and(Option::is_some)
77 }
78
79 pub fn update_multilin_poly(
80 &mut self,
81 witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, PackedType<U, FW>>)>,
82 ) -> Result<(), Error> {
83 for (id, witness) in witnesses {
84 if id >= self.entries.len() {
85 self.entries.resize_with(id + 1, || None);
86 }
87 self.entries[id] = Some(witness);
88 }
89 Ok(())
90 }
91
92 pub fn get<FS>(
94 &self,
95 id: OracleId,
96 ) -> Result<MultilinearExtensionBorrowed<PackedType<U, FS>>, Error>
97 where
98 FS: TowerField,
99 FW: ExtensionField<FS>,
100 U: PackScalar<FS>,
101 {
102 let entry = self
103 .entries
104 .get(id)
105 .ok_or(Error::MissingWitness { id })?
106 .as_ref()
107 .ok_or(Error::MissingWitness { id })?;
108
109 if entry.log_extension_degree() != FW::LOG_DEGREE {
110 bail!(Error::OracleExtensionDegreeMismatch {
111 oracle_id: id,
112 field_log_extension_degree: FW::LOG_DEGREE,
113 entry_log_extension_degree: entry.log_extension_degree()
114 })
115 }
116
117 let evals = entry
118 .packed_evals()
119 .map(<PackedType<U, FW>>::cast_bases)
120 .ok_or(Error::NoExplicitBackingMultilinearExtension { id })?;
121
122 Ok(MultilinearExtension::from_values_slice(evals)?)
123 }
124}