binius_hal/
backend.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	fmt::Debug,
5	ops::{Deref, DerefMut},
6};
7
8use binius_field::{Field, PackedExtension, PackedField};
9use binius_math::{
10	CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQuery,
11	MultilinearQueryRef,
12};
13use binius_maybe_rayon::iter::FromParallelIterator;
14use tracing::instrument;
15
16use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
17
18/// HAL-managed memory containing the result of its operations.
19pub trait HalSlice<P: Debug + Send + Sync>:
20	Deref<Target = [P]>
21	+ DerefMut<Target = [P]>
22	+ Debug
23	+ FromIterator<P>
24	+ FromParallelIterator<P>
25	+ Send
26	+ Sync
27	+ 'static
28{
29}
30
31impl<P: Send + Sync + Debug + 'static> HalSlice<P> for Vec<P> {}
32
33/// An abstraction to interface with acceleration hardware to perform computation intensive operations.
34pub trait ComputationBackend: Send + Sync + Debug {
35	type Vec<P: Send + Sync + Debug + 'static>: HalSlice<P>;
36
37	/// Creates `Self::Vec<P>` from the given `Vec<P>`.
38	fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P>;
39
40	/// Computes tensor product expansion.
41	fn tensor_product_full_query<P: PackedField>(
42		&self,
43		query: &[P::Scalar],
44	) -> Result<Self::Vec<P>, Error>;
45
46	/// Calculate the accumulated evaluations for an arbitrary round of zerocheck.
47	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
48		&self,
49		evaluation_order: EvaluationOrder,
50		n_vars: usize,
51		tensor_query: Option<MultilinearQueryRef<P>>,
52		multilinears: &[SumcheckMultilinear<P, M>],
53		evaluators: &[Evaluator],
54		nontrivial_evaluation_points: &[FDomain],
55	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
56	where
57		FDomain: Field,
58		P: PackedExtension<FDomain>,
59		M: MultilinearPoly<P> + Send + Sync,
60		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
61		Composition: CompositionPoly<P>;
62
63	/// Sumcheck round
64	fn sumcheck_fold_multilinears<P, M>(
65		&self,
66		evaluation_order: EvaluationOrder,
67		n_vars: usize,
68		multilinears: &mut [SumcheckMultilinear<P, M>],
69		challenge: P::Scalar,
70		tensor_query: Option<MultilinearQueryRef<P>>,
71	) -> Result<bool, Error>
72	where
73		P: PackedField,
74		M: MultilinearPoly<P> + Send + Sync;
75
76	/// Partially evaluate the polynomial with assignment to the high-indexed variables.
77	fn evaluate_partial_high<P: PackedField>(
78		&self,
79		multilinear: &impl MultilinearPoly<P>,
80		query_expansion: MultilinearQueryRef<P>,
81	) -> Result<MultilinearExtension<P>, Error>;
82}
83
84/// Makes it unnecessary to clone backends.
85/// Can't use `auto_impl` because of the complex associated type.
86impl<'a, T: 'a + ComputationBackend> ComputationBackend for &'a T
87where
88	&'a T: Debug + Sync + Send,
89{
90	type Vec<P: Send + Sync + Debug + 'static> = T::Vec<P>;
91
92	fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P> {
93		T::to_hal_slice(v)
94	}
95
96	fn tensor_product_full_query<P: PackedField>(
97		&self,
98		query: &[P::Scalar],
99	) -> Result<Self::Vec<P>, Error> {
100		T::tensor_product_full_query(self, query)
101	}
102
103	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
104		&self,
105		evaluation_order: EvaluationOrder,
106		n_vars: usize,
107		tensor_query: Option<MultilinearQueryRef<P>>,
108		multilinears: &[SumcheckMultilinear<P, M>],
109		evaluators: &[Evaluator],
110		nontrivial_evaluation_points: &[FDomain],
111	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
112	where
113		FDomain: Field,
114		P: PackedExtension<FDomain>,
115		M: MultilinearPoly<P> + Send + Sync,
116		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
117		Composition: CompositionPoly<P>,
118	{
119		T::sumcheck_compute_round_evals(
120			self,
121			evaluation_order,
122			n_vars,
123			tensor_query,
124			multilinears,
125			evaluators,
126			nontrivial_evaluation_points,
127		)
128	}
129
130	fn sumcheck_fold_multilinears<P, M>(
131		&self,
132		evaluation_order: EvaluationOrder,
133		n_vars: usize,
134		multilinears: &mut [SumcheckMultilinear<P, M>],
135		challenge: P::Scalar,
136		tensor_query: Option<MultilinearQueryRef<P>>,
137	) -> Result<bool, Error>
138	where
139		P: PackedField,
140		M: MultilinearPoly<P> + Send + Sync,
141	{
142		T::sumcheck_fold_multilinears(
143			self,
144			evaluation_order,
145			n_vars,
146			multilinears,
147			challenge,
148			tensor_query,
149		)
150	}
151
152	fn evaluate_partial_high<P: PackedField>(
153		&self,
154		multilinear: &impl MultilinearPoly<P>,
155		query_expansion: MultilinearQueryRef<P>,
156	) -> Result<MultilinearExtension<P>, Error> {
157		T::evaluate_partial_high(self, multilinear, query_expansion)
158	}
159}
160
161pub trait ComputationBackendExt: ComputationBackend {
162	/// Constructs a `MultilinearQuery` by performing tensor product expansion on the given `query`.
163	#[instrument(skip_all, level = "trace")]
164	fn multilinear_query<P: PackedField>(
165		&self,
166		query: &[P::Scalar],
167	) -> Result<MultilinearQuery<P, Self::Vec<P>>, Error> {
168		let tensor_product = self.tensor_product_full_query(query)?;
169		Ok(MultilinearQuery::with_expansion(query.len(), tensor_product)?)
170	}
171}
172
173impl<Backend> ComputationBackendExt for Backend where Backend: ComputationBackend {}