binius_hal/
cpu.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use binius_field::{Field, PackedExtension, PackedField};
6use binius_math::{
7	eq_ind_partial_eval, CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly,
8	MultilinearQueryRef,
9};
10use tracing::instrument;
11
12use crate::{
13	sumcheck_round_calculator::calculate_round_evals, ComputationBackend, Error, RoundEvals,
14	SumcheckEvaluator, SumcheckMultilinear,
15};
16
17/// Implementation of ComputationBackend for the default Backend that uses the CPU for all computations.
18#[derive(Clone, Debug)]
19pub struct CpuBackend;
20
21pub const fn make_portable_backend() -> CpuBackend {
22	CpuBackend
23}
24
25impl ComputationBackend for CpuBackend {
26	type Vec<P: Send + Sync + Debug + 'static> = Vec<P>;
27
28	fn to_hal_slice<P: Debug + Send + Sync + 'static>(v: Vec<P>) -> Self::Vec<P> {
29		v
30	}
31
32	#[instrument(skip_all, level = "trace")]
33	fn tensor_product_full_query<P: PackedField>(
34		&self,
35		query: &[P::Scalar],
36	) -> Result<Self::Vec<P>, Error> {
37		Ok(eq_ind_partial_eval(query))
38	}
39
40	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
41		&self,
42		evaluation_order: EvaluationOrder,
43		n_vars: usize,
44		tensor_query: Option<MultilinearQueryRef<P>>,
45		multilinears: &[SumcheckMultilinear<P, M>],
46		evaluators: &[Evaluator],
47		nontrivial_evaluation_points: &[FDomain],
48	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
49	where
50		FDomain: Field,
51		P: PackedExtension<FDomain>,
52		M: MultilinearPoly<P> + Send + Sync,
53		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
54		Composition: CompositionPoly<P>,
55	{
56		calculate_round_evals(
57			evaluation_order,
58			n_vars,
59			tensor_query,
60			multilinears,
61			evaluators,
62			nontrivial_evaluation_points,
63		)
64	}
65
66	#[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")]
67	fn evaluate_partial_high<P: PackedField>(
68		&self,
69		multilinear: &impl MultilinearPoly<P>,
70		query_expansion: MultilinearQueryRef<P>,
71	) -> Result<MultilinearExtension<P>, Error> {
72		Ok(multilinear.evaluate_partial_high(query_expansion)?)
73	}
74}