binius_hal/
cpu.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::fmt::Debug;
4
5use binius_field::{Field, PackedExtension, PackedField};
6use binius_math::{
7	eq_ind_partial_eval, CompositionPoly, EvaluationOrder, MultilinearExtension, MultilinearPoly,
8	MultilinearQueryRef,
9};
10use tracing::instrument;
11
12use crate::{
13	sumcheck_folding::fold_multilinears, sumcheck_round_calculation::calculate_round_evals,
14	ComputationBackend, Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear,
15};
16
17/// Implementation of ComputationBackend for the default Backend that uses the CPU for all computations.
18#[derive(Clone, Debug)]
19pub struct CpuBackend;
20
21pub const fn make_portable_backend() -> CpuBackend {
22	CpuBackend
23}
24
25impl ComputationBackend for CpuBackend {
26	type Vec<P: Send + Sync + Debug + 'static> = Vec<P>;
27
28	fn to_hal_slice<P: Debug + Send + Sync + 'static>(v: Vec<P>) -> Self::Vec<P> {
29		v
30	}
31
32	#[instrument(skip_all, level = "trace")]
33	fn tensor_product_full_query<P: PackedField>(
34		&self,
35		query: &[P::Scalar],
36	) -> Result<Self::Vec<P>, Error> {
37		Ok(eq_ind_partial_eval(query))
38	}
39
40	fn sumcheck_compute_round_evals<FDomain, P, M, Evaluator, Composition>(
41		&self,
42		evaluation_order: EvaluationOrder,
43		n_vars: usize,
44		tensor_query: Option<MultilinearQueryRef<P>>,
45		multilinears: &[SumcheckMultilinear<P, M>],
46		evaluators: &[Evaluator],
47		nontrivial_evaluation_points: &[FDomain],
48	) -> Result<Vec<RoundEvals<P::Scalar>>, Error>
49	where
50		FDomain: Field,
51		P: PackedExtension<FDomain>,
52		M: MultilinearPoly<P> + Send + Sync,
53		Evaluator: SumcheckEvaluator<P, Composition> + Sync,
54		Composition: CompositionPoly<P>,
55	{
56		calculate_round_evals(
57			evaluation_order,
58			n_vars,
59			tensor_query,
60			multilinears,
61			evaluators,
62			nontrivial_evaluation_points,
63		)
64	}
65
66	fn sumcheck_fold_multilinears<P, M>(
67		&self,
68		evaluation_order: EvaluationOrder,
69		n_vars: usize,
70		multilinears: &mut [SumcheckMultilinear<P, M>],
71		challenge: P::Scalar,
72		tensor_query: Option<MultilinearQueryRef<P>>,
73	) -> Result<bool, Error>
74	where
75		P: PackedField,
76		M: MultilinearPoly<P> + Send + Sync,
77	{
78		fold_multilinears(evaluation_order, n_vars, multilinears, challenge, tensor_query)
79	}
80
81	#[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")]
82	fn evaluate_partial_high<P: PackedField>(
83		&self,
84		multilinear: &impl MultilinearPoly<P>,
85		query_expansion: MultilinearQueryRef<P>,
86	) -> Result<MultilinearExtension<P>, Error> {
87		Ok(multilinear.evaluate_partial_high(query_expansion)?)
88	}
89}