binius_hal/
backend.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	fmt::Debug,
5	ops::{Deref, DerefMut},
6};
7
8use binius_field::{Field, PackedExtension, PackedField};
9use binius_math::{
10	CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQuery,
11	MultilinearQueryRef,
12};
13use binius_maybe_rayon::iter::FromParallelIterator;
14use tracing::instrument;
15
16use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
17
18/// HAL-managed memory containing the result of its operations.
19pub trait HalSlice<P: Debug + Send + Sync>:
20	Deref<Target = [P]>
21	+ DerefMut<Target = [P]>
22	+ Debug
23	+ FromIterator<P>
24	+ FromParallelIterator<P>
25	+ Send
26	+ Sync
27	+ 'static
28{
29}
30
31impl<P: Send + Sync + Debug + 'static> HalSlice<P> for Vec<P> {}
32
33/// An abstraction to interface with acceleration hardware to perform computation intensive operations.
34pub trait ComputationBackend: Send + Sync + Debug {
35	type Vec<P: Send + Sync + Debug + 'static>: HalSlice<P>;
36
37	/// Creates `Self::Vec<P>` from the given `Vec<P>`.
38	fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P>;
39
40	/// Computes tensor product expansion.
41	fn tensor_product_full_query<P: PackedField>(
42		&self,
43		query: &[P::Scalar],
44	) -> Result<Self::Vec<P>, Error>;
45
46	/// Calculate the accumulated evaluations for an arbitrary round of zerocheck.
47	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
48		&self,
49		evaluation_order: EvaluationOrder,
50		n_vars: usize,
51		tensor_query: Option<MultilinearQueryRef<P>>,
52		multilinears: &[SumcheckMultilinear<P, M>],
53		evaluators: &[Evaluator],
54		nontrivial_evaluation_points: &[FDomain],
55	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
56	where
57		FDomain: Field,
58		P: PackedExtension<FDomain>,
59		M: MultilinearPoly<P> + Send + Sync,
60		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
61		Composition: CompositionPoly<P>;
62
63	/// Partially evaluate the polynomial with assignment to the high-indexed variables.
64	fn evaluate_partial_high<P: PackedField>(
65		&self,
66		multilinear: &impl MultilinearPoly<P>,
67		query_expansion: MultilinearQueryRef<P>,
68	) -> Result<MultilinearExtension<P>, Error>;
69}
70
71/// Makes it unnecessary to clone backends.
72/// Can't use `auto_impl` because of the complex associated type.
73impl<'a, T: 'a + ComputationBackend> ComputationBackend for &'a T
74where
75	&'a T: Debug + Sync + Send,
76{
77	type Vec<P: Send + Sync + Debug + 'static> = T::Vec<P>;
78
79	fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P> {
80		T::to_hal_slice(v)
81	}
82
83	fn tensor_product_full_query<P: PackedField>(
84		&self,
85		query: &[P::Scalar],
86	) -> Result<Self::Vec<P>, Error> {
87		T::tensor_product_full_query(self, query)
88	}
89
90	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
91		&self,
92		evaluation_order: EvaluationOrder,
93		n_vars: usize,
94		tensor_query: Option<MultilinearQueryRef<P>>,
95		multilinears: &[SumcheckMultilinear<P, M>],
96		evaluators: &[Evaluator],
97		nontrivial_evaluation_points: &[FDomain],
98	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
99	where
100		FDomain: Field,
101		P: PackedExtension<FDomain>,
102		M: MultilinearPoly<P> + Send + Sync,
103		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
104		Composition: CompositionPoly<P>,
105	{
106		T::sumcheck_compute_round_evals(
107			self,
108			evaluation_order,
109			n_vars,
110			tensor_query,
111			multilinears,
112			evaluators,
113			nontrivial_evaluation_points,
114		)
115	}
116
117	fn evaluate_partial_high<P: PackedField>(
118		&self,
119		multilinear: &impl MultilinearPoly<P>,
120		query_expansion: MultilinearQueryRef<P>,
121	) -> Result<MultilinearExtension<P>, Error> {
122		T::evaluate_partial_high(self, multilinear, query_expansion)
123	}
124}
125
126pub trait ComputationBackendExt: ComputationBackend {
127	/// Constructs a `MultilinearQuery` by performing tensor product expansion on the given `query`.
128	#[instrument(skip_all, level = "trace")]
129	fn multilinear_query<P: PackedField>(
130		&self,
131		query: &[P::Scalar],
132	) -> Result<MultilinearQuery<P, Self::Vec<P>>, Error> {
133		let tensor_product = self.tensor_product_full_query(query)?;
134		Ok(MultilinearQuery::with_expansion(query.len(), tensor_product)?)
135	}
136}
137
138impl<Backend> ComputationBackendExt for Backend where Backend: ComputationBackend {}