binius_core/protocols/sumcheck/prove/
common.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	packed::{get_packed_slice, packed_from_fn_with_offset},
5	PackedField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::EvaluationOrder;
9use binius_maybe_rayon::prelude::*;
10use tracing::instrument;
11
12#[instrument(skip_all, level = "debug")]
13pub fn fold_partial_eq_ind<P, Backend>(
14	evaluation_order: EvaluationOrder,
15	n_vars: usize,
16	partial_eq_ind_evals: &mut Backend::Vec<P>,
17) where
18	P: PackedField,
19	Backend: ComputationBackend,
20{
21	debug_assert_eq!(1 << n_vars.saturating_sub(P::LOG_WIDTH), partial_eq_ind_evals.len());
22
23	if n_vars == 0 {
24		return;
25	}
26
27	if partial_eq_ind_evals.len() == 1 {
28		let only_packed = partial_eq_ind_evals.first().expect("len == 1");
29
30		let mut folded = P::zero();
31		for i in 0..1 << (n_vars - 1) {
32			folded.set(
33				i,
34				match evaluation_order {
35					EvaluationOrder::LowToHigh => {
36						only_packed.get(i << 1) + only_packed.get(i << 1 | 1)
37					}
38					EvaluationOrder::HighToLow => {
39						only_packed.get(i) + only_packed.get(i | 1 << (n_vars - 1))
40					}
41				},
42			);
43		}
44
45		*partial_eq_ind_evals.first_mut().expect("len == 1") = folded;
46	} else {
47		let new_packed_len = partial_eq_ind_evals.len() >> 1;
48		let updated_evals = match evaluation_order {
49			EvaluationOrder::LowToHigh => (0..new_packed_len)
50				.into_par_iter()
51				.map(|i| {
52					packed_from_fn_with_offset(i, |index| {
53						let eval0 = get_packed_slice(&*partial_eq_ind_evals, index << 1);
54						let eval1 = get_packed_slice(&*partial_eq_ind_evals, index << 1 | 1);
55						eval0 + eval1
56					})
57				})
58				.collect(),
59
60			EvaluationOrder::HighToLow => {
61				// REVIEW: make this inplace, by enabling truncation in Backend::Vec
62				let (evals_0, evals_1) = partial_eq_ind_evals.split_at(new_packed_len);
63
64				(evals_0, evals_1)
65					.into_par_iter()
66					.map(|(&eval_0, &eval_1)| eval_0 + eval_1)
67					.collect()
68			}
69		};
70
71		*partial_eq_ind_evals = Backend::to_hal_slice(updated_evals);
72	}
73}