binius_hal/
backend.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	fmt::Debug,
5	ops::{Deref, DerefMut},
6};
7
8use binius_field::{Field, PackedExtension, PackedField};
9use binius_math::{
10	CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQuery,
11	MultilinearQueryRef,
12};
13use binius_maybe_rayon::iter::FromParallelIterator;
14use tracing::instrument;
15
16use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
17
18/// HAL-managed memory containing the result of its operations.
19pub trait HalSlice<P: Debug + Send + Sync>:
20	Deref<Target = [P]>
21	+ DerefMut<Target = [P]>
22	+ Debug
23	+ FromIterator<P>
24	+ FromParallelIterator<P>
25	+ Send
26	+ Sync
27	+ 'static
28{
29}
30
31impl<P: Send + Sync + Debug + 'static> HalSlice<P> for Vec<P> {}
32
33/// An abstraction to interface with acceleration hardware to perform computation intensive
34/// operations.
35pub trait ComputationBackend: Send + Sync + Debug {
36	type Vec<P: Send + Sync + Debug + 'static>: HalSlice<P>;
37
38	/// Creates `Self::Vec<P>` from the given `Vec<P>`.
39	fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P>;
40
41	/// Computes tensor product expansion.
42	fn tensor_product_full_query<P: PackedField>(
43		&self,
44		query: &[P::Scalar],
45	) -> Result<Self::Vec<P>, Error>;
46
47	/// Calculate the accumulated evaluations for an arbitrary round of zerocheck.
48	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
49		&self,
50		evaluation_order: EvaluationOrder,
51		n_vars: usize,
52		tensor_query: Option<MultilinearQueryRef<P>>,
53		multilinears: &[SumcheckMultilinear<P, M>],
54		evaluators: &[Evaluator],
55		nontrivial_evaluation_points: &[FDomain],
56	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
57	where
58		FDomain: Field,
59		P: PackedExtension<FDomain>,
60		M: MultilinearPoly<P> + Send + Sync,
61		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
62		Composition: CompositionPoly<P>;
63
64	/// Sumcheck round
65	fn sumcheck_fold_multilinears<P, M>(
66		&self,
67		evaluation_order: EvaluationOrder,
68		n_vars: usize,
69		multilinears: &mut [SumcheckMultilinear<P, M>],
70		challenge: P::Scalar,
71		tensor_query: Option<MultilinearQueryRef<P>>,
72	) -> Result<bool, Error>
73	where
74		P: PackedField,
75		M: MultilinearPoly<P> + Send + Sync;
76
77	/// Partially evaluate the polynomial with assignment to the high-indexed variables.
78	fn evaluate_partial_high<P: PackedField>(
79		&self,
80		multilinear: &impl MultilinearPoly<P>,
81		query_expansion: MultilinearQueryRef<P>,
82	) -> Result<MultilinearExtension<P>, Error>;
83}
84
85/// Makes it unnecessary to clone backends.
86/// Can't use `auto_impl` because of the complex associated type.
87impl<'a, T: 'a + ComputationBackend> ComputationBackend for &'a T
88where
89	&'a T: Debug + Sync + Send,
90{
91	type Vec<P: Send + Sync + Debug + 'static> = T::Vec<P>;
92
93	fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P> {
94		T::to_hal_slice(v)
95	}
96
97	fn tensor_product_full_query<P: PackedField>(
98		&self,
99		query: &[P::Scalar],
100	) -> Result<Self::Vec<P>, Error> {
101		T::tensor_product_full_query(self, query)
102	}
103
104	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
105		&self,
106		evaluation_order: EvaluationOrder,
107		n_vars: usize,
108		tensor_query: Option<MultilinearQueryRef<P>>,
109		multilinears: &[SumcheckMultilinear<P, M>],
110		evaluators: &[Evaluator],
111		nontrivial_evaluation_points: &[FDomain],
112	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
113	where
114		FDomain: Field,
115		P: PackedExtension<FDomain>,
116		M: MultilinearPoly<P> + Send + Sync,
117		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
118		Composition: CompositionPoly<P>,
119	{
120		T::sumcheck_compute_round_evals(
121			self,
122			evaluation_order,
123			n_vars,
124			tensor_query,
125			multilinears,
126			evaluators,
127			nontrivial_evaluation_points,
128		)
129	}
130
131	fn sumcheck_fold_multilinears<P, M>(
132		&self,
133		evaluation_order: EvaluationOrder,
134		n_vars: usize,
135		multilinears: &mut [SumcheckMultilinear<P, M>],
136		challenge: P::Scalar,
137		tensor_query: Option<MultilinearQueryRef<P>>,
138	) -> Result<bool, Error>
139	where
140		P: PackedField,
141		M: MultilinearPoly<P> + Send + Sync,
142	{
143		T::sumcheck_fold_multilinears(
144			self,
145			evaluation_order,
146			n_vars,
147			multilinears,
148			challenge,
149			tensor_query,
150		)
151	}
152
153	fn evaluate_partial_high<P: PackedField>(
154		&self,
155		multilinear: &impl MultilinearPoly<P>,
156		query_expansion: MultilinearQueryRef<P>,
157	) -> Result<MultilinearExtension<P>, Error> {
158		T::evaluate_partial_high(self, multilinear, query_expansion)
159	}
160}
161
162pub trait ComputationBackendExt: ComputationBackend {
163	/// Constructs a `MultilinearQuery` by performing tensor product expansion on the given `query`.
164	#[instrument(skip_all, level = "trace")]
165	fn multilinear_query<P: PackedField>(
166		&self,
167		query: &[P::Scalar],
168	) -> Result<MultilinearQuery<P, Self::Vec<P>>, Error> {
169		let tensor_product = self.tensor_product_full_query(query)?;
170		Ok(MultilinearQuery::with_expansion(query.len(), tensor_product)?)
171	}
172}
173
174impl<Backend> ComputationBackendExt for Backend where Backend: ComputationBackend {}