binius_core/composition/
index.rs1use std::fmt::Debug;
4
5use binius_field::PackedField;
6use binius_math::{ArithExpr, CompositionPoly, RowsBatchRef};
7use binius_utils::bail;
8
9use crate::polynomial::Error;
10
11#[derive(Clone, Debug)]
14pub struct IndexComposition<C, const N: usize> {
15 n_vars: usize,
17 indices: [usize; N],
19 composition: C,
21}
22
23impl<C, const N: usize> IndexComposition<C, N> {
24 pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
25 if indices.iter().any(|&index| index >= n_vars) {
26 bail!(Error::IndexCompositionIndicesOutOfBounds);
27 }
28
29 Ok(Self {
30 n_vars,
31 indices,
32 composition,
33 })
34 }
35}
36
37impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
38 for IndexComposition<C, N>
39{
40 fn n_vars(&self) -> usize {
41 self.n_vars
42 }
43
44 fn degree(&self) -> usize {
45 self.composition.degree()
46 }
47
48 fn expression(&self) -> ArithExpr<<P as PackedField>::Scalar> {
49 self.composition
50 .expression()
51 .remap_vars(&self.indices)
52 .expect("remapping must be valid")
53 }
54
55 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
56 if query.len() != self.n_vars {
57 bail!(binius_math::Error::IncorrectQuerySize {
58 expected: self.n_vars,
59 });
60 }
61
62 let subquery = self.indices.map(|index| query[index]);
63 self.composition.evaluate(&subquery)
64 }
65
66 fn binary_tower_level(&self) -> usize {
67 self.composition.binary_tower_level()
68 }
69
70 fn batch_evaluate(
71 &self,
72 batch_query: &RowsBatchRef<P>,
73 evals: &mut [P],
74 ) -> Result<(), binius_math::Error> {
75 let batch_subquery = batch_query.map(self.indices);
76 self.composition
77 .batch_evaluate(&batch_subquery.get_ref(), evals)
78 }
79}
80
81pub fn index_composition<E, C, const N: usize>(
87 superset: &[E],
88 subset: [E; N],
89 composition: C,
90) -> Result<IndexComposition<C, N>, Error>
91where
92 E: PartialEq,
93{
94 let n_vars = superset.len();
95
96 let proper_subset = subset.iter().all(|subset_item| {
98 superset
99 .iter()
100 .any(|superset_item| superset_item == subset_item)
101 });
102
103 if !proper_subset {
104 bail!(Error::MixedMultilinearNotFound);
105 }
106
107 let indices = subset.map(|subset_item| {
108 superset
109 .iter()
110 .position(|superset_item| superset_item == &subset_item)
111 .expect("Superset condition checked above.")
112 });
113
114 Ok(IndexComposition {
115 n_vars,
116 indices,
117 composition,
118 })
119}
120
121#[derive(Debug)]
122pub enum FixedDimIndexCompositions<C> {
123 Trivariate(IndexComposition<C, 3>),
124 Bivariate(IndexComposition<C, 2>),
125}
126
127impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
128 for FixedDimIndexCompositions<C>
129{
130 fn n_vars(&self) -> usize {
131 match self {
132 Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
133 Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
134 }
135 }
136
137 fn degree(&self) -> usize {
138 match self {
139 Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
140 Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
141 }
142 }
143
144 fn binary_tower_level(&self) -> usize {
145 match self {
146 Self::Trivariate(index_composition) => {
147 CompositionPoly::<P>::binary_tower_level(index_composition)
148 }
149 Self::Bivariate(index_composition) => {
150 CompositionPoly::<P>::binary_tower_level(index_composition)
151 }
152 }
153 }
154
155 fn expression(&self) -> ArithExpr<P::Scalar> {
156 match self {
157 Self::Trivariate(index_composition) => {
158 CompositionPoly::<P>::expression(index_composition)
159 }
160 Self::Bivariate(index_composition) => {
161 CompositionPoly::<P>::expression(index_composition)
162 }
163 }
164 }
165
166 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
167 match self {
168 Self::Trivariate(index_composition) => index_composition.evaluate(query),
169 Self::Bivariate(index_composition) => index_composition.evaluate(query),
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::sync::Arc;
177
178 use binius_field::{BinaryField1b, Field};
179
180 use super::*;
181 use crate::polynomial::ArithCircuitPoly;
182
183 #[test]
184 fn tests_expr() {
185 let expr = ArithExpr::Mul(
186 Arc::new(ArithExpr::Var(0)),
187 Arc::new(ArithExpr::Add(
188 Arc::new(ArithExpr::Var(1)),
189 Arc::new(ArithExpr::Const(BinaryField1b::ONE)),
190 )),
191 );
192 let circuit = ArithCircuitPoly::new(&expr);
193
194 let composition = IndexComposition {
195 n_vars: 3,
196 indices: [1, 2],
197 composition: circuit,
198 };
199
200 assert_eq!(
201 (&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
202 ArithExpr::Mul(
203 Arc::new(ArithExpr::Var(1)),
204 Arc::new(ArithExpr::Add(
205 Arc::new(ArithExpr::Var(2)),
206 Arc::new(ArithExpr::Const(BinaryField1b::ONE)),
207 )),
208 )
209 );
210 }
211}