binius_core/oracle/
composite.rs1use std::sync::Arc;
4
5use binius_field::TowerField;
6use binius_math::CompositionPoly;
7use binius_utils::bail;
8
9use crate::oracle::{Error, MultilinearPolyOracle, OracleId};
10
11#[derive(Debug, Clone)]
12pub struct CompositePolyOracle<F: TowerField> {
13 n_vars: usize,
14 inner: Vec<MultilinearPolyOracle<F>>,
15 composition: Arc<dyn CompositionPoly<F>>,
16}
17
18impl<F: TowerField> CompositePolyOracle<F> {
19 pub fn new<C: CompositionPoly<F> + 'static>(
20 n_vars: usize,
21 inner: Vec<MultilinearPolyOracle<F>>,
22 composition: C,
23 ) -> Result<Self, Error> {
24 if inner.len() != composition.n_vars() {
25 bail!(Error::CompositionMismatch);
26 }
27 for poly in &inner {
28 if poly.n_vars() != n_vars {
29 bail!(Error::IncorrectNumberOfVariables { expected: n_vars });
30 }
31 }
32 Ok(Self {
33 n_vars,
34 inner,
35 composition: Arc::new(composition),
36 })
37 }
38
39 pub fn max_individual_degree(&self) -> usize {
40 self.composition.degree()
42 }
43
44 pub fn n_multilinears(&self) -> usize {
45 self.composition.n_vars()
46 }
47
48 pub fn binary_tower_level(&self) -> usize {
49 self.composition.binary_tower_level().max(
50 self.inner
51 .iter()
52 .map(MultilinearPolyOracle::binary_tower_level)
53 .max()
54 .unwrap_or(0),
55 )
56 }
57
58 pub const fn n_vars(&self) -> usize {
59 self.n_vars
60 }
61
62 pub fn inner_polys_oracle_ids(&self) -> impl Iterator<Item = OracleId> + '_ {
63 self.inner.iter().map(|oracle| oracle.id())
64 }
65
66 pub fn inner_polys(&self) -> Vec<MultilinearPolyOracle<F>> {
67 self.inner.clone()
68 }
69
70 pub fn composition(&self) -> Arc<dyn CompositionPoly<F>> {
71 self.composition.clone()
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use binius_field::{BinaryField128b, BinaryField2b, BinaryField32b, BinaryField8b, TowerField};
78 use binius_math::{ArithCircuit, ArithExpr};
79
80 use super::*;
81 use crate::oracle::MultilinearOracleSet;
82
83 #[derive(Clone, Debug)]
84 struct TestByteComposition;
85 impl CompositionPoly<BinaryField128b> for TestByteComposition {
86 fn n_vars(&self) -> usize {
87 3
88 }
89
90 fn degree(&self) -> usize {
91 1
92 }
93
94 fn expression(&self) -> ArithCircuit<BinaryField128b> {
95 (ArithExpr::Var(0) * ArithExpr::Var(1)
96 + ArithExpr::Var(2) * ArithExpr::Const(BinaryField128b::new(125)))
97 .into()
98 }
99
100 fn evaluate(
101 &self,
102 query: &[BinaryField128b],
103 ) -> Result<BinaryField128b, binius_math::Error> {
104 Ok(query[0] * query[1] + query[2] * BinaryField128b::new(125))
105 }
106
107 fn binary_tower_level(&self) -> usize {
108 BinaryField8b::TOWER_LEVEL
109 }
110 }
111
112 #[test]
113 fn test_composite_tower_level() {
114 type F = BinaryField128b;
115
116 let n_vars = 5;
117
118 let mut oracles = MultilinearOracleSet::<F>::new();
119 let poly_2b = oracles.add_committed(n_vars, BinaryField2b::TOWER_LEVEL);
120 let poly_8b = oracles.add_committed(n_vars, BinaryField8b::TOWER_LEVEL);
121 let poly_32b = oracles.add_committed(n_vars, BinaryField32b::TOWER_LEVEL);
122
123 let composition = TestByteComposition;
124 let composite = CompositePolyOracle::new(
125 n_vars,
126 vec![
127 oracles.oracle(poly_2b),
128 oracles.oracle(poly_2b),
129 oracles.oracle(poly_2b),
130 ],
131 composition.clone(),
132 )
133 .unwrap();
134 assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
135
136 let composite = CompositePolyOracle::new(
137 n_vars,
138 vec![
139 oracles.oracle(poly_2b),
140 oracles.oracle(poly_8b),
141 oracles.oracle(poly_8b),
142 ],
143 composition.clone(),
144 )
145 .unwrap();
146 assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
147
148 let composite = CompositePolyOracle::new(
149 n_vars,
150 vec![
151 oracles.oracle(poly_2b),
152 oracles.oracle(poly_8b),
153 oracles.oracle(poly_32b),
154 ],
155 composition,
156 )
157 .unwrap();
158 assert_eq!(composite.binary_tower_level(), BinaryField32b::TOWER_LEVEL);
159 }
160}