binius_core/composition/
index.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use binius_field::{Field, PackedField};
6use binius_math::{ArithExpr, CompositionPoly, RowsBatchRef};
7use binius_utils::bail;
8
9use crate::polynomial::Error;
10
11/// An adapter which allows evaluating a composition over a larger query by indexing into it.
12/// See [`index_composition`] for a factory method.
13#[derive(Clone, Debug)]
14pub struct IndexComposition<C, const N: usize> {
15	/// Number of variables in a larger query
16	n_vars: usize,
17	/// Mapping from the inner composition query variables to outer query variables
18	indices: [usize; N],
19	/// Inner composition
20	composition: C,
21}
22
23impl<C, const N: usize> IndexComposition<C, N> {
24	pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
25		if indices.iter().any(|&index| index >= n_vars) {
26			bail!(Error::IndexCompositionIndicesOutOfBounds);
27		}
28
29		Ok(Self {
30			n_vars,
31			indices,
32			composition,
33		})
34	}
35}
36
37impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
38	for IndexComposition<C, N>
39{
40	fn n_vars(&self) -> usize {
41		self.n_vars
42	}
43
44	fn degree(&self) -> usize {
45		self.composition.degree()
46	}
47
48	fn expression(&self) -> ArithExpr<<P as PackedField>::Scalar> {
49		fn map_variables<F: Field, const M: usize>(
50			index_map: &[usize; M],
51			expr: &ArithExpr<F>,
52		) -> ArithExpr<F> {
53			match expr {
54				ArithExpr::Var(i) => ArithExpr::Var(index_map[*i]),
55				ArithExpr::Const(c) => ArithExpr::Const(*c),
56				ArithExpr::Add(a, b) => ArithExpr::Add(
57					Box::new(map_variables(index_map, a)),
58					Box::new(map_variables(index_map, b)),
59				),
60				ArithExpr::Mul(a, b) => ArithExpr::Mul(
61					Box::new(map_variables(index_map, a)),
62					Box::new(map_variables(index_map, b)),
63				),
64				ArithExpr::Pow(a, n) => ArithExpr::Pow(Box::new(map_variables(index_map, a)), *n),
65			}
66		}
67
68		map_variables(&self.indices, &self.composition.expression())
69	}
70
71	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
72		if query.len() != self.n_vars {
73			bail!(binius_math::Error::IncorrectQuerySize {
74				expected: self.n_vars,
75			});
76		}
77
78		let subquery = self.indices.map(|index| query[index]);
79		self.composition.evaluate(&subquery)
80	}
81
82	fn binary_tower_level(&self) -> usize {
83		self.composition.binary_tower_level()
84	}
85
86	fn batch_evaluate(
87		&self,
88		batch_query: &RowsBatchRef<P>,
89		evals: &mut [P],
90	) -> Result<(), binius_math::Error> {
91		let batch_subquery = batch_query.map(self.indices);
92		self.composition
93			.batch_evaluate(&batch_subquery.get_ref(), evals)
94	}
95}
96
97/// A factory helper method to create an [`IndexComposition`] by looking at
98///  * `superset` - a set of identifiers of a greater (outer) query
99///  * `subset` - a set of identifiers of a smaller query, the one which corresponds to the inner composition directly
100///
101/// Identifiers may be anything `Eq` - `OracleId`, `MultilinearPolyOracle<F>`, etc.
102pub fn index_composition<E, C, const N: usize>(
103	superset: &[E],
104	subset: [E; N],
105	composition: C,
106) -> Result<IndexComposition<C, N>, Error>
107where
108	E: PartialEq,
109{
110	let n_vars = superset.len();
111
112	// array_try_map is unstable as of 03/24, check the condition beforehand
113	let proper_subset = subset.iter().all(|subset_item| {
114		superset
115			.iter()
116			.any(|superset_item| superset_item == subset_item)
117	});
118
119	if !proper_subset {
120		bail!(Error::MixedMultilinearNotFound);
121	}
122
123	let indices = subset.map(|subset_item| {
124		superset
125			.iter()
126			.position(|superset_item| superset_item == &subset_item)
127			.expect("Superset condition checked above.")
128	});
129
130	Ok(IndexComposition {
131		n_vars,
132		indices,
133		composition,
134	})
135}
136
137#[derive(Debug)]
138pub enum FixedDimIndexCompositions<C> {
139	Trivariate(IndexComposition<C, 3>),
140	Bivariate(IndexComposition<C, 2>),
141}
142
143impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
144	for FixedDimIndexCompositions<C>
145{
146	fn n_vars(&self) -> usize {
147		match self {
148			Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
149			Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
150		}
151	}
152
153	fn degree(&self) -> usize {
154		match self {
155			Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
156			Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
157		}
158	}
159
160	fn binary_tower_level(&self) -> usize {
161		match self {
162			Self::Trivariate(index_composition) => {
163				CompositionPoly::<P>::binary_tower_level(index_composition)
164			}
165			Self::Bivariate(index_composition) => {
166				CompositionPoly::<P>::binary_tower_level(index_composition)
167			}
168		}
169	}
170
171	fn expression(&self) -> ArithExpr<P::Scalar> {
172		match self {
173			Self::Trivariate(index_composition) => {
174				CompositionPoly::<P>::expression(index_composition)
175			}
176			Self::Bivariate(index_composition) => {
177				CompositionPoly::<P>::expression(index_composition)
178			}
179		}
180	}
181
182	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
183		match self {
184			Self::Trivariate(index_composition) => index_composition.evaluate(query),
185			Self::Bivariate(index_composition) => index_composition.evaluate(query),
186		}
187	}
188}
189
190#[cfg(test)]
191mod tests {
192	use binius_field::BinaryField1b;
193
194	use super::*;
195	use crate::polynomial::ArithCircuitPoly;
196
197	#[test]
198	fn tests_expr() {
199		let expr = ArithExpr::Add(
200			Box::new(ArithExpr::Var(0)),
201			Box::new(ArithExpr::Mul(
202				Box::new(ArithExpr::Var(1)),
203				Box::new(ArithExpr::Const(BinaryField1b::ONE)),
204			)),
205		);
206		let circuit = ArithCircuitPoly::new(expr);
207
208		let composition = IndexComposition {
209			n_vars: 3,
210			indices: [1, 2],
211			composition: circuit,
212		};
213
214		assert_eq!(
215			(&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
216			ArithExpr::Add(
217				Box::new(ArithExpr::Var(1)),
218				Box::new(ArithExpr::Mul(
219					Box::new(ArithExpr::Var(2)),
220					Box::new(ArithExpr::Const(BinaryField1b::ONE)),
221				)),
222			)
223		);
224	}
225}