binius_core/
witness.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{fmt::Debug, sync::Arc};
4
5use binius_field::{
6	as_packed_field::{PackScalar, PackedType},
7	underlier::UnderlierType,
8	ExtensionField, Field, PackedExtension, TowerField,
9};
10use binius_math::{MultilinearExtension, MultilinearExtensionBorrowed, MultilinearPoly};
11use binius_utils::bail;
12
13use crate::{oracle::OracleId, polynomial::Error as PolynomialError};
14
15pub type MultilinearWitness<'a, P> = Arc<dyn MultilinearPoly<P> + Send + Sync + 'a>;
16
17/// Data structure that indexes multilinear extensions by oracle ID.
18///
19/// A [`crate::oracle::MultilinearOracleSet`] indexes multilinear polynomial oracles by assigning
20/// unique, sequential oracle IDs. The caller can get the [`MultilinearExtension`] defined natively
21/// over a subfield. This is possible because the [`MultilinearExtensionIndex::get`] method is
22/// generic over the subfield type and the struct itself only stores the underlying data.
23#[derive(Default, Debug)]
24pub struct MultilinearExtensionIndex<'a, U, FW>
25where
26	U: UnderlierType + PackScalar<FW>,
27	FW: Field,
28{
29	entries: Vec<Option<MultilinearWitness<'a, PackedType<U, FW>>>>,
30}
31
32#[derive(Debug, thiserror::Error)]
33pub enum Error {
34	#[error("witness not found for oracle {id}")]
35	MissingWitness { id: OracleId },
36	#[error("witness for oracle id {id} does not have an explicit backing multilinear")]
37	NoExplicitBackingMultilinearExtension { id: OracleId },
38	#[error("log degree mismatch for oracle id {oracle_id}. field_log_extension_degree = {field_log_extension_degree} entry_log_extension_degree = {entry_log_extension_degree}")]
39	OracleExtensionDegreeMismatch {
40		oracle_id: OracleId,
41		field_log_extension_degree: usize,
42		entry_log_extension_degree: usize,
43	},
44	#[error("polynomial error: {0}")]
45	Polynomial(#[from] PolynomialError),
46	#[error("HAL error: {0}")]
47	HalError(#[from] binius_hal::Error),
48	#[error("Math error: {0}")]
49	MathError(#[from] binius_math::Error),
50}
51
52impl<'a, U, FW> MultilinearExtensionIndex<'a, U, FW>
53where
54	U: UnderlierType + PackScalar<FW>,
55	FW: Field,
56{
57	pub fn new() -> Self {
58		Self::default()
59	}
60
61	pub fn get_multilin_poly(
62		&self,
63		id: OracleId,
64	) -> Result<MultilinearWitness<'a, PackedType<U, FW>>, Error> {
65		let entry = self
66			.entries
67			.get(id)
68			.ok_or(Error::MissingWitness { id })?
69			.as_ref()
70			.ok_or(Error::MissingWitness { id })?;
71		Ok(entry.clone())
72	}
73
74	/// Whether has data for the given oracle id.
75	pub fn has(&self, id: OracleId) -> bool {
76		self.entries.get(id).is_some_and(Option::is_some)
77	}
78
79	pub fn update_multilin_poly(
80		&mut self,
81		witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, PackedType<U, FW>>)>,
82	) -> Result<(), Error> {
83		for (id, witness) in witnesses {
84			if id >= self.entries.len() {
85				self.entries.resize_with(id + 1, || None);
86			}
87			self.entries[id] = Some(witness);
88		}
89		Ok(())
90	}
91
92	/// TODO: Remove once PCS no longer needs this
93	pub fn get<FS>(
94		&self,
95		id: OracleId,
96	) -> Result<MultilinearExtensionBorrowed<PackedType<U, FS>>, Error>
97	where
98		FS: TowerField,
99		FW: ExtensionField<FS>,
100		U: PackScalar<FS>,
101	{
102		let entry = self
103			.entries
104			.get(id)
105			.ok_or(Error::MissingWitness { id })?
106			.as_ref()
107			.ok_or(Error::MissingWitness { id })?;
108
109		if entry.log_extension_degree() != FW::LOG_DEGREE {
110			bail!(Error::OracleExtensionDegreeMismatch {
111				oracle_id: id,
112				field_log_extension_degree: FW::LOG_DEGREE,
113				entry_log_extension_degree: entry.log_extension_degree()
114			})
115		}
116
117		let evals = entry
118			.packed_evals()
119			.map(<PackedType<U, FW>>::cast_bases)
120			.ok_or(Error::NoExplicitBackingMultilinearExtension { id })?;
121
122		Ok(MultilinearExtension::from_values_slice(evals)?)
123	}
124}