binius_core/composition/
index.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use binius_field::PackedField;
6use binius_math::{ArithCircuit, CompositionPoly, RowsBatchRef};
7use binius_utils::bail;
8use getset::Getters;
9
10use crate::polynomial::Error;
11
12/// An adapter which allows evaluating a composition over a larger query by indexing into it.
13/// See [`index_composition`] for a factory method.
14#[derive(Clone, Debug, Getters)]
15pub struct IndexComposition<C, const N: usize> {
16	/// Number of variables in a larger query
17	n_vars: usize,
18	/// Mapping from the inner composition query variables to outer query variables
19	#[get = "pub"]
20	indices: [usize; N],
21	/// Inner composition
22	composition: C,
23}
24
25impl<C, const N: usize> IndexComposition<C, N> {
26	pub fn new(n_vars: usize, indices: [usize; N], composition: C) -> Result<Self, Error> {
27		if indices.iter().any(|&index| index >= n_vars) {
28			bail!(Error::IndexCompositionIndicesOutOfBounds);
29		}
30
31		Ok(Self {
32			n_vars,
33			indices,
34			composition,
35		})
36	}
37}
38
39impl<P: PackedField, C: CompositionPoly<P>, const N: usize> CompositionPoly<P>
40	for IndexComposition<C, N>
41{
42	fn n_vars(&self) -> usize {
43		self.n_vars
44	}
45
46	fn degree(&self) -> usize {
47		self.composition.degree()
48	}
49
50	fn expression(&self) -> ArithCircuit<<P as PackedField>::Scalar> {
51		self.composition
52			.expression()
53			.remap_vars(&self.indices)
54			.expect("remapping must be valid")
55	}
56
57	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
58		if query.len() != self.n_vars {
59			bail!(binius_math::Error::IncorrectQuerySize {
60				expected: self.n_vars,
61				actual: query.len(),
62			});
63		}
64
65		let subquery = self.indices.map(|index| query[index]);
66		self.composition.evaluate(&subquery)
67	}
68
69	fn binary_tower_level(&self) -> usize {
70		self.composition.binary_tower_level()
71	}
72
73	fn batch_evaluate(
74		&self,
75		batch_query: &RowsBatchRef<P>,
76		evals: &mut [P],
77	) -> Result<(), binius_math::Error> {
78		let batch_subquery = batch_query.map(self.indices);
79		self.composition
80			.batch_evaluate(&batch_subquery.get_ref(), evals)
81	}
82}
83
84/// A factory helper method to create an [`IndexComposition`] by looking at
85///  * `superset` - a set of identifiers of a greater (outer) query
86///  * `subset` - a set of identifiers of a smaller query, the one which corresponds to the inner
87///    composition directly
88///
89/// Identifiers may be anything `Eq` - `OracleId`, `MultilinearPolyOracle<F>`, etc.
90pub fn index_composition<E, C, const N: usize>(
91	superset: &[E],
92	subset: [E; N],
93	composition: C,
94) -> Result<IndexComposition<C, N>, Error>
95where
96	E: PartialEq,
97{
98	let n_vars = superset.len();
99
100	// array_try_map is unstable as of 03/24, check the condition beforehand
101	let proper_subset = subset.iter().all(|subset_item| {
102		superset
103			.iter()
104			.any(|superset_item| superset_item == subset_item)
105	});
106
107	if !proper_subset {
108		bail!(Error::MixedMultilinearNotFound);
109	}
110
111	let indices = subset.map(|subset_item| {
112		superset
113			.iter()
114			.position(|superset_item| superset_item == &subset_item)
115			.expect("Superset condition checked above.")
116	});
117
118	Ok(IndexComposition {
119		n_vars,
120		indices,
121		composition,
122	})
123}
124
125#[derive(Debug)]
126pub enum FixedDimIndexCompositions<C> {
127	Trivariate(IndexComposition<C, 3>),
128	Bivariate(IndexComposition<C, 2>),
129}
130
131impl<P: PackedField, C: CompositionPoly<P> + Debug + Send + Sync> CompositionPoly<P>
132	for FixedDimIndexCompositions<C>
133{
134	fn n_vars(&self) -> usize {
135		match self {
136			Self::Trivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
137			Self::Bivariate(index_composition) => CompositionPoly::<P>::n_vars(index_composition),
138		}
139	}
140
141	fn degree(&self) -> usize {
142		match self {
143			Self::Trivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
144			Self::Bivariate(index_composition) => CompositionPoly::<P>::degree(index_composition),
145		}
146	}
147
148	fn binary_tower_level(&self) -> usize {
149		match self {
150			Self::Trivariate(index_composition) => {
151				CompositionPoly::<P>::binary_tower_level(index_composition)
152			}
153			Self::Bivariate(index_composition) => {
154				CompositionPoly::<P>::binary_tower_level(index_composition)
155			}
156		}
157	}
158
159	fn expression(&self) -> ArithCircuit<P::Scalar> {
160		match self {
161			Self::Trivariate(index_composition) => {
162				CompositionPoly::<P>::expression(index_composition)
163			}
164			Self::Bivariate(index_composition) => {
165				CompositionPoly::<P>::expression(index_composition)
166			}
167		}
168	}
169
170	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
171		match self {
172			Self::Trivariate(index_composition) => index_composition.evaluate(query),
173			Self::Bivariate(index_composition) => index_composition.evaluate(query),
174		}
175	}
176}
177
178#[cfg(test)]
179mod tests {
180	use binius_fast_compute::arith_circuit::ArithCircuitPoly;
181	use binius_field::{BinaryField1b, Field};
182	use binius_math::ArithExpr;
183
184	use super::*;
185
186	#[test]
187	fn tests_expr() {
188		let expr = ArithExpr::Var(0) * (ArithExpr::Var(1) + ArithExpr::Const(BinaryField1b::ONE));
189		let circuit = ArithCircuitPoly::new((&expr).into());
190
191		let composition = IndexComposition {
192			n_vars: 3,
193			indices: [1, 2],
194			composition: circuit,
195		};
196
197		assert_eq!(
198			(&composition as &dyn CompositionPoly<BinaryField1b>).expression(),
199			ArithCircuit::from(
200				&(ArithExpr::Var(1) * (ArithExpr::Var(2) + ArithExpr::Const(BinaryField1b::ONE)))
201			),
202		);
203	}
204}