binius_core/oracle/
composite.rs1use std::sync::Arc;
4
5use binius_field::TowerField;
6use binius_math::CompositionPoly;
7use binius_utils::bail;
8
9use crate::oracle::{Error, MultilinearPolyOracle, OracleId};
10
11#[derive(Debug, Clone)]
12pub struct CompositePolyOracle<F: TowerField> {
13 n_vars: usize,
14 inner: Vec<MultilinearPolyOracle<F>>,
15 composition: Arc<dyn CompositionPoly<F>>,
16}
17
18impl<F: TowerField> CompositePolyOracle<F> {
19 pub fn new<C: CompositionPoly<F> + 'static>(
20 n_vars: usize,
21 inner: Vec<MultilinearPolyOracle<F>>,
22 composition: C,
23 ) -> Result<Self, Error> {
24 if inner.len() != composition.n_vars() {
25 bail!(Error::CompositionMismatch);
26 }
27 for poly in &inner {
28 if poly.n_vars() != n_vars {
29 bail!(Error::IncorrectNumberOfVariables { expected: n_vars });
30 }
31 }
32 Ok(Self {
33 n_vars,
34 inner,
35 composition: Arc::new(composition),
36 })
37 }
38
39 pub fn max_individual_degree(&self) -> usize {
40 self.composition.degree()
42 }
43
44 pub fn n_multilinears(&self) -> usize {
45 self.composition.n_vars()
46 }
47
48 pub fn binary_tower_level(&self) -> usize {
49 self.composition.binary_tower_level().max(
50 self.inner
51 .iter()
52 .map(MultilinearPolyOracle::binary_tower_level)
53 .max()
54 .unwrap_or(0),
55 )
56 }
57
58 pub const fn n_vars(&self) -> usize {
59 self.n_vars
60 }
61
62 pub fn inner_polys_oracle_ids(&self) -> impl Iterator<Item = OracleId> + '_ {
63 self.inner.iter().map(|oracle| oracle.id())
64 }
65
66 pub fn inner_polys(&self) -> Vec<MultilinearPolyOracle<F>> {
67 self.inner.clone()
68 }
69
70 pub fn composition(&self) -> Arc<dyn CompositionPoly<F>> {
71 self.composition.clone()
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use binius_field::{BinaryField128b, BinaryField2b, BinaryField32b, BinaryField8b, TowerField};
78 use binius_math::ArithExpr;
79
80 use super::*;
81 use crate::oracle::MultilinearOracleSet;
82
83 #[derive(Clone, Debug)]
84 struct TestByteComposition;
85 impl CompositionPoly<BinaryField128b> for TestByteComposition {
86 fn n_vars(&self) -> usize {
87 3
88 }
89
90 fn degree(&self) -> usize {
91 1
92 }
93
94 fn expression(&self) -> ArithExpr<BinaryField128b> {
95 ArithExpr::Add(
96 Box::new(ArithExpr::Mul(Box::new(ArithExpr::Var(0)), Box::new(ArithExpr::Var(1)))),
97 Box::new(ArithExpr::Mul(
98 Box::new(ArithExpr::Var(2)),
99 Box::new(ArithExpr::Const(BinaryField128b::new(125))),
100 )),
101 )
102 }
103
104 fn evaluate(
105 &self,
106 query: &[BinaryField128b],
107 ) -> Result<BinaryField128b, binius_math::Error> {
108 Ok(query[0] * query[1] + query[2] * BinaryField128b::new(125))
109 }
110
111 fn binary_tower_level(&self) -> usize {
112 BinaryField8b::TOWER_LEVEL
113 }
114 }
115
116 #[test]
117 fn test_composite_tower_level() {
118 type F = BinaryField128b;
119
120 let n_vars = 5;
121
122 let mut oracles = MultilinearOracleSet::<F>::new();
123 let poly_2b = oracles.add_committed(n_vars, BinaryField2b::TOWER_LEVEL);
124 let poly_8b = oracles.add_committed(n_vars, BinaryField8b::TOWER_LEVEL);
125 let poly_32b = oracles.add_committed(n_vars, BinaryField32b::TOWER_LEVEL);
126
127 let composition = TestByteComposition;
128 let composite = CompositePolyOracle::new(
129 n_vars,
130 vec![
131 oracles.oracle(poly_2b),
132 oracles.oracle(poly_2b),
133 oracles.oracle(poly_2b),
134 ],
135 composition.clone(),
136 )
137 .unwrap();
138 assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
139
140 let composite = CompositePolyOracle::new(
141 n_vars,
142 vec![
143 oracles.oracle(poly_2b),
144 oracles.oracle(poly_8b),
145 oracles.oracle(poly_8b),
146 ],
147 composition.clone(),
148 )
149 .unwrap();
150 assert_eq!(composite.binary_tower_level(), BinaryField8b::TOWER_LEVEL);
151
152 let composite = CompositePolyOracle::new(
153 n_vars,
154 vec![
155 oracles.oracle(poly_2b),
156 oracles.oracle(poly_8b),
157 oracles.oracle(poly_32b),
158 ],
159 composition,
160 )
161 .unwrap();
162 assert_eq!(composite.binary_tower_level(), BinaryField32b::TOWER_LEVEL);
163 }
164}