1use std::fmt::Debug;
4
5use binius_field::{Field, PackedExtension, PackedField};
6use binius_math::{
7 CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQueryRef,
8 eq_ind_partial_eval,
9};
10use tracing::instrument;
11
12use crate::{
13 ComputationBackend, Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
14 sumcheck_folding::fold_multilinears, sumcheck_round_calculation::calculate_round_evals,
15};
16
17#[derive(Clone, Debug)]
20pub struct CpuBackend;
21
22pub const fn make_portable_backend() -> CpuBackend {
23 CpuBackend
24}
25
26impl ComputationBackend for CpuBackend {
27 type Vec<P: Send + Sync + Debug + 'static> = Vec<P>;
28
29 fn to_hal_slice<P: Debug + Send + Sync + 'static>(v: Vec<P>) -> Self::Vec<P> {
30 v
31 }
32
33 #[instrument(skip_all, level = "trace")]
34 fn tensor_product_full_query<P: PackedField>(
35 &self,
36 query: &[P::Scalar],
37 ) -> Result<Self::Vec<P>, Error> {
38 Ok(eq_ind_partial_eval(query))
39 }
40
41 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
42 &self,
43 evaluation_order: EvaluationOrder,
44 n_vars: usize,
45 tensor_query: Option<MultilinearQueryRef<P>>,
46 multilinears: &[SumcheckMultilinear<P, M>],
47 evaluators: &[Evaluator],
48 nontrivial_evaluation_points: &[FDomain],
49 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
50 where
51 FDomain: Field,
52 P: PackedExtension<FDomain>,
53 M: MultilinearPoly<P> + Send + Sync,
54 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
55 Composition: CompositionPoly<P>,
56 {
57 calculate_round_evals(
58 evaluation_order,
59 n_vars,
60 tensor_query,
61 multilinears,
62 evaluators,
63 nontrivial_evaluation_points,
64 )
65 }
66
67 fn sumcheck_fold_multilinears<P, M>(
68 &self,
69 evaluation_order: EvaluationOrder,
70 n_vars: usize,
71 multilinears: &mut [SumcheckMultilinear<P, M>],
72 challenge: P::Scalar,
73 tensor_query: Option<MultilinearQueryRef<P>>,
74 ) -> Result<bool, Error>
75 where
76 P: PackedField,
77 M: MultilinearPoly<P> + Send + Sync,
78 {
79 fold_multilinears(evaluation_order, n_vars, multilinears, challenge, tensor_query)
80 }
81
82 #[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")]
83 fn evaluate_partial_high<P: PackedField>(
84 &self,
85 multilinear: &impl MultilinearPoly<P>,
86 query_expansion: MultilinearQueryRef<P>,
87 ) -> Result<MultilinearExtension<P>, Error> {
88 Ok(multilinear.evaluate_partial_high(query_expansion)?)
89 }
90}