binius_core/composition/
index.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use binius_field::{Field, PackedField};
6use binius_math::{ArithExpr, CompositionPoly};
7use binius_utils::bail;
8
9use crate::polynomial::Error;
10
11/// An adapter which allows evaluating a composition over a larger query by indexing into it.
12/// See [`index_composition`] for a factory method.
13#[derive(Clone, Debug)]
14pub struct IndexComposition<C, const N: usize> {
15	/// Number of variables in a larger query
16	n_vars: usize,
17	/// Mapping from the inner composition query variables to outer query variables
18	indices: [usize; N],
19	/// Inner composition
20	composition: C,
21}
22
23impl<C, const N: usize> IndexComposition<C, N> {
24	pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
25		if indices.iter().any(|&index| index >= n_vars) {
26			bail!(Error::IndexCompositionIndicesOutOfBounds);
27		}
28
29		Ok(Self {
30			n_vars,
31			indices,
32			composition,
33		})
34	}
35}
36
37impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
38	for IndexComposition<C, N>
39{
40	fn n_vars(&self) -> usize {
41		self.n_vars
42	}
43
44	fn degree(&self) -> usize {
45		self.composition.degree()
46	}
47
48	fn expression(&self) -> ArithExpr<<P as PackedField>::Scalar> {
49		fn map_variables<F: Field, const M: usize>(
50			index_map: &[usize; M],
51			expr: &ArithExpr<F>,
52		) -> ArithExpr<F> {
53			match expr {
54				ArithExpr::Var(i) => ArithExpr::Var(index_map[*i]),
55				ArithExpr::Const(c) => ArithExpr::Const(*c),
56				ArithExpr::Add(a, b) => ArithExpr::Add(
57					Box::new(map_variables(index_map, a)),
58					Box::new(map_variables(index_map, b)),
59				),
60				ArithExpr::Mul(a, b) => ArithExpr::Mul(
61					Box::new(map_variables(index_map, a)),
62					Box::new(map_variables(index_map, b)),
63				),
64				ArithExpr::Pow(a, n) => ArithExpr::Pow(Box::new(map_variables(index_map, a)), *n),
65			}
66		}
67
68		map_variables(&self.indices, &self.composition.expression())
69	}
70
71	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
72		if query.len() != self.n_vars {
73			bail!(binius_math::Error::IncorrectQuerySize {
74				expected: self.n_vars,
75			});
76		}
77
78		let subquery = self.indices.map(|index| query[index]);
79		self.composition.evaluate(&subquery)
80	}
81
82	fn binary_tower_level(&self) -> usize {
83		self.composition.binary_tower_level()
84	}
85
86	fn batch_evaluate(
87		&self,
88		batch_query: &[&[P]],
89		evals: &mut [P],
90	) -> Result<(), binius_math::Error> {
91		let batch_subquery = self.indices.map(|index| batch_query[index]);
92		self.composition
93			.batch_evaluate(batch_subquery.as_slice(), evals)
94	}
95}
96
97/// A factory helper method to create an [`IndexComposition`] by looking at
98///  * `superset` - a set of identifiers of a greater (outer) query
99///  * `subset` - a set of identifiers of a smaller query, the one which corresponds to the inner composition directly
100///
101/// Identifiers may be anything `Eq` - `OracleId`, `MultilinearPolyOracle<F>`, etc.
102pub fn index_composition<E, C, const N: usize>(
103	superset: &[E],
104	subset: [E; N],
105	composition: C,
106) -> Result<IndexComposition<C, N>, Error>
107where
108	E: PartialEq,
109{
110	let n_vars = superset.len();
111
112	// array_try_map is unstable as of 03/24, check the condition beforehand
113	let proper_subset = subset.iter().all(|subset_item| {
114		superset
115			.iter()
116			.any(|superset_item| superset_item == subset_item)
117	});
118
119	if !proper_subset {
120		bail!(Error::MixedMultilinearNotFound);
121	}
122
123	let indices = subset.map(|subset_item| {
124		superset
125			.iter()
126			.position(|superset_item| superset_item == &subset_item)
127			.expect("Superset condition checked above.")
128	});
129
130	Ok(IndexComposition {
131		n_vars,
132		indices,
133		composition,
134	})
135}
136
137#[derive(Debug)]
138pub enum FixedDimIndexCompositions<C> {
139	Quadrivariate(IndexComposition<C, 4>),
140	Trivariate(IndexComposition<C, 3>),
141	Bivariate(IndexComposition<C, 2>),
142}
143
144impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
145	for FixedDimIndexCompositions<C>
146{
147	fn n_vars(&self) -> usize {
148		match self {
149			Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
150			Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
151			Self::Quadrivariate(index_composition) => {
152				CompositionPoly::<P>::n_vars(index_composition)
153			}
154		}
155	}
156
157	fn degree(&self) -> usize {
158		match self {
159			Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
160			Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
161			Self::Quadrivariate(index_composition) => {
162				CompositionPoly::<P>::degree(index_composition)
163			}
164		}
165	}
166
167	fn binary_tower_level(&self) -> usize {
168		match self {
169			Self::Trivariate(index_composition) => {
170				CompositionPoly::<P>::binary_tower_level(index_composition)
171			}
172			Self::Bivariate(index_composition) => {
173				CompositionPoly::<P>::binary_tower_level(index_composition)
174			}
175			Self::Quadrivariate(index_composition) => {
176				CompositionPoly::<P>::binary_tower_level(index_composition)
177			}
178		}
179	}
180
181	fn expression(&self) -> ArithExpr<P::Scalar> {
182		match self {
183			Self::Trivariate(index_composition) => {
184				CompositionPoly::<P>::expression(index_composition)
185			}
186			Self::Bivariate(index_composition) => {
187				CompositionPoly::<P>::expression(index_composition)
188			}
189			Self::Quadrivariate(index_composition) => {
190				CompositionPoly::<P>::expression(index_composition)
191			}
192		}
193	}
194
195	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
196		match self {
197			Self::Trivariate(index_composition) => index_composition.evaluate(query),
198			Self::Bivariate(index_composition) => index_composition.evaluate(query),
199			Self::Quadrivariate(index_composition) => index_composition.evaluate(query),
200		}
201	}
202}
203
204#[cfg(test)]
205mod tests {
206	use binius_field::BinaryField1b;
207
208	use super::*;
209	use crate::polynomial::ArithCircuitPoly;
210
211	#[test]
212	fn tests_expr() {
213		let expr = ArithExpr::Add(
214			Box::new(ArithExpr::Var(0)),
215			Box::new(ArithExpr::Mul(
216				Box::new(ArithExpr::Var(1)),
217				Box::new(ArithExpr::Const(BinaryField1b::ONE)),
218			)),
219		);
220		let circuit = ArithCircuitPoly::new(expr);
221
222		let composition = IndexComposition {
223			n_vars: 3,
224			indices: [1, 2],
225			composition: circuit,
226		};
227
228		assert_eq!(
229			(&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
230			ArithExpr::Add(
231				Box::new(ArithExpr::Var(1)),
232				Box::new(ArithExpr::Mul(
233					Box::new(ArithExpr::Var(2)),
234					Box::new(ArithExpr::Const(BinaryField1b::ONE)),
235				)),
236			)
237		);
238	}
239}