binius_math/
composition_poly.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use auto_impl::auto_impl;
6use binius_field::PackedField;
7use binius_utils::bail;
8use stackalloc::stackalloc_with_default;
9
10use crate::{ArithExpr, Error, RowsBatchRef};
11
12/// A multivariate polynomial that is used as a composition of several multilinear polynomials.
13#[auto_impl(Arc, &)]
14pub trait CompositionPoly<P>: Debug + Send + Sync
15where
16	P: PackedField,
17{
18	/// The number of variables.
19	fn n_vars(&self) -> usize;
20
21	/// Total degree of the polynomial.
22	fn degree(&self) -> usize;
23
24	/// Returns the maximum binary tower level of all constants in the arithmetic expression.
25	fn binary_tower_level(&self) -> usize;
26
27	/// Returns the arithmetic expression representing the polynomial.
28	fn expression(&self) -> ArithExpr<P::Scalar>;
29
30	/// Evaluates the polynomial using packed values, where each packed value may contain multiple scalar values.
31	/// The evaluation follows SIMD semantics, meaning that operations are performed
32	/// element-wise across corresponding scalar values in the packed values.
33	///
34	/// For example, given a polynomial represented as `query[0] + query[1]`:
35	/// - The addition operation is applied element-wise between `query[0]` and `query[1]`.
36	/// - Each scalar value in `query[0]` is added to the corresponding scalar value in `query[1]`.
37	/// - There are no operations performed between scalar values within the same packed value.
38	fn evaluate(&self, query: &[P]) -> Result<P, Error>;
39
40	/// Batch evaluation that admits non-strided argument layout.
41	/// `batch_query` is a slice of slice references of equal length, which furthermore should equal
42	/// the length of `evals` parameter.
43	///
44	/// Evaluation follows SIMD semantics as in `evaluate`:
45	/// - `evals[j] := composition([batch_query[i][j] forall i]) forall j`
46	/// - no crosstalk between evaluations
47	///
48	/// This method has a default implementation.
49	fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
50		let row_len = evals.len();
51		if batch_query.row_len() != row_len {
52			bail!(Error::BatchEvaluateSizeMismatch {
53				expected: row_len,
54				actual: batch_query.row_len(),
55			});
56		}
57
58		stackalloc_with_default(batch_query.n_rows(), |query| {
59			for (column, eval) in evals.iter_mut().enumerate() {
60				for (query_elem, batch_query_row) in query.iter_mut().zip(batch_query.iter()) {
61					*query_elem = batch_query_row[column];
62				}
63
64				*eval = self.evaluate(query)?;
65			}
66			Ok(())
67		})
68	}
69}