binius_core/protocols/sumcheck/prove/
common.rs1use binius_field::{
4 packed::{get_packed_slice, packed_from_fn_with_offset},
5 PackedField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::EvaluationOrder;
9use binius_maybe_rayon::prelude::*;
10use tracing::instrument;
11
12#[instrument(skip_all, level = "debug")]
13pub fn fold_partial_eq_ind<P, Backend>(
14 evaluation_order: EvaluationOrder,
15 n_vars: usize,
16 partial_eq_ind_evals: &mut Backend::Vec<P>,
17) where
18 P: PackedField,
19 Backend: ComputationBackend,
20{
21 debug_assert_eq!(1 << n_vars.saturating_sub(P::LOG_WIDTH), partial_eq_ind_evals.len());
22
23 if n_vars == 0 {
24 return;
25 }
26
27 if partial_eq_ind_evals.len() == 1 {
28 let only_packed = partial_eq_ind_evals.first().expect("len == 1");
29
30 let mut folded = P::zero();
31 for i in 0..1 << (n_vars - 1) {
32 folded.set(
33 i,
34 match evaluation_order {
35 EvaluationOrder::LowToHigh => {
36 only_packed.get(i << 1) + only_packed.get(i << 1 | 1)
37 }
38 EvaluationOrder::HighToLow => {
39 only_packed.get(i) + only_packed.get(i | 1 << (n_vars - 1))
40 }
41 },
42 );
43 }
44
45 *partial_eq_ind_evals.first_mut().expect("len == 1") = folded;
46 } else {
47 let new_packed_len = partial_eq_ind_evals.len() >> 1;
48 let updated_evals = match evaluation_order {
49 EvaluationOrder::LowToHigh => (0..new_packed_len)
50 .into_par_iter()
51 .map(|i| {
52 packed_from_fn_with_offset(i, |index| {
53 let eval0 = get_packed_slice(&*partial_eq_ind_evals, index << 1);
54 let eval1 = get_packed_slice(&*partial_eq_ind_evals, index << 1 | 1);
55 eval0 + eval1
56 })
57 })
58 .collect(),
59
60 EvaluationOrder::HighToLow => {
61 let (evals_0, evals_1) = partial_eq_ind_evals.split_at(new_packed_len);
63
64 (evals_0, evals_1)
65 .into_par_iter()
66 .map(|(&eval_0, &eval_1)| eval_0 + eval_1)
67 .collect()
68 }
69 };
70
71 *partial_eq_ind_evals = Backend::to_hal_slice(updated_evals);
72 }
73}