binius_core/
witness.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Copyright 2024-2025 Irreducible Inc.

use std::{fmt::Debug, sync::Arc};

use binius_field::{
	as_packed_field::{PackScalar, PackedType},
	underlier::UnderlierType,
	ExtensionField, Field, PackedExtension, TowerField,
};
use binius_math::{MultilinearExtension, MultilinearExtensionBorrowed, MultilinearPoly};
use binius_utils::bail;

use crate::{oracle::OracleId, polynomial::Error as PolynomialError};

pub type MultilinearWitness<'a, P> = Arc<dyn MultilinearPoly<P> + Send + Sync + 'a>;

/// Data structure that indexes multilinear extensions by oracle ID.
///
/// A [`crate::oracle::MultilinearOracleSet`] indexes multilinear polynomial oracles by assigning
/// unique, sequential oracle IDs. The caller can get the [`MultilinearExtension`] defined natively
/// over a subfield. This is possible because the [`MultilinearExtensionIndex::get`] method is
/// generic over the subfield type and the struct itself only stores the underlying data.
#[derive(Default, Debug)]
pub struct MultilinearExtensionIndex<'a, U: UnderlierType, FW>
where
	U: UnderlierType + PackScalar<FW>,
	FW: Field,
{
	entries: Vec<Option<MultilinearWitness<'a, PackedType<U, FW>>>>,
}

#[derive(Debug, thiserror::Error)]
pub enum Error {
	#[error("witness not found for oracle {id}")]
	MissingWitness { id: OracleId },
	#[error("witness for oracle id {id} does not have an explicit backing multilinear")]
	NoExplicitBackingMultilinearExtension { id: OracleId },
	#[error("log degree mismatch for oracle id {oracle_id}. field_log_extension_degree = {field_log_extension_degree} entry_log_extension_degree = {entry_log_extension_degree}")]
	OracleExtensionDegreeMismatch {
		oracle_id: OracleId,
		field_log_extension_degree: usize,
		entry_log_extension_degree: usize,
	},
	#[error("polynomial error: {0}")]
	Polynomial(#[from] PolynomialError),
	#[error("HAL error: {0}")]
	HalError(#[from] binius_hal::Error),
	#[error("Math error: {0}")]
	MathError(#[from] binius_math::Error),
}

impl<'a, U, FW> MultilinearExtensionIndex<'a, U, FW>
where
	U: UnderlierType + PackScalar<FW>,
	FW: Field,
{
	pub fn new() -> Self {
		Self::default()
	}

	pub fn get_multilin_poly(
		&self,
		id: OracleId,
	) -> Result<MultilinearWitness<'a, PackedType<U, FW>>, Error> {
		let entry = self
			.entries
			.get(id)
			.ok_or(Error::MissingWitness { id })?
			.as_ref()
			.ok_or(Error::MissingWitness { id })?;
		Ok(entry.clone())
	}

	/// Whether has data for the given oracle id.
	pub fn has(&self, id: OracleId) -> bool {
		self.entries.get(id).is_some_and(Option::is_some)
	}

	pub fn update_multilin_poly(
		&mut self,
		witnesses: impl IntoIterator<Item = (OracleId, MultilinearWitness<'a, PackedType<U, FW>>)>,
	) -> Result<(), Error> {
		for (id, witness) in witnesses {
			if id >= self.entries.len() {
				self.entries.resize_with(id + 1, || None);
			}
			self.entries[id] = Some(witness);
		}
		Ok(())
	}

	/// TODO: Remove once PCS no longer needs this
	pub fn get<FS>(
		&self,
		id: OracleId,
	) -> Result<MultilinearExtensionBorrowed<PackedType<U, FS>>, Error>
	where
		FS: TowerField,
		FW: ExtensionField<FS>,
		U: PackScalar<FS>,
	{
		let entry = self
			.entries
			.get(id)
			.ok_or(Error::MissingWitness { id })?
			.as_ref()
			.ok_or(Error::MissingWitness { id })?;

		if entry.log_extension_degree() != FW::LOG_DEGREE {
			bail!(Error::OracleExtensionDegreeMismatch {
				oracle_id: id,
				field_log_extension_degree: FW::LOG_DEGREE,
				entry_log_extension_degree: entry.log_extension_degree()
			})
		}

		let evals = entry
			.packed_evals()
			.map(<PackedType<U, FW>>::cast_bases)
			.ok_or(Error::NoExplicitBackingMultilinearExtension { id })?;

		Ok(MultilinearExtension::from_values_slice(evals)?)
	}
}