binius_math/composition_poly.rs
1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use auto_impl::auto_impl;
6use binius_field::PackedField;
7use binius_utils::bail;
8use stackalloc::stackalloc_with_default;
9
10use crate::{ArithExpr, Error, RowsBatchRef};
11
12/// A multivariate polynomial that is used as a composition of several multilinear polynomials.
13#[auto_impl(Arc, &)]
14pub trait CompositionPoly<P>: Debug + Send + Sync
15where
16 P: PackedField,
17{
18 /// The number of variables.
19 fn n_vars(&self) -> usize;
20
21 /// Total degree of the polynomial.
22 fn degree(&self) -> usize;
23
24 /// Returns the maximum binary tower level of all constants in the arithmetic expression.
25 fn binary_tower_level(&self) -> usize;
26
27 /// Returns the arithmetic expression representing the polynomial.
28 fn expression(&self) -> ArithExpr<P::Scalar>;
29
30 /// Evaluates the polynomial using packed values, where each packed value may contain multiple scalar values.
31 /// The evaluation follows SIMD semantics, meaning that operations are performed
32 /// element-wise across corresponding scalar values in the packed values.
33 ///
34 /// For example, given a polynomial represented as `query[0] + query[1]`:
35 /// - The addition operation is applied element-wise between `query[0]` and `query[1]`.
36 /// - Each scalar value in `query[0]` is added to the corresponding scalar value in `query[1]`.
37 /// - There are no operations performed between scalar values within the same packed value.
38 fn evaluate(&self, query: &[P]) -> Result<P, Error>;
39
40 /// Batch evaluation that admits non-strided argument layout.
41 /// `batch_query` is a slice of slice references of equal length, which furthermore should equal
42 /// the length of `evals` parameter.
43 ///
44 /// Evaluation follows SIMD semantics as in `evaluate`:
45 /// - `evals[j] := composition([batch_query[i][j] forall i]) forall j`
46 /// - no crosstalk between evaluations
47 ///
48 /// This method has a default implementation.
49 fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
50 let row_len = evals.len();
51 if batch_query.row_len() != row_len {
52 bail!(Error::BatchEvaluateSizeMismatch {
53 expected: row_len,
54 actual: batch_query.row_len(),
55 });
56 }
57
58 stackalloc_with_default(batch_query.n_rows(), |query| {
59 for (column, eval) in evals.iter_mut().enumerate() {
60 for (query_elem, batch_query_row) in query.iter_mut().zip(batch_query.iter()) {
61 *query_elem = batch_query_row[column];
62 }
63
64 *eval = self.evaluate(query)?;
65 }
66 Ok(())
67 })
68 }
69}