1use std::{
4 fmt::Debug,
5 ops::{Deref, DerefMut},
6};
7
8use binius_field::{Field, PackedExtension, PackedField};
9use binius_math::{
10 CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQuery,
11 MultilinearQueryRef,
12};
13use binius_maybe_rayon::iter::FromParallelIterator;
14use tracing::instrument;
15
16use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
17
18pub trait HalSlice<P: Debug + Send + Sync>:
20 Deref<Target = [P]>
21 + DerefMut<Target = [P]>
22 + Debug
23 + FromIterator<P>
24 + FromParallelIterator<P>
25 + Send
26 + Sync
27 + 'static
28{
29}
30
31impl<P: Send + Sync + Debug + 'static> HalSlice<P> for Vec<P> {}
32
33pub trait ComputationBackend: Send + Sync + Debug {
35 type Vec<P: Send + Sync + Debug + 'static>: HalSlice<P>;
36
37 fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P>;
39
40 fn tensor_product_full_query<P: PackedField>(
42 &self,
43 query: &[P::Scalar],
44 ) -> Result<Self::Vec<P>, Error>;
45
46 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
48 &self,
49 evaluation_order: EvaluationOrder,
50 n_vars: usize,
51 tensor_query: Option<MultilinearQueryRef<P>>,
52 multilinears: &[SumcheckMultilinear<P, M>],
53 evaluators: &[Evaluator],
54 nontrivial_evaluation_points: &[FDomain],
55 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
56 where
57 FDomain: Field,
58 P: PackedExtension<FDomain>,
59 M: MultilinearPoly<P> + Send + Sync,
60 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
61 Composition: CompositionPoly<P>;
62
63 fn sumcheck_fold_multilinears<P, M>(
65 &self,
66 evaluation_order: EvaluationOrder,
67 n_vars: usize,
68 multilinears: &mut [SumcheckMultilinear<P, M>],
69 challenge: P::Scalar,
70 tensor_query: Option<MultilinearQueryRef<P>>,
71 ) -> Result<bool, Error>
72 where
73 P: PackedField,
74 M: MultilinearPoly<P> + Send + Sync;
75
76 fn evaluate_partial_high<P: PackedField>(
78 &self,
79 multilinear: &impl MultilinearPoly<P>,
80 query_expansion: MultilinearQueryRef<P>,
81 ) -> Result<MultilinearExtension<P>, Error>;
82}
83
84impl<'a, T: 'a + ComputationBackend> ComputationBackend for &'a T
87where
88 &'a T: Debug + Sync + Send,
89{
90 type Vec<P: Send + Sync + Debug + 'static> = T::Vec<P>;
91
92 fn to_hal_slice<P: Debug + Send + Sync>(v: Vec<P>) -> Self::Vec<P> {
93 T::to_hal_slice(v)
94 }
95
96 fn tensor_product_full_query<P: PackedField>(
97 &self,
98 query: &[P::Scalar],
99 ) -> Result<Self::Vec<P>, Error> {
100 T::tensor_product_full_query(self, query)
101 }
102
103 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
104 &self,
105 evaluation_order: EvaluationOrder,
106 n_vars: usize,
107 tensor_query: Option<MultilinearQueryRef<P>>,
108 multilinears: &[SumcheckMultilinear<P, M>],
109 evaluators: &[Evaluator],
110 nontrivial_evaluation_points: &[FDomain],
111 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
112 where
113 FDomain: Field,
114 P: PackedExtension<FDomain>,
115 M: MultilinearPoly<P> + Send + Sync,
116 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
117 Composition: CompositionPoly<P>,
118 {
119 T::sumcheck_compute_round_evals(
120 self,
121 evaluation_order,
122 n_vars,
123 tensor_query,
124 multilinears,
125 evaluators,
126 nontrivial_evaluation_points,
127 )
128 }
129
130 fn sumcheck_fold_multilinears<P, M>(
131 &self,
132 evaluation_order: EvaluationOrder,
133 n_vars: usize,
134 multilinears: &mut [SumcheckMultilinear<P, M>],
135 challenge: P::Scalar,
136 tensor_query: Option<MultilinearQueryRef<P>>,
137 ) -> Result<bool, Error>
138 where
139 P: PackedField,
140 M: MultilinearPoly<P> + Send + Sync,
141 {
142 T::sumcheck_fold_multilinears(
143 self,
144 evaluation_order,
145 n_vars,
146 multilinears,
147 challenge,
148 tensor_query,
149 )
150 }
151
152 fn evaluate_partial_high<P: PackedField>(
153 &self,
154 multilinear: &impl MultilinearPoly<P>,
155 query_expansion: MultilinearQueryRef<P>,
156 ) -> Result<MultilinearExtension<P>, Error> {
157 T::evaluate_partial_high(self, multilinear, query_expansion)
158 }
159}
160
161pub trait ComputationBackendExt: ComputationBackend {
162 #[instrument(skip_all, level = "trace")]
164 fn multilinear_query<P: PackedField>(
165 &self,
166 query: &[P::Scalar],
167 ) -> Result<MultilinearQuery<P, Self::Vec<P>>, Error> {
168 let tensor_product = self.tensor_product_full_query(query)?;
169 Ok(MultilinearQuery::with_expansion(query.len(), tensor_product)?)
170 }
171}
172
173impl<Backend> ComputationBackendExt for Backend where Backend: ComputationBackend {}