1use std::fmt::Debug;
4
5use binius_field::{Field, PackedExtension, PackedField};
6use binius_math::{
7 eq_ind_partial_eval, CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly,
8 MultilinearQueryRef,
9};
10use tracing::instrument;
11
12use crate::{
13 sumcheck_round_calculator::calculate_round_evals, ComputationBackend, Error, RoundEvals,
14 SumcheckEvaluator, SumcheckMultilinear,
15};
16
17#[derive(Clone, Debug)]
19pub struct CpuBackend;
20
21pub const fn make_portable_backend() -> CpuBackend {
22 CpuBackend
23}
24
25impl ComputationBackend for CpuBackend {
26 type Vec<P: Send + Sync + Debug + 'static> = Vec<P>;
27
28 fn to_hal_slice<P: Debug + Send + Sync + 'static>(v: Vec<P>) -> Self::Vec<P> {
29 v
30 }
31
32 #[instrument(skip_all, level = "trace")]
33 fn tensor_product_full_query<P: PackedField>(
34 &self,
35 query: &[P::Scalar],
36 ) -> Result<Self::Vec<P>, Error> {
37 Ok(eq_ind_partial_eval(query))
38 }
39
40 fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
41 &self,
42 evaluation_order: EvaluationOrder,
43 n_vars: usize,
44 tensor_query: Option<MultilinearQueryRef<P>>,
45 multilinears: &[SumcheckMultilinear<P, M>],
46 evaluators: &[Evaluator],
47 nontrivial_evaluation_points: &[FDomain],
48 ) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
49 where
50 FDomain: Field,
51 P: PackedExtension<FDomain>,
52 M: MultilinearPoly<P> + Send + Sync,
53 Evaluator: SumcheckEvaluator<P, Composition> + Sync,
54 Composition: CompositionPoly<P>,
55 {
56 calculate_round_evals(
57 evaluation_order,
58 n_vars,
59 tensor_query,
60 multilinears,
61 evaluators,
62 nontrivial_evaluation_points,
63 )
64 }
65
66 #[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")]
67 fn evaluate_partial_high<P: PackedField>(
68 &self,
69 multilinear: &impl MultilinearPoly<P>,
70 query_expansion: MultilinearQueryRef<P>,
71 ) -> Result<MultilinearExtension<P>, Error> {
72 Ok(multilinear.evaluate_partial_high(query_expansion)?)
73 }
74}