binius_core/composition/
index.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use binius_field::PackedField;
6use binius_math::{ArithExpr, CompositionPoly, RowsBatchRef};
7use binius_utils::bail;
8
9use crate::polynomial::Error;
10
11/// An adapter which allows evaluating a composition over a larger query by indexing into it.
12/// See [`index_composition`] for a factory method.
13#[derive(Clone, Debug)]
14pub struct IndexComposition<C, const N: usize> {
15	/// Number of variables in a larger query
16	n_vars: usize,
17	/// Mapping from the inner composition query variables to outer query variables
18	indices: [usize; N],
19	/// Inner composition
20	composition: C,
21}
22
23impl<C, const N: usize> IndexComposition<C, N> {
24	pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
25		if indices.iter().any(|&index| index >= n_vars) {
26			bail!(Error::IndexCompositionIndicesOutOfBounds);
27		}
28
29		Ok(Self {
30			n_vars,
31			indices,
32			composition,
33		})
34	}
35}
36
37impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
38	for IndexComposition<C, N>
39{
40	fn n_vars(&self) -> usize {
41		self.n_vars
42	}
43
44	fn degree(&self) -> usize {
45		self.composition.degree()
46	}
47
48	fn expression(&self) -> ArithExpr<<P as PackedField>::Scalar> {
49		self.composition
50			.expression()
51			.remap_vars(&self.indices)
52			.expect("remapping must be valid")
53	}
54
55	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
56		if query.len() != self.n_vars {
57			bail!(binius_math::Error::IncorrectQuerySize {
58				expected: self.n_vars,
59			});
60		}
61
62		let subquery = self.indices.map(|index| query[index]);
63		self.composition.evaluate(&subquery)
64	}
65
66	fn binary_tower_level(&self) -> usize {
67		self.composition.binary_tower_level()
68	}
69
70	fn batch_evaluate(
71		&self,
72		batch_query: &RowsBatchRef<P>,
73		evals: &mut [P],
74	) -> Result<(), binius_math::Error> {
75		let batch_subquery = batch_query.map(self.indices);
76		self.composition
77			.batch_evaluate(&batch_subquery.get_ref(), evals)
78	}
79}
80
81/// A factory helper method to create an [`IndexComposition`] by looking at
82///  * `superset` - a set of identifiers of a greater (outer) query
83///  * `subset` - a set of identifiers of a smaller query, the one which corresponds to the inner composition directly
84///
85/// Identifiers may be anything `Eq` - `OracleId`, `MultilinearPolyOracle<F>`, etc.
86pub fn index_composition<E, C, const N: usize>(
87	superset: &[E],
88	subset: [E; N],
89	composition: C,
90) -> Result<IndexComposition<C, N>, Error>
91where
92	E: PartialEq,
93{
94	let n_vars = superset.len();
95
96	// array_try_map is unstable as of 03/24, check the condition beforehand
97	let proper_subset = subset.iter().all(|subset_item| {
98		superset
99			.iter()
100			.any(|superset_item| superset_item == subset_item)
101	});
102
103	if !proper_subset {
104		bail!(Error::MixedMultilinearNotFound);
105	}
106
107	let indices = subset.map(|subset_item| {
108		superset
109			.iter()
110			.position(|superset_item| superset_item == &subset_item)
111			.expect("Superset condition checked above.")
112	});
113
114	Ok(IndexComposition {
115		n_vars,
116		indices,
117		composition,
118	})
119}
120
121#[derive(Debug)]
122pub enum FixedDimIndexCompositions<C> {
123	Trivariate(IndexComposition<C, 3>),
124	Bivariate(IndexComposition<C, 2>),
125}
126
127impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
128	for FixedDimIndexCompositions<C>
129{
130	fn n_vars(&self) -> usize {
131		match self {
132			Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
133			Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
134		}
135	}
136
137	fn degree(&self) -> usize {
138		match self {
139			Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
140			Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
141		}
142	}
143
144	fn binary_tower_level(&self) -> usize {
145		match self {
146			Self::Trivariate(index_composition) => {
147				CompositionPoly::<P>::binary_tower_level(index_composition)
148			}
149			Self::Bivariate(index_composition) => {
150				CompositionPoly::<P>::binary_tower_level(index_composition)
151			}
152		}
153	}
154
155	fn expression(&self) -> ArithExpr<P::Scalar> {
156		match self {
157			Self::Trivariate(index_composition) => {
158				CompositionPoly::<P>::expression(index_composition)
159			}
160			Self::Bivariate(index_composition) => {
161				CompositionPoly::<P>::expression(index_composition)
162			}
163		}
164	}
165
166	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
167		match self {
168			Self::Trivariate(index_composition) => index_composition.evaluate(query),
169			Self::Bivariate(index_composition) => index_composition.evaluate(query),
170		}
171	}
172}
173
174#[cfg(test)]
175mod tests {
176	use std::sync::Arc;
177
178	use binius_field::{BinaryField1b, Field};
179
180	use super::*;
181	use crate::polynomial::ArithCircuitPoly;
182
183	#[test]
184	fn tests_expr() {
185		let expr = ArithExpr::Mul(
186			Arc::new(ArithExpr::Var(0)),
187			Arc::new(ArithExpr::Add(
188				Arc::new(ArithExpr::Var(1)),
189				Arc::new(ArithExpr::Const(BinaryField1b::ONE)),
190			)),
191		);
192		let circuit = ArithCircuitPoly::new(&expr);
193
194		let composition = IndexComposition {
195			n_vars: 3,
196			indices: [1, 2],
197			composition: circuit,
198		};
199
200		assert_eq!(
201			(&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
202			ArithExpr::Mul(
203				Arc::new(ArithExpr::Var(1)),
204				Arc::new(ArithExpr::Add(
205					Arc::new(ArithExpr::Var(2)),
206					Arc::new(ArithExpr::Const(BinaryField1b::ONE)),
207				)),
208			)
209		);
210	}
211}