1use std::{
4 fmt::Debug,
5 ops::{Deref, DerefMut},
6};
7
8use binius_field::{Field, PackedExtension, PackedField};
9use binius_math::{
10 CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQuery,
11 MultilinearQueryRef,
12};
13use binius_maybe_rayon::iter::FromParallelIterator;
14use tracing::instrument;
15
16use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
17
18pub trait HalSlice<P: Debug + Send + Sync>:
20 Deref<Target = [P]>
21 + DerefMut<Target = [P]>
22 + Debug
23 + FromIterator<P>
24 + FromParallelIterator<P>
25 + Send
26 + Sync
27 + 'static
28{
29}
30
31impl<P: Send + Sync + Debug + 'static> HalSlice<P> for Vec<P> {}
32
33pub trait ComputationBackend: Send + Sync + Debug {
36 type Vec<P: Send + Sync + Debug + 'static>: HalSlice<P>;
37
38 fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P>;
40
41 fn tensor_product_full_query<P: PackedField>(
43 &self,
44 query: &[P::Scalar],
45 ) -> Result<Self::Vec<P>, Error>;
46
47 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
49 &self,
50 evaluation_order: EvaluationOrder,
51 n_vars: usize,
52 tensor_query: Option<MultilinearQueryRef<P>>,
53 multilinears: &[SumcheckMultilinear<P, M>],
54 evaluators: &[Evaluator],
55 nontrivial_evaluation_points: &[FDomain],
56 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
57 where
58 FDomain: Field,
59 P: PackedExtension<FDomain>,
60 M: MultilinearPoly<P> + Send + Sync,
61 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
62 Composition: CompositionPoly<P>;
63
64 fn sumcheck_fold_multilinears<P, M>(
66 &self,
67 evaluation_order: EvaluationOrder,
68 n_vars: usize,
69 multilinears: &mut [SumcheckMultilinear<P, M>],
70 challenge: P::Scalar,
71 tensor_query: Option<MultilinearQueryRef<P>>,
72 ) -> Result<bool, Error>
73 where
74 P: PackedField,
75 M: MultilinearPoly<P> + Send + Sync;
76
77 fn evaluate_partial_high<P: PackedField>(
79 &self,
80 multilinear: &impl MultilinearPoly<P>,
81 query_expansion: MultilinearQueryRef<P>,
82 ) -> Result<MultilinearExtension<P>, Error>;
83}
84
85impl<'a, T: 'a + ComputationBackend> ComputationBackend for &'a T
88where
89 &'a T: Debug + Sync + Send,
90{
91 type Vec<P: Send + Sync + Debug + 'static> = T::Vec<P>;
92
93 fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P> {
94 T::to_hal_slice(v)
95 }
96
97 fn tensor_product_full_query<P: PackedField>(
98 &self,
99 query: &[P::Scalar],
100 ) -> Result<Self::Vec<P>, Error> {
101 T::tensor_product_full_query(self, query)
102 }
103
104 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
105 &self,
106 evaluation_order: EvaluationOrder,
107 n_vars: usize,
108 tensor_query: Option<MultilinearQueryRef<P>>,
109 multilinears: &[SumcheckMultilinear<P, M>],
110 evaluators: &[Evaluator],
111 nontrivial_evaluation_points: &[FDomain],
112 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
113 where
114 FDomain: Field,
115 P: PackedExtension<FDomain>,
116 M: MultilinearPoly<P> + Send + Sync,
117 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
118 Composition: CompositionPoly<P>,
119 {
120 T::sumcheck_compute_round_evals(
121 self,
122 evaluation_order,
123 n_vars,
124 tensor_query,
125 multilinears,
126 evaluators,
127 nontrivial_evaluation_points,
128 )
129 }
130
131 fn sumcheck_fold_multilinears<P, M>(
132 &self,
133 evaluation_order: EvaluationOrder,
134 n_vars: usize,
135 multilinears: &mut [SumcheckMultilinear<P, M>],
136 challenge: P::Scalar,
137 tensor_query: Option<MultilinearQueryRef<P>>,
138 ) -> Result<bool, Error>
139 where
140 P: PackedField,
141 M: MultilinearPoly<P> + Send + Sync,
142 {
143 T::sumcheck_fold_multilinears(
144 self,
145 evaluation_order,
146 n_vars,
147 multilinears,
148 challenge,
149 tensor_query,
150 )
151 }
152
153 fn evaluate_partial_high<P: PackedField>(
154 &self,
155 multilinear: &impl MultilinearPoly<P>,
156 query_expansion: MultilinearQueryRef<P>,
157 ) -> Result<MultilinearExtension<P>, Error> {
158 T::evaluate_partial_high(self, multilinear, query_expansion)
159 }
160}
161
162pub trait ComputationBackendExt: ComputationBackend {
163 #[instrument(skip_all, level = "trace")]
165 fn multilinear_query<P: PackedField>(
166 &self,
167 query: &[P::Scalar],
168 ) -> Result<MultilinearQuery<P, Self::Vec<P>>, Error> {
169 let tensor_product = self.tensor_product_full_query(query)?;
170 Ok(MultilinearQuery::with_expansion(query.len(), tensor_product)?)
171 }
172}
173
174impl<Backend> ComputationBackendExt for Backend where Backend: ComputationBackend {}