binius_math/
composition_poly.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
// Copyright 2024 Irreducible Inc.

use crate::Error;
use auto_impl::auto_impl;
use binius_field::{ExtensionField, Field, PackedField};
use stackalloc::stackalloc_with_default;
use std::fmt::Debug;

/// A multivariate polynomial that defines a composition of `MultilinearComposite`.
/// This is an object-safe version of the trait.
#[auto_impl(Arc, &)]
pub trait CompositionPolyOS<P>: Debug + Send + Sync
where
	P: PackedField,
{
	/// The number of variables.
	fn n_vars(&self) -> usize;

	/// Total degree of the polynomial.
	fn degree(&self) -> usize;

	/// Evaluates the polynomial using packed values, where each packed value may contain multiple scalar values.
	/// The evaluation follows SIMD semantics, meaning that operations are performed
	/// element-wise across corresponding scalar values in the packed values.
	///
	/// For example, given a polynomial represented as `query[0] + query[1]`:
	/// - The addition operation is applied element-wise between `query[0]` and `query[1]`.
	/// - Each scalar value in `query[0]` is added to the corresponding scalar value in `query[1]`.
	/// - There are no operations performed between scalar values within the same packed value.
	fn evaluate(&self, query: &[P]) -> Result<P, Error>;

	/// Returns the maximum binary tower level of all constants in the arithmetic expression.
	fn binary_tower_level(&self) -> usize;

	/// Batch evaluation that admits non-strided argument layout.
	/// `batch_query` is a slice of slice references of equal length, which furthermore should equal
	/// the length of `evals` parameter.
	///
	/// Evaluation follows SIMD semantics as in `evaluate`:
	/// - `evals[j] := composition([batch_query[i][j] forall i]) forall j`
	/// - no crosstalk between evaluations
	///
	/// This method has a default implementation.
	fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> {
		let row_len = batch_query.first().map_or(0, |row| row.len());

		if evals.len() != row_len || batch_query.iter().any(|row| row.len() != row_len) {
			return Err(Error::BatchEvaluateSizeMismatch);
		}

		stackalloc_with_default(batch_query.len(), |query| {
			for (column, eval) in evals.iter_mut().enumerate() {
				for (query_elem, batch_query_row) in query.iter_mut().zip(batch_query) {
					*query_elem = batch_query_row[column];
				}

				*eval = self.evaluate(query)?;
			}
			Ok(())
		})
	}
}

/// A generic version of the `CompositionPolyOS` trait that is not object-safe.
#[auto_impl(&)]
pub trait CompositionPoly<F: Field>: Debug + Send + Sync {
	fn n_vars(&self) -> usize;

	fn degree(&self) -> usize;

	fn evaluate<P: PackedField<Scalar: ExtensionField<F>>>(&self, query: &[P]) -> Result<P, Error>;

	fn binary_tower_level(&self) -> usize;

	fn batch_evaluate<P: PackedField<Scalar: ExtensionField<F>>>(
		&self,
		batch_query: &[&[P]],
		evals: &mut [P],
	) -> Result<(), Error>;
}