1use std::{
4 fmt::Debug,
5 ops::{Deref, DerefMut},
6};
7
8use binius_field::{Field, PackedExtension, PackedField};
9use binius_math::{
10 CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQuery,
11 MultilinearQueryRef,
12};
13use binius_maybe_rayon::iter::FromParallelIterator;
14use tracing::instrument;
15
16use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
17
18pub trait HalSlice<P: Debug + Send + Sync>:
20 Deref<Target = [P]>
21 + DerefMut<Target = [P]>
22 + Debug
23 + FromIterator<P>
24 + FromParallelIterator<P>
25 + Send
26 + Sync
27 + 'static
28{
29}
30
31impl<P: Send + Sync + Debug + 'static> HalSlice<P> for Vec<P> {}
32
33pub trait ComputationBackend: Send + Sync + Debug {
35 type Vec<P: Send + Sync + Debug + 'static>: HalSlice<P>;
36
37 fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P>;
39
40 fn tensor_product_full_query<P: PackedField>(
42 &self,
43 query: &[P::Scalar],
44 ) -> Result<Self::Vec<P>, Error>;
45
46 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
48 &self,
49 evaluation_order: EvaluationOrder,
50 n_vars: usize,
51 tensor_query: Option<MultilinearQueryRef<P>>,
52 multilinears: &[SumcheckMultilinear<P, M>],
53 evaluators: &[Evaluator],
54 nontrivial_evaluation_points: &[FDomain],
55 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
56 where
57 FDomain: Field,
58 P: PackedExtension<FDomain>,
59 M: MultilinearPoly<P> + Send + Sync,
60 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
61 Composition: CompositionPoly<P>;
62
63 fn evaluate_partial_high<P: PackedField>(
65 &self,
66 multilinear: &impl MultilinearPoly<P>,
67 query_expansion: MultilinearQueryRef<P>,
68 ) -> Result<MultilinearExtension<P>, Error>;
69}
70
71impl<'a, T: 'a + ComputationBackend> ComputationBackend for &'a T
74where
75 &'a T: Debug + Sync + Send,
76{
77 type Vec<P: Send + Sync + Debug + 'static> = T::Vec<P>;
78
79 fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P> {
80 T::to_hal_slice(v)
81 }
82
83 fn tensor_product_full_query<P: PackedField>(
84 &self,
85 query: &[P::Scalar],
86 ) -> Result<Self::Vec<P>, Error> {
87 T::tensor_product_full_query(self, query)
88 }
89
90 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
91 &self,
92 evaluation_order: EvaluationOrder,
93 n_vars: usize,
94 tensor_query: Option<MultilinearQueryRef<P>>,
95 multilinears: &[SumcheckMultilinear<P, M>],
96 evaluators: &[Evaluator],
97 nontrivial_evaluation_points: &[FDomain],
98 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
99 where
100 FDomain: Field,
101 P: PackedExtension<FDomain>,
102 M: MultilinearPoly<P> + Send + Sync,
103 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
104 Composition: CompositionPoly<P>,
105 {
106 T::sumcheck_compute_round_evals(
107 self,
108 evaluation_order,
109 n_vars,
110 tensor_query,
111 multilinears,
112 evaluators,
113 nontrivial_evaluation_points,
114 )
115 }
116
117 fn evaluate_partial_high<P: PackedField>(
118 &self,
119 multilinear: &impl MultilinearPoly<P>,
120 query_expansion: MultilinearQueryRef<P>,
121 ) -> Result<MultilinearExtension<P>, Error> {
122 T::evaluate_partial_high(self, multilinear, query_expansion)
123 }
124}
125
126pub trait ComputationBackendExt: ComputationBackend {
127 #[instrument(skip_all, level = "trace")]
129 fn multilinear_query<P: PackedField>(
130 &self,
131 query: &[P::Scalar],
132 ) -> Result<MultilinearQuery<P, Self::Vec<P>>, Error> {
133 let tensor_product = self.tensor_product_full_query(query)?;
134 Ok(MultilinearQuery::with_expansion(query.len(), tensor_product)?)
135 }
136}
137
138impl<Backend> ComputationBackendExt for Backend where Backend: ComputationBackend {}