binius_math/composition_poly.rs
1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use auto_impl::auto_impl;
6use binius_field::PackedField;
7use stackalloc::stackalloc_with_default;
8
9use crate::{ArithExpr, Error};
10
11/// A multivariate polynomial that is used as a composition of several multilinear polynomials.
12#[auto_impl(Arc, &)]
13pub trait CompositionPoly<P>: Debug + Send + Sync
14where
15 P: PackedField,
16{
17 /// The number of variables.
18 fn n_vars(&self) -> usize;
19
20 /// Total degree of the polynomial.
21 fn degree(&self) -> usize;
22
23 /// Returns the maximum binary tower level of all constants in the arithmetic expression.
24 fn binary_tower_level(&self) -> usize;
25
26 /// Returns the arithmetic expression representing the polynomial.
27 fn expression(&self) -> ArithExpr<P::Scalar>;
28
29 /// Evaluates the polynomial using packed values, where each packed value may contain multiple scalar values.
30 /// The evaluation follows SIMD semantics, meaning that operations are performed
31 /// element-wise across corresponding scalar values in the packed values.
32 ///
33 /// For example, given a polynomial represented as `query[0] + query[1]`:
34 /// - The addition operation is applied element-wise between `query[0]` and `query[1]`.
35 /// - Each scalar value in `query[0]` is added to the corresponding scalar value in `query[1]`.
36 /// - There are no operations performed between scalar values within the same packed value.
37 fn evaluate(&self, query: &[P]) -> Result<P, Error>;
38
39 /// Batch evaluation that admits non-strided argument layout.
40 /// `batch_query` is a slice of slice references of equal length, which furthermore should equal
41 /// the length of `evals` parameter.
42 ///
43 /// Evaluation follows SIMD semantics as in `evaluate`:
44 /// - `evals[j] := composition([batch_query[i][j] forall i]) forall j`
45 /// - no crosstalk between evaluations
46 ///
47 /// This method has a default implementation.
48 fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> {
49 let row_len = evals.len();
50 if batch_query.iter().any(|row| row.len() != row_len) {
51 return Err(Error::BatchEvaluateSizeMismatch);
52 }
53
54 stackalloc_with_default(batch_query.len(), |query| {
55 for (column, eval) in evals.iter_mut().enumerate() {
56 for (query_elem, batch_query_row) in query.iter_mut().zip(batch_query) {
57 *query_elem = batch_query_row[column];
58 }
59
60 *eval = self.evaluate(query)?;
61 }
62 Ok(())
63 })
64 }
65}