binius_core/oracle/
composite.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::sync::Arc;
4
5use binius_field::TowerField;
6use binius_math::CompositionPoly;
7use binius_utils::bail;
8
9use crate::oracle::{Error, MultilinearPolyOracle, OracleId};
10
11#[derive(Debug, Clone)]
12pub struct CompositePolyOracle<F: TowerField> {
13	n_vars: usize,
14	inner: Vec<MultilinearPolyOracle<F>>,
15	composition: Arc<dyn CompositionPoly<F>>,
16}
17
18impl<F: TowerField> CompositePolyOracle<F> {
19	pub fn new<C: CompositionPoly<F> + 'static>(
20		n_vars: usize,
21		inner: Vec<MultilinearPolyOracle<F>>,
22		composition: C,
23	) -> Result<Self, Error> {
24		if inner.len() != composition.n_vars() {
25			bail!(Error::CompositionMismatch);
26		}
27		for poly in &inner {
28			if poly.n_vars() != n_vars {
29				bail!(Error::IncorrectNumberOfVariables { expected: n_vars });
30			}
31		}
32		Ok(Self {
33			n_vars,
34			inner,
35			composition: Arc::new(composition),
36		})
37	}
38
39	pub fn max_individual_degree(&self) -> usize {
40		// Maximum individual degree of the multilinear composite equals composition degree
41		self.composition.degree()
42	}
43
44	pub fn n_multilinears(&self) -> usize {
45		self.composition.n_vars()
46	}
47
48	pub fn binary_tower_level(&self) -> usize {
49		self.composition.binary_tower_level().max(
50			self.inner
51				.iter()
52				.map(MultilinearPolyOracle::binary_tower_level)
53				.max()
54				.unwrap_or(0),
55		)
56	}
57
58	pub const fn n_vars(&self) -> usize {
59		self.n_vars
60	}
61
62	pub fn inner_polys_oracle_ids(&self) -> impl Iterator<Item = OracleId> + '_ {
63		self.inner.iter().map(|oracle| oracle.id())
64	}
65
66	pub fn inner_polys(&self) -> Vec<MultilinearPolyOracle<F>> {
67		self.inner.clone()
68	}
69
70	pub fn composition(&self) -> Arc<dyn CompositionPoly<F>> {
71		self.composition.clone()
72	}
73}
74
75#[cfg(test)]
76mod tests {
77	use binius_field::{BinaryField128b, BinaryField2b, BinaryField32b, BinaryField8b, TowerField};
78	use binius_math::{ArithCircuit, ArithExpr};
79
80	use super::*;
81	use crate::oracle::MultilinearOracleSet;
82
83	#[derive(Clone, Debug)]
84	struct TestByteComposition;
85	impl CompositionPoly<BinaryField128b> for TestByteComposition {
86		fn n_vars(&self) -> usize {
87			3
88		}
89
90		fn degree(&self) -> usize {
91			1
92		}
93
94		fn expression(&self) -> ArithCircuit<BinaryField128b> {
95			(ArithExpr::Var(0) * ArithExpr::Var(1)
96				+ ArithExpr::Var(2) * ArithExpr::Const(BinaryField128b::new(125)))
97			.into()
98		}
99
100		fn evaluate(
101			&self,
102			query: &[BinaryField128b],
103		) -> Result<BinaryField128b, binius_math::Error> {
104			Ok(query[0] * query[1] + query[2] * BinaryField128b::new(125))
105		}
106
107		fn binary_tower_level(&self) -> usize {
108			BinaryField8b::TOWER_LEVEL
109		}
110	}
111
112	#[test]
113	fn test_composite_tower_level() {
114		type F = BinaryField128b;
115
116		let n_vars = 5;
117
118		let mut oracles = MultilinearOracleSet::<F>::new();
119		let poly_2b = oracles.add_committed(n_vars, BinaryField2b::TOWER_LEVEL);
120		let poly_8b = oracles.add_committed(n_vars, BinaryField8b::TOWER_LEVEL);
121		let poly_32b = oracles.add_committed(n_vars, BinaryField32b::TOWER_LEVEL);
122
123		let composition = TestByteComposition;
124		let composite = CompositePolyOracle::new(
125			n_vars,
126			vec![
127				oracles.oracle(poly_2b),
128				oracles.oracle(poly_2b),
129				oracles.oracle(poly_2b),
130			],
131			composition.clone(),
132		)
133		.unwrap();
134		assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
135
136		let composite = CompositePolyOracle::new(
137			n_vars,
138			vec![
139				oracles.oracle(poly_2b),
140				oracles.oracle(poly_8b),
141				oracles.oracle(poly_8b),
142			],
143			composition.clone(),
144		)
145		.unwrap();
146		assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
147
148		let composite = CompositePolyOracle::new(
149			n_vars,
150			vec![
151				oracles.oracle(poly_2b),
152				oracles.oracle(poly_8b),
153				oracles.oracle(poly_32b),
154			],
155			composition,
156		)
157		.unwrap();
158		assert_eq!(composite.binary_tower_level(), BinaryField32b::TOWER_LEVEL);
159	}
160}