1use std::fmt::Debug;
4
5use binius_field::{Field, PackedField};
6use binius_math::{ArithExpr, CompositionPoly};
7use binius_utils::bail;
8
9use crate::polynomial::Error;
10
11#[derive(Clone, Debug)]
14pub struct IndexComposition<C, const N: usize> {
15 n_vars: usize,
17 indices: [usize; N],
19 composition: C,
21}
22
23impl<C, const N: usize> IndexComposition<C, N> {
24 pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
25 if indices.iter().any(|&index| index >= n_vars) {
26 bail!(Error::IndexCompositionIndicesOutOfBounds);
27 }
28
29 Ok(Self {
30 n_vars,
31 indices,
32 composition,
33 })
34 }
35}
36
37impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
38 for IndexComposition<C, N>
39{
40 fn n_vars(&self) -> usize {
41 self.n_vars
42 }
43
44 fn degree(&self) -> usize {
45 self.composition.degree()
46 }
47
48 fn expression(&self) -> ArithExpr<<P as PackedField>::Scalar> {
49 fn map_variables<F: Field, const M: usize>(
50 index_map: &[usize; M],
51 expr: &ArithExpr<F>,
52 ) -> ArithExpr<F> {
53 match expr {
54 ArithExpr::Var(i) => ArithExpr::Var(index_map[*i]),
55 ArithExpr::Const(c) => ArithExpr::Const(*c),
56 ArithExpr::Add(a, b) => ArithExpr::Add(
57 Box::new(map_variables(index_map, a)),
58 Box::new(map_variables(index_map, b)),
59 ),
60 ArithExpr::Mul(a, b) => ArithExpr::Mul(
61 Box::new(map_variables(index_map, a)),
62 Box::new(map_variables(index_map, b)),
63 ),
64 ArithExpr::Pow(a, n) => ArithExpr::Pow(Box::new(map_variables(index_map, a)), *n),
65 }
66 }
67
68 map_variables(&self.indices, &self.composition.expression())
69 }
70
71 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
72 if query.len() != self.n_vars {
73 bail!(binius_math::Error::IncorrectQuerySize {
74 expected: self.n_vars,
75 });
76 }
77
78 let subquery = self.indices.map(|index| query[index]);
79 self.composition.evaluate(&subquery)
80 }
81
82 fn binary_tower_level(&self) -> usize {
83 self.composition.binary_tower_level()
84 }
85
86 fn batch_evaluate(
87 &self,
88 batch_query: &[&[P]],
89 evals: &mut [P],
90 ) -> Result<(), binius_math::Error> {
91 let batch_subquery = self.indices.map(|index| batch_query[index]);
92 self.composition
93 .batch_evaluate(batch_subquery.as_slice(), evals)
94 }
95}
96
97pub fn index_composition<E, C, const N: usize>(
103 superset: &[E],
104 subset: [E; N],
105 composition: C,
106) -> Result<IndexComposition<C, N>, Error>
107where
108 E: PartialEq,
109{
110 let n_vars = superset.len();
111
112 let proper_subset = subset.iter().all(|subset_item| {
114 superset
115 .iter()
116 .any(|superset_item| superset_item == subset_item)
117 });
118
119 if !proper_subset {
120 bail!(Error::MixedMultilinearNotFound);
121 }
122
123 let indices = subset.map(|subset_item| {
124 superset
125 .iter()
126 .position(|superset_item| superset_item == &subset_item)
127 .expect("Superset condition checked above.")
128 });
129
130 Ok(IndexComposition {
131 n_vars,
132 indices,
133 composition,
134 })
135}
136
137#[derive(Debug)]
138pub enum FixedDimIndexCompositions<C> {
139 Quadrivariate(IndexComposition<C, 4>),
140 Trivariate(IndexComposition<C, 3>),
141 Bivariate(IndexComposition<C, 2>),
142}
143
144impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
145 for FixedDimIndexCompositions<C>
146{
147 fn n_vars(&self) -> usize {
148 match self {
149 Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
150 Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
151 Self::Quadrivariate(index_composition) => {
152 CompositionPoly::<P>::n_vars(index_composition)
153 }
154 }
155 }
156
157 fn degree(&self) -> usize {
158 match self {
159 Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
160 Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
161 Self::Quadrivariate(index_composition) => {
162 CompositionPoly::<P>::degree(index_composition)
163 }
164 }
165 }
166
167 fn binary_tower_level(&self) -> usize {
168 match self {
169 Self::Trivariate(index_composition) => {
170 CompositionPoly::<P>::binary_tower_level(index_composition)
171 }
172 Self::Bivariate(index_composition) => {
173 CompositionPoly::<P>::binary_tower_level(index_composition)
174 }
175 Self::Quadrivariate(index_composition) => {
176 CompositionPoly::<P>::binary_tower_level(index_composition)
177 }
178 }
179 }
180
181 fn expression(&self) -> ArithExpr<P::Scalar> {
182 match self {
183 Self::Trivariate(index_composition) => {
184 CompositionPoly::<P>::expression(index_composition)
185 }
186 Self::Bivariate(index_composition) => {
187 CompositionPoly::<P>::expression(index_composition)
188 }
189 Self::Quadrivariate(index_composition) => {
190 CompositionPoly::<P>::expression(index_composition)
191 }
192 }
193 }
194
195 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
196 match self {
197 Self::Trivariate(index_composition) => index_composition.evaluate(query),
198 Self::Bivariate(index_composition) => index_composition.evaluate(query),
199 Self::Quadrivariate(index_composition) => index_composition.evaluate(query),
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use binius_field::BinaryField1b;
207
208 use super::*;
209 use crate::polynomial::ArithCircuitPoly;
210
211 #[test]
212 fn tests_expr() {
213 let expr = ArithExpr::Add(
214 Box::new(ArithExpr::Var(0)),
215 Box::new(ArithExpr::Mul(
216 Box::new(ArithExpr::Var(1)),
217 Box::new(ArithExpr::Const(BinaryField1b::ONE)),
218 )),
219 );
220 let circuit = ArithCircuitPoly::new(expr);
221
222 let composition = IndexComposition {
223 n_vars: 3,
224 indices: [1, 2],
225 composition: circuit,
226 };
227
228 assert_eq!(
229 (&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
230 ArithExpr::Add(
231 Box::new(ArithExpr::Var(1)),
232 Box::new(ArithExpr::Mul(
233 Box::new(ArithExpr::Var(2)),
234 Box::new(ArithExpr::Const(BinaryField1b::ONE)),
235 )),
236 )
237 );
238 }
239}