binius_core/oracle/
composite.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::sync::Arc;
4
5use binius_field::TowerField;
6use binius_math::CompositionPoly;
7use binius_utils::bail;
8
9use crate::oracle::{Error, MultilinearPolyOracle, OracleId};
10
11#[derive(Debug, Clone)]
12pub struct CompositePolyOracle<F: TowerField> {
13	n_vars: usize,
14	inner: Vec<MultilinearPolyOracle<F>>,
15	composition: Arc<dyn CompositionPoly<F>>,
16}
17
18impl<F: TowerField> CompositePolyOracle<F> {
19	pub fn new<C: CompositionPoly<F> + 'static>(
20		n_vars: usize,
21		inner: Vec<MultilinearPolyOracle<F>>,
22		composition: C,
23	) -> Result<Self, Error> {
24		if inner.len() != composition.n_vars() {
25			bail!(Error::CompositionMismatch);
26		}
27		for poly in &inner {
28			if poly.n_vars() != n_vars {
29				bail!(Error::IncorrectNumberOfVariables { expected: n_vars });
30			}
31		}
32		Ok(Self {
33			n_vars,
34			inner,
35			composition: Arc::new(composition),
36		})
37	}
38
39	pub fn max_individual_degree(&self) -> usize {
40		// Maximum individual degree of the multilinear composite equals composition degree
41		self.composition.degree()
42	}
43
44	pub fn n_multilinears(&self) -> usize {
45		self.composition.n_vars()
46	}
47
48	pub fn binary_tower_level(&self) -> usize {
49		self.composition.binary_tower_level().max(
50			self.inner
51				.iter()
52				.map(MultilinearPolyOracle::binary_tower_level)
53				.max()
54				.unwrap_or(0),
55		)
56	}
57
58	pub const fn n_vars(&self) -> usize {
59		self.n_vars
60	}
61
62	pub fn inner_polys_oracle_ids(&self) -> impl Iterator<Item = OracleId> + '_ {
63		self.inner.iter().map(|oracle| oracle.id())
64	}
65
66	pub fn inner_polys(&self) -> Vec<MultilinearPolyOracle<F>> {
67		self.inner.clone()
68	}
69
70	pub fn composition(&self) -> Arc<dyn CompositionPoly<F>> {
71		self.composition.clone()
72	}
73}
74
75#[cfg(test)]
76mod tests {
77	use binius_field::{BinaryField128b, BinaryField2b, BinaryField32b, BinaryField8b, TowerField};
78	use binius_math::ArithExpr;
79
80	use super::*;
81	use crate::oracle::MultilinearOracleSet;
82
83	#[derive(Clone, Debug)]
84	struct TestByteComposition;
85	impl CompositionPoly<BinaryField128b> for TestByteComposition {
86		fn n_vars(&self) -> usize {
87			3
88		}
89
90		fn degree(&self) -> usize {
91			1
92		}
93
94		fn expression(&self) -> ArithExpr<BinaryField128b> {
95			ArithExpr::Add(
96				Box::new(ArithExpr::Mul(Box::new(ArithExpr::Var(0)), Box::new(ArithExpr::Var(1)))),
97				Box::new(ArithExpr::Mul(
98					Box::new(ArithExpr::Var(2)),
99					Box::new(ArithExpr::Const(BinaryField128b::new(125))),
100				)),
101			)
102		}
103
104		fn evaluate(
105			&self,
106			query: &[BinaryField128b],
107		) -> Result<BinaryField128b, binius_math::Error> {
108			Ok(query[0] * query[1] + query[2] * BinaryField128b::new(125))
109		}
110
111		fn binary_tower_level(&self) -> usize {
112			BinaryField8b::TOWER_LEVEL
113		}
114	}
115
116	#[test]
117	fn test_composite_tower_level() {
118		type F = BinaryField128b;
119
120		let n_vars = 5;
121
122		let mut oracles = MultilinearOracleSet::<F>::new();
123		let poly_2b = oracles.add_committed(n_vars, BinaryField2b::TOWER_LEVEL);
124		let poly_8b = oracles.add_committed(n_vars, BinaryField8b::TOWER_LEVEL);
125		let poly_32b = oracles.add_committed(n_vars, BinaryField32b::TOWER_LEVEL);
126
127		let composition = TestByteComposition;
128		let composite = CompositePolyOracle::new(
129			n_vars,
130			vec![
131				oracles.oracle(poly_2b),
132				oracles.oracle(poly_2b),
133				oracles.oracle(poly_2b),
134			],
135			composition.clone(),
136		)
137		.unwrap();
138		assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
139
140		let composite = CompositePolyOracle::new(
141			n_vars,
142			vec![
143				oracles.oracle(poly_2b),
144				oracles.oracle(poly_8b),
145				oracles.oracle(poly_8b),
146			],
147			composition.clone(),
148		)
149		.unwrap();
150		assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
151
152		let composite = CompositePolyOracle::new(
153			n_vars,
154			vec![
155				oracles.oracle(poly_2b),
156				oracles.oracle(poly_8b),
157				oracles.oracle(poly_32b),
158			],
159			composition,
160		)
161		.unwrap();
162		assert_eq!(composite.binary_tower_level(), BinaryField32b::TOWER_LEVEL);
163	}
164}