1use std::fmt::Debug;
4
5use binius_field::{Field, PackedExtension, PackedField};
6use binius_math::{
7 eq_ind_partial_eval, CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly,
8 MultilinearQueryRef,
9};
10use tracing::instrument;
11
12use crate::{
13 sumcheck_folding::fold_multilinears, sumcheck_round_calculation::calculate_round_evals,
14 ComputationBackend, Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
15};
16
17#[derive(Clone, Debug)]
19pub struct CpuBackend;
20
21pub const fn make_portable_backend() -> CpuBackend {
22 CpuBackend
23}
24
25impl ComputationBackend for CpuBackend {
26 type Vec<P: Send + Sync + Debug + 'static> = Vec<P>;
27
28 fn to_hal_slice<P: Debug + Send + Sync + 'static>(v: Vec<P>) -> Self::Vec<P> {
29 v
30 }
31
32 #[instrument(skip_all, level = "trace")]
33 fn tensor_product_full_query<P: PackedField>(
34 &self,
35 query: &[P::Scalar],
36 ) -> Result<Self::Vec<P>, Error> {
37 Ok(eq_ind_partial_eval(query))
38 }
39
40 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
41 &self,
42 evaluation_order: EvaluationOrder,
43 n_vars: usize,
44 tensor_query: Option<MultilinearQueryRef<P>>,
45 multilinears: &[SumcheckMultilinear<P, M>],
46 evaluators: &[Evaluator],
47 nontrivial_evaluation_points: &[FDomain],
48 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
49 where
50 FDomain: Field,
51 P: PackedExtension<FDomain>,
52 M: MultilinearPoly<P> + Send + Sync,
53 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
54 Composition: CompositionPoly<P>,
55 {
56 calculate_round_evals(
57 evaluation_order,
58 n_vars,
59 tensor_query,
60 multilinears,
61 evaluators,
62 nontrivial_evaluation_points,
63 )
64 }
65
66 fn sumcheck_fold_multilinears<P, M>(
67 &self,
68 evaluation_order: EvaluationOrder,
69 n_vars: usize,
70 multilinears: &mut [SumcheckMultilinear<P, M>],
71 challenge: P::Scalar,
72 tensor_query: Option<MultilinearQueryRef<P>>,
73 ) -> Result<bool, Error>
74 where
75 P: PackedField,
76 M: MultilinearPoly<P> + Send + Sync,
77 {
78 fold_multilinears(evaluation_order, n_vars, multilinears, challenge, tensor_query)
79 }
80
81 #[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")]
82 fn evaluate_partial_high<P: PackedField>(
83 &self,
84 multilinear: &impl MultilinearPoly<P>,
85 query_expansion: MultilinearQueryRef<P>,
86 ) -> Result<MultilinearExtension<P>, Error> {
87 Ok(multilinear.evaluate_partial_high(query_expansion)?)
88 }
89}