binius_core/
witness.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{fmt::Debug, sync::Arc};
4
5use binius_field::{ExtensionField, PackedExtension, PackedField, TowerField};
6use binius_math::{MultilinearExtension, MultilinearExtensionBorrowed, MultilinearPoly};
7use binius_utils::bail;
8
9use crate::{oracle::OracleId, polynomial::Error as PolynomialError};
10
11pub type MultilinearWitness<'a, P> = Arc<dyn MultilinearPoly<P> + Send + Sync + 'a>;
12
13#[derive(Clone, Debug)]
14pub struct IndexEntry<'a, P: PackedField> {
15	pub multilin_poly: MultilinearWitness<'a, P>,
16	pub nonzero_scalars_prefix: usize,
17}
18
19/// Data structure that indexes multilinear extensions by oracle ID.
20///
21/// A [`crate::oracle::MultilinearOracleSet`] indexes multilinear polynomial oracles by assigning
22/// unique, sequential oracle IDs. The caller can get the [`MultilinearExtension`] defined natively
23/// over a subfield. This is possible because the [`MultilinearExtensionIndex::get`] method is
24/// generic over the subfield type and the struct itself only stores the underlying data.
25#[derive(Default, Debug)]
26pub struct MultilinearExtensionIndex<'a, P>
27where
28	P: PackedField,
29{
30	entries: Vec<Option<IndexEntry<'a, P>>>,
31}
32
33#[derive(Debug, thiserror::Error)]
34pub enum Error {
35	#[error("witness not found for oracle {id}")]
36	MissingWitness { id: OracleId },
37	#[error("witness for oracle id {id} does not have an explicit backing multilinear")]
38	NoExplicitBackingMultilinearExtension { id: OracleId },
39	#[error("log degree mismatch for oracle id {oracle_id}. field_log_extension_degree = {field_log_extension_degree} entry_log_extension_degree = {entry_log_extension_degree}")]
40	OracleExtensionDegreeMismatch {
41		oracle_id: OracleId,
42		field_log_extension_degree: usize,
43		entry_log_extension_degree: usize,
44	},
45	#[error("polynomial error: {0}")]
46	Polynomial(#[from] PolynomialError),
47	#[error("HAL error: {0}")]
48	HalError(#[from] binius_hal::Error),
49	#[error("Math error: {0}")]
50	MathError(#[from] binius_math::Error),
51}
52
53impl<'a, P> MultilinearExtensionIndex<'a, P>
54where
55	P: PackedField,
56{
57	pub fn new() -> Self {
58		Self::default()
59	}
60
61	pub fn get_index_entry(&self, id: OracleId) -> Result<IndexEntry<'a, P>, Error> {
62		let entry = self
63			.entries
64			.get(id)
65			.ok_or(Error::MissingWitness { id })?
66			.as_ref()
67			.ok_or(Error::MissingWitness { id })?;
68		Ok(entry.clone())
69	}
70
71	pub fn get_multilin_poly(&self, id: OracleId) -> Result<MultilinearWitness<'a, P>, Error> {
72		Ok(self.get_index_entry(id)?.multilin_poly)
73	}
74
75	/// Whether has data for the given oracle id.
76	pub fn has(&self, id: OracleId) -> bool {
77		self.entries.get(id).is_some_and(Option::is_some)
78	}
79
80	pub fn update_multilin_poly(
81		&mut self,
82		witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>)>,
83	) -> Result<(), Error> {
84		self.update_multilin_poly_with_nonzero_scalars_prefixes(witnesses.into_iter().map(
85			|(id, multilin_poly)| {
86				let nonzero_scalars_prefix = 1 << multilin_poly.n_vars();
87				(id, multilin_poly, nonzero_scalars_prefix)
88			},
89		))
90	}
91
92	pub fn update_multilin_poly_with_nonzero_scalars_prefixes(
93		&mut self,
94		witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, P>, usize)>,
95	) -> Result<(), Error> {
96		for (id, multilin_poly, nonzero_scalars_prefix) in witnesses {
97			if id >= self.entries.len() {
98				self.entries.resize_with(id + 1, || None);
99			}
100			// TODO: validate nonzero_scalars_prefix
101			self.entries[id] = Some(IndexEntry {
102				multilin_poly,
103				nonzero_scalars_prefix,
104			});
105		}
106		Ok(())
107	}
108
109	/// TODO: Remove once PCS no longer needs this
110	pub fn get<PS>(&self, id: OracleId) -> Result<MultilinearExtensionBorrowed<PS>, Error>
111	where
112		PS: PackedField,
113		P: PackedExtension<PS::Scalar, PackedSubfield = PS>,
114		PS::Scalar: TowerField,
115		P::Scalar: ExtensionField<PS::Scalar>,
116	{
117		let entry = self
118			.entries
119			.get(id)
120			.ok_or(Error::MissingWitness { id })?
121			.as_ref()
122			.ok_or(Error::MissingWitness { id })?;
123
124		let log_extension_degree = entry.multilin_poly.log_extension_degree();
125		if log_extension_degree != PS::Scalar::LOG_DEGREE {
126			bail!(Error::OracleExtensionDegreeMismatch {
127				oracle_id: id,
128				field_log_extension_degree: PS::Scalar::LOG_DEGREE,
129				entry_log_extension_degree: log_extension_degree
130			})
131		}
132
133		let evals = entry
134			.multilin_poly
135			.packed_evals()
136			.map(P::cast_bases)
137			.ok_or(Error::NoExplicitBackingMultilinearExtension { id })?;
138
139		Ok(MultilinearExtension::from_values_slice(evals)?)
140	}
141}