binius_core/composition/
index.rs1use std::fmt::Debug;
4
5use binius_field::PackedField;
6use binius_math::{ArithCircuit, CompositionPoly, RowsBatchRef};
7use binius_utils::bail;
8use getset::Getters;
9
10use crate::polynomial::Error;
11
12#[derive(Clone, Debug, Getters)]
15pub struct IndexComposition<C, const N: usize> {
16 n_vars: usize,
18 #[get = "pub"]
20 indices: [usize; N],
21 composition: C,
23}
24
25impl<C, const N: usize> IndexComposition<C, N> {
26 pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
27 if indices.iter().any(|&index| index >= n_vars) {
28 bail!(Error::IndexCompositionIndicesOutOfBounds);
29 }
30
31 Ok(Self {
32 n_vars,
33 indices,
34 composition,
35 })
36 }
37}
38
39impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
40 for IndexComposition<C, N>
41{
42 fn n_vars(&self) -> usize {
43 self.n_vars
44 }
45
46 fn degree(&self) -> usize {
47 self.composition.degree()
48 }
49
50 fn expression(&self) -> ArithCircuit<<P as PackedField>::Scalar> {
51 self.composition
52 .expression()
53 .remap_vars(&self.indices)
54 .expect("remapping must be valid")
55 }
56
57 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
58 if query.len() != self.n_vars {
59 bail!(binius_math::Error::IncorrectQuerySize {
60 expected: self.n_vars,
61 actual: query.len(),
62 });
63 }
64
65 let subquery = self.indices.map(|index| query[index]);
66 self.composition.evaluate(&subquery)
67 }
68
69 fn binary_tower_level(&self) -> usize {
70 self.composition.binary_tower_level()
71 }
72
73 fn batch_evaluate(
74 &self,
75 batch_query: &RowsBatchRef<P>,
76 evals: &mut [P],
77 ) -> Result<(), binius_math::Error> {
78 let batch_subquery = batch_query.map(self.indices);
79 self.composition
80 .batch_evaluate(&batch_subquery.get_ref(), evals)
81 }
82}
83
84pub fn index_composition<E, C, const N: usize>(
91 superset: &[E],
92 subset: [E; N],
93 composition: C,
94) -> Result<IndexComposition<C, N>, Error>
95where
96 E: PartialEq,
97{
98 let n_vars = superset.len();
99
100 let proper_subset = subset.iter().all(|subset_item| {
102 superset
103 .iter()
104 .any(|superset_item| superset_item == subset_item)
105 });
106
107 if !proper_subset {
108 bail!(Error::MixedMultilinearNotFound);
109 }
110
111 let indices = subset.map(|subset_item| {
112 superset
113 .iter()
114 .position(|superset_item| superset_item == &subset_item)
115 .expect("Superset condition checked above.")
116 });
117
118 Ok(IndexComposition {
119 n_vars,
120 indices,
121 composition,
122 })
123}
124
125#[derive(Debug)]
126pub enum FixedDimIndexCompositions<C> {
127 Trivariate(IndexComposition<C, 3>),
128 Bivariate(IndexComposition<C, 2>),
129}
130
131impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
132 for FixedDimIndexCompositions<C>
133{
134 fn n_vars(&self) -> usize {
135 match self {
136 Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
137 Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
138 }
139 }
140
141 fn degree(&self) -> usize {
142 match self {
143 Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
144 Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
145 }
146 }
147
148 fn binary_tower_level(&self) -> usize {
149 match self {
150 Self::Trivariate(index_composition) => {
151 CompositionPoly::<P>::binary_tower_level(index_composition)
152 }
153 Self::Bivariate(index_composition) => {
154 CompositionPoly::<P>::binary_tower_level(index_composition)
155 }
156 }
157 }
158
159 fn expression(&self) -> ArithCircuit<P::Scalar> {
160 match self {
161 Self::Trivariate(index_composition) => {
162 CompositionPoly::<P>::expression(index_composition)
163 }
164 Self::Bivariate(index_composition) => {
165 CompositionPoly::<P>::expression(index_composition)
166 }
167 }
168 }
169
170 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
171 match self {
172 Self::Trivariate(index_composition) => index_composition.evaluate(query),
173 Self::Bivariate(index_composition) => index_composition.evaluate(query),
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use binius_fast_compute::arith_circuit::ArithCircuitPoly;
181 use binius_field::{BinaryField1b, Field};
182 use binius_math::ArithExpr;
183
184 use super::*;
185
186 #[test]
187 fn tests_expr() {
188 let expr = ArithExpr::Var(0) * (ArithExpr::Var(1) + ArithExpr::Const(BinaryField1b::ONE));
189 let circuit = ArithCircuitPoly::new((&expr).into());
190
191 let composition = IndexComposition {
192 n_vars: 3,
193 indices: [1, 2],
194 composition: circuit,
195 };
196
197 assert_eq!(
198 (&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
199 ArithCircuit::from(
200 &(ArithExpr::Var(1) * (ArithExpr::Var(2) + ArithExpr::Const(BinaryField1b::ONE)))
201 ),
202 );
203 }
204}