binius_math/
composition_poly.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use auto_impl::auto_impl;
6use binius_field::PackedField;
7use stackalloc::stackalloc_with_default;
8
9use crate::{ArithExpr, Error};
10
11/// A multivariate polynomial that is used as a composition of several multilinear polynomials.
12#[auto_impl(Arc, &)]
13pub trait CompositionPoly<P>: Debug + Send + Sync
14where
15	P: PackedField,
16{
17	/// The number of variables.
18	fn n_vars(&self) -> usize;
19
20	/// Total degree of the polynomial.
21	fn degree(&self) -> usize;
22
23	/// Returns the maximum binary tower level of all constants in the arithmetic expression.
24	fn binary_tower_level(&self) -> usize;
25
26	/// Returns the arithmetic expression representing the polynomial.
27	fn expression(&self) -> ArithExpr<P::Scalar>;
28
29	/// Evaluates the polynomial using packed values, where each packed value may contain multiple scalar values.
30	/// The evaluation follows SIMD semantics, meaning that operations are performed
31	/// element-wise across corresponding scalar values in the packed values.
32	///
33	/// For example, given a polynomial represented as `query[0] + query[1]`:
34	/// - The addition operation is applied element-wise between `query[0]` and `query[1]`.
35	/// - Each scalar value in `query[0]` is added to the corresponding scalar value in `query[1]`.
36	/// - There are no operations performed between scalar values within the same packed value.
37	fn evaluate(&self, query: &[P]) -> Result<P, Error>;
38
39	/// Batch evaluation that admits non-strided argument layout.
40	/// `batch_query` is a slice of slice references of equal length, which furthermore should equal
41	/// the length of `evals` parameter.
42	///
43	/// Evaluation follows SIMD semantics as in `evaluate`:
44	/// - `evals[j] := composition([batch_query[i][j] forall i]) forall j`
45	/// - no crosstalk between evaluations
46	///
47	/// This method has a default implementation.
48	fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> {
49		let row_len = evals.len();
50		if batch_query.iter().any(|row| row.len() != row_len) {
51			return Err(Error::BatchEvaluateSizeMismatch);
52		}
53
54		stackalloc_with_default(batch_query.len(), |query| {
55			for (column, eval) in evals.iter_mut().enumerate() {
56				for (query_elem, batch_query_row) in query.iter_mut().zip(batch_query) {
57					*query_elem = batch_query_row[column];
58				}
59
60				*eval = self.evaluate(query)?;
61			}
62			Ok(())
63		})
64	}
65}