binius_hal/
cpu.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use binius_field::{Field, PackedExtension, PackedField};
6use binius_math::{
7	CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly, MultilinearQueryRef,
8	eq_ind_partial_eval,
9};
10use tracing::instrument;
11
12use crate::{
13	ComputationBackend, Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
14	sumcheck_folding::fold_multilinears, sumcheck_round_calculation::calculate_round_evals,
15};
16
17/// Implementation of ComputationBackend for the default Backend that uses the CPU for all
18/// computations.
19#[derive(Clone, Debug)]
20pub struct CpuBackend;
21
22pub const fn make_portable_backend() -> CpuBackend {
23	CpuBackend
24}
25
26impl ComputationBackend for CpuBackend {
27	type Vec<P: Send + Sync + Debug + 'static> = Vec<P>;
28
29	fn to_hal_slice<P: Debug + Send + Sync + 'static>(v: Vec<P>) -> Self::Vec<P> {
30		v
31	}
32
33	#[instrument(skip_all, level = "trace")]
34	fn tensor_product_full_query<P: PackedField>(
35		&self,
36		query: &[P::Scalar],
37	) -> Result<Self::Vec<P>, Error> {
38		Ok(eq_ind_partial_eval(query))
39	}
40
41	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
42		&self,
43		evaluation_order: EvaluationOrder,
44		n_vars: usize,
45		tensor_query: Option<MultilinearQueryRef<P>>,
46		multilinears: &[SumcheckMultilinear<P, M>],
47		evaluators: &[Evaluator],
48		nontrivial_evaluation_points: &[FDomain],
49	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
50	where
51		FDomain: Field,
52		P: PackedExtension<FDomain>,
53		M: MultilinearPoly<P> + Send + Sync,
54		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
55		Composition: CompositionPoly<P>,
56	{
57		calculate_round_evals(
58			evaluation_order,
59			n_vars,
60			tensor_query,
61			multilinears,
62			evaluators,
63			nontrivial_evaluation_points,
64		)
65	}
66
67	fn sumcheck_fold_multilinears<P, M>(
68		&self,
69		evaluation_order: EvaluationOrder,
70		n_vars: usize,
71		multilinears: &mut [SumcheckMultilinear<P, M>],
72		challenge: P::Scalar,
73		tensor_query: Option<MultilinearQueryRef<P>>,
74	) -> Result<bool, Error>
75	where
76		P: PackedField,
77		M: MultilinearPoly<P> + Send + Sync,
78	{
79		fold_multilinears(evaluation_order, n_vars, multilinears, challenge, tensor_query)
80	}
81
82	#[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")]
83	fn evaluate_partial_high<P: PackedField>(
84		&self,
85		multilinear: &impl MultilinearPoly<P>,
86		query_expansion: MultilinearQueryRef<P>,
87	) -> Result<MultilinearExtension<P>, Error> {
88		Ok(multilinear.evaluate_partial_high(query_expansion)?)
89	}
90}