1use std::fmt::Debug;
4
5use binius_field::{Field, PackedField};
6use binius_math::{ArithExpr, CompositionPoly, RowsBatchRef};
7use binius_utils::bail;
8
9use crate::polynomial::Error;
10
11#[derive(Clone, Debug)]
14pub struct IndexComposition<C, const N: usize> {
15 n_vars: usize,
17 indices: [usize; N],
19 composition: C,
21}
22
23impl<C, const N: usize> IndexComposition<C, N> {
24 pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
25 if indices.iter().any(|&index| index >= n_vars) {
26 bail!(Error::IndexCompositionIndicesOutOfBounds);
27 }
28
29 Ok(Self {
30 n_vars,
31 indices,
32 composition,
33 })
34 }
35}
36
37impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
38 for IndexComposition<C, N>
39{
40 fn n_vars(&self) -> usize {
41 self.n_vars
42 }
43
44 fn degree(&self) -> usize {
45 self.composition.degree()
46 }
47
48 fn expression(&self) -> ArithExpr<<P as PackedField>::Scalar> {
49 fn map_variables<F: Field, const M: usize>(
50 index_map: &[usize; M],
51 expr: &ArithExpr<F>,
52 ) -> ArithExpr<F> {
53 match expr {
54 ArithExpr::Var(i) => ArithExpr::Var(index_map[*i]),
55 ArithExpr::Const(c) => ArithExpr::Const(*c),
56 ArithExpr::Add(a, b) => ArithExpr::Add(
57 Box::new(map_variables(index_map, a)),
58 Box::new(map_variables(index_map, b)),
59 ),
60 ArithExpr::Mul(a, b) => ArithExpr::Mul(
61 Box::new(map_variables(index_map, a)),
62 Box::new(map_variables(index_map, b)),
63 ),
64 ArithExpr::Pow(a, n) => ArithExpr::Pow(Box::new(map_variables(index_map, a)), *n),
65 }
66 }
67
68 map_variables(&self.indices, &self.composition.expression())
69 }
70
71 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
72 if query.len() != self.n_vars {
73 bail!(binius_math::Error::IncorrectQuerySize {
74 expected: self.n_vars,
75 });
76 }
77
78 let subquery = self.indices.map(|index| query[index]);
79 self.composition.evaluate(&subquery)
80 }
81
82 fn binary_tower_level(&self) -> usize {
83 self.composition.binary_tower_level()
84 }
85
86 fn batch_evaluate(
87 &self,
88 batch_query: &RowsBatchRef<P>,
89 evals: &mut [P],
90 ) -> Result<(), binius_math::Error> {
91 let batch_subquery = batch_query.map(self.indices);
92 self.composition
93 .batch_evaluate(&batch_subquery.get_ref(), evals)
94 }
95}
96
97pub fn index_composition<E, C, const N: usize>(
103 superset: &[E],
104 subset: [E; N],
105 composition: C,
106) -> Result<IndexComposition<C, N>, Error>
107where
108 E: PartialEq,
109{
110 let n_vars = superset.len();
111
112 let proper_subset = subset.iter().all(|subset_item| {
114 superset
115 .iter()
116 .any(|superset_item| superset_item == subset_item)
117 });
118
119 if !proper_subset {
120 bail!(Error::MixedMultilinearNotFound);
121 }
122
123 let indices = subset.map(|subset_item| {
124 superset
125 .iter()
126 .position(|superset_item| superset_item == &subset_item)
127 .expect("Superset condition checked above.")
128 });
129
130 Ok(IndexComposition {
131 n_vars,
132 indices,
133 composition,
134 })
135}
136
137#[derive(Debug)]
138pub enum FixedDimIndexCompositions<C> {
139 Trivariate(IndexComposition<C, 3>),
140 Bivariate(IndexComposition<C, 2>),
141}
142
143impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
144 for FixedDimIndexCompositions<C>
145{
146 fn n_vars(&self) -> usize {
147 match self {
148 Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
149 Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
150 }
151 }
152
153 fn degree(&self) -> usize {
154 match self {
155 Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
156 Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
157 }
158 }
159
160 fn binary_tower_level(&self) -> usize {
161 match self {
162 Self::Trivariate(index_composition) => {
163 CompositionPoly::<P>::binary_tower_level(index_composition)
164 }
165 Self::Bivariate(index_composition) => {
166 CompositionPoly::<P>::binary_tower_level(index_composition)
167 }
168 }
169 }
170
171 fn expression(&self) -> ArithExpr<P::Scalar> {
172 match self {
173 Self::Trivariate(index_composition) => {
174 CompositionPoly::<P>::expression(index_composition)
175 }
176 Self::Bivariate(index_composition) => {
177 CompositionPoly::<P>::expression(index_composition)
178 }
179 }
180 }
181
182 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
183 match self {
184 Self::Trivariate(index_composition) => index_composition.evaluate(query),
185 Self::Bivariate(index_composition) => index_composition.evaluate(query),
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use binius_field::BinaryField1b;
193
194 use super::*;
195 use crate::polynomial::ArithCircuitPoly;
196
197 #[test]
198 fn tests_expr() {
199 let expr = ArithExpr::Add(
200 Box::new(ArithExpr::Var(0)),
201 Box::new(ArithExpr::Mul(
202 Box::new(ArithExpr::Var(1)),
203 Box::new(ArithExpr::Const(BinaryField1b::ONE)),
204 )),
205 );
206 let circuit = ArithCircuitPoly::new(expr);
207
208 let composition = IndexComposition {
209 n_vars: 3,
210 indices: [1, 2],
211 composition: circuit,
212 };
213
214 assert_eq!(
215 (&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
216 ArithExpr::Add(
217 Box::new(ArithExpr::Var(1)),
218 Box::new(ArithExpr::Mul(
219 Box::new(ArithExpr::Var(2)),
220 Box::new(ArithExpr::Const(BinaryField1b::ONE)),
221 )),
222 )
223 );
224 }
225}