binius_core/
witness.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{fmt::Debug, sync::Arc};
4
5use binius_field::PackedField;
6use binius_math::MultilinearPoly;
7
8use crate::{oracle::OracleId, polynomial::Error as PolynomialError};
9
10pub type MultilinearWitness<'a, P> = Arc<dyn MultilinearPoly<P> + Send + Sync + 'a>;
11
12#[derive(Clone, Debug)]
13pub struct IndexEntry<'a, P: PackedField> {
14	pub multilin_poly: MultilinearWitness<'a, P>,
15	pub nonzero_scalars_prefix: usize,
16}
17
18/// Data structure that indexes multilinear extensions by oracle ID.
19///
20/// A [`crate::oracle::MultilinearOracleSet`] indexes multilinear polynomial oracles by assigning
21/// unique, sequential oracle IDs.
22#[derive(Default, Debug)]
23pub struct MultilinearExtensionIndex<'a, P>
24where
25	P: PackedField,
26{
27	entries: Vec<Option<IndexEntry<'a, P>>>,
28}
29
30#[derive(Debug, thiserror::Error)]
31pub enum Error {
32	#[error("witness not found for oracle {id}")]
33	MissingWitness { id: OracleId },
34	#[error("witness for oracle id {id} does not have an explicit backing multilinear")]
35	NoExplicitBackingMultilinearExtension { id: OracleId },
36	#[error(
37		"log degree mismatch for oracle id {oracle_id}. field_log_extension_degree = {field_log_extension_degree} entry_log_extension_degree = {entry_log_extension_degree}"
38	)]
39	OracleExtensionDegreeMismatch {
40		oracle_id: OracleId,
41		field_log_extension_degree: usize,
42		entry_log_extension_degree: usize,
43	},
44	#[error("polynomial error: {0}")]
45	Polynomial(#[from] PolynomialError),
46	#[error("HAL error: {0}")]
47	HalError(#[from] binius_hal::Error),
48	#[error("Math error: {0}")]
49	MathError(#[from] binius_math::Error),
50}
51
52impl<'a, P> MultilinearExtensionIndex<'a, P>
53where
54	P: PackedField,
55{
56	pub fn new() -> Self {
57		Self::default()
58	}
59
60	pub fn get_index_entry(&self, id: OracleId) -> Result<IndexEntry<'a, P>, Error> {
61		let entry = self
62			.entries
63			.get(id.index())
64			.ok_or(Error::MissingWitness { id })?
65			.as_ref()
66			.ok_or(Error::MissingWitness { id })?;
67		Ok(entry.clone())
68	}
69
70	pub fn get_multilin_poly(&self, id: OracleId) -> Result<MultilinearWitness<'a, P>, Error> {
71		Ok(self.get_index_entry(id)?.multilin_poly)
72	}
73
74	/// Whether has data for the given oracle id.
75	pub fn has(&self, id: OracleId) -> bool {
76		self.entries.get(id.index()).is_some_and(Option::is_some)
77	}
78
79	pub fn update_multilin_poly(
80		&mut self,
81		witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>)>,
82	) -> Result<(), Error> {
83		self.update_multilin_poly_with_nonzero_scalars_prefixes(witnesses.into_iter().map(
84			|(id, multilin_poly)| {
85				let nonzero_scalars_prefix = 1 << multilin_poly.n_vars();
86				(id, multilin_poly, nonzero_scalars_prefix)
87			},
88		))
89	}
90
91	pub fn update_multilin_poly_with_nonzero_scalars_prefixes(
92		&mut self,
93		witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>, usize)>,
94	) -> Result<(), Error> {
95		for (id, multilin_poly, nonzero_scalars_prefix) in witnesses {
96			let id_index = id.index();
97			if id_index >= self.entries.len() {
98				self.entries.resize_with(id_index + 1, || None);
99			}
100			// TODO: validate nonzero_scalars_prefix
101			self.entries[id_index] = Some(IndexEntry {
102				multilin_poly,
103				nonzero_scalars_prefix,
104			});
105		}
106		Ok(())
107	}
108}