binius_hal/
sumcheck_folding.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::PackedField;
4use binius_math::{
5	fold_left_lerp_inplace, fold_right_lerp, EvaluationOrder, MultilinearPoly, MultilinearQueryRef,
6};
7use binius_maybe_rayon::prelude::*;
8use binius_utils::checked_arithmetics::log2_ceil_usize;
9use bytemuck::zeroed_vec;
10
11use crate::{
12	common::{subcube_vars_for_bits, MAX_SRC_SUBCUBE_LOG_BITS},
13	Error, SumcheckMultilinear,
14};
15
16pub(crate) fn fold_multilinears<P, M>(
17	evaluation_order: EvaluationOrder,
18	n_vars: usize,
19	multilinears: &mut [SumcheckMultilinear<P, M>],
20	challenge: P::Scalar,
21	tensor_query: Option<MultilinearQueryRef<P>>,
22) -> Result<bool, Error>
23where
24	P: PackedField,
25	M: MultilinearPoly<P> + Send + Sync,
26{
27	match evaluation_order {
28		EvaluationOrder::LowToHigh => {
29			fold_multilinears_low_to_high(n_vars, multilinears, challenge, tensor_query)
30		}
31		EvaluationOrder::HighToLow => {
32			fold_multilinears_high_to_low(n_vars, multilinears, challenge, tensor_query)
33		}
34	}
35}
36
37fn fold_multilinears_low_to_high<P, M>(
38	n_vars: usize,
39	multilinears: &mut [SumcheckMultilinear<P, M>],
40	challenge: P::Scalar,
41	tensor_query: Option<MultilinearQueryRef<P>>,
42) -> Result<bool, Error>
43where
44	P: PackedField,
45	M: MultilinearPoly<P> + Send + Sync,
46{
47	assert!(n_vars > 0);
48	parallel_map(multilinears, |sumcheck_multilinear| -> Result<_, Error> {
49		match sumcheck_multilinear {
50			SumcheckMultilinear::Transparent {
51				multilinear,
52				switchover_round,
53				zero_scalars_suffix,
54			} => {
55				if *switchover_round == 0 {
56					// At switchover we partially evaluate the multilinear at an expanded tensor query.
57					let tensor_query = tensor_query
58						.as_ref()
59						.expect("guaranteed to be Some while there is still a transparent");
60
61					assert!(tensor_query.n_vars() > 0);
62
63					let nonzero_scalars_prefix = (1 << n_vars) - *zero_scalars_suffix;
64
65					let large_field_folded_evals = if nonzero_scalars_prefix < 1 << n_vars {
66						let subcube_vars = subcube_vars_for_bits::<P>(
67							MAX_SRC_SUBCUBE_LOG_BITS,
68							log2_ceil_usize(nonzero_scalars_prefix),
69							tensor_query.n_vars(),
70							n_vars - 1,
71						);
72
73						let packed_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
74
75						let folded_scalars =
76							nonzero_scalars_prefix.div_ceil(1 << tensor_query.n_vars());
77
78						let mut folded =
79							zeroed_vec(folded_scalars.div_ceil(1 << subcube_vars) * packed_len);
80
81						// REVIEW: no lerp optimization in subcube_partial_low_evals currently
82						for (subcube_index, subcube_evals) in
83							folded.chunks_exact_mut(packed_len).enumerate()
84						{
85							multilinear.subcube_partial_low_evals(
86								*tensor_query,
87								subcube_vars,
88								subcube_index,
89								subcube_evals,
90							)?;
91						}
92
93						folded.truncate(folded_scalars.div_ceil(P::WIDTH));
94						folded
95					} else {
96						multilinear
97							.evaluate_partial_low(*tensor_query)?
98							.into_evals()
99					};
100
101					*sumcheck_multilinear = SumcheckMultilinear::Folded {
102						large_field_folded_evals,
103					};
104
105					Ok(false)
106				} else {
107					*switchover_round -= 1;
108					Ok(true)
109				}
110			}
111
112			SumcheckMultilinear::Folded {
113				large_field_folded_evals: evals,
114			} => {
115				// Post-switchover, we perform single variable folding (linear interpolation).
116				// NB: Lerp folding in low-to-high evaluation order can be made inplace, but not
117				// easily so if multithreading is desired.
118				let mut new_evals = zeroed_vec(evals.len().div_ceil(2));
119
120				fold_right_lerp(
121					evals.as_slice(),
122					// evals is optimally truncated, upper bound on nonzero scalars is quite tight
123					evals.len() * P::WIDTH,
124					challenge,
125					&mut new_evals,
126				)?;
127
128				*evals = new_evals;
129				Ok(false)
130			}
131		}
132	})
133}
134
135fn fold_multilinears_high_to_low<P, M>(
136	n_vars: usize,
137	multilinears: &mut [SumcheckMultilinear<P, M>],
138	challenge: P::Scalar,
139	tensor_query: Option<MultilinearQueryRef<P>>,
140) -> Result<bool, Error>
141where
142	P: PackedField,
143	M: MultilinearPoly<P> + Send + Sync,
144{
145	parallel_map(multilinears, |sumcheck_multilinear| -> Result<_, Error> {
146		match sumcheck_multilinear {
147			SumcheckMultilinear::Transparent {
148				multilinear,
149				switchover_round,
150				zero_scalars_suffix,
151			} => {
152				if *switchover_round == 0 {
153					// At switchover we partially evaluate the multilinear at an expanded tensor query.
154					let tensor_query = tensor_query
155						.as_ref()
156						.expect("guaranteed to be Some while there is still a transparent");
157
158					let nonzero_scalars_prefix = (1 << n_vars) - *zero_scalars_suffix;
159
160					let large_field_folded_evals = if nonzero_scalars_prefix < 1 << n_vars {
161						let subcube_vars = subcube_vars_for_bits::<P>(
162							MAX_SRC_SUBCUBE_LOG_BITS,
163							log2_ceil_usize(nonzero_scalars_prefix),
164							tensor_query.n_vars(),
165							n_vars - 1,
166						);
167
168						let packed_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
169
170						let folded_scalars =
171							nonzero_scalars_prefix.min(1 << (n_vars - tensor_query.n_vars()));
172
173						let mut folded =
174							zeroed_vec(folded_scalars.div_ceil(1 << subcube_vars) * packed_len);
175
176						// REVIEW: no lerp optimization in subcube_partial_high_evals currently
177						for (subcube_index, subcube_evals) in
178							folded.chunks_exact_mut(packed_len).enumerate()
179						{
180							multilinear.subcube_partial_high_evals(
181								*tensor_query,
182								subcube_vars,
183								subcube_index,
184								subcube_evals,
185							)?;
186						}
187
188						folded.truncate(folded_scalars.div_ceil(P::WIDTH));
189						folded
190					} else {
191						multilinear
192							.evaluate_partial_high(*tensor_query)?
193							.into_evals()
194					};
195
196					*sumcheck_multilinear = SumcheckMultilinear::Folded {
197						large_field_folded_evals,
198					};
199
200					Ok(false)
201				} else {
202					*switchover_round -= 1;
203					Ok(true)
204				}
205			}
206
207			SumcheckMultilinear::Folded {
208				large_field_folded_evals,
209			} => {
210				// REVIEW: note that this method is currently _not_ multithreaded, as
211				//         traces are usually sufficiently wide
212				fold_left_lerp_inplace(
213					large_field_folded_evals,
214					(large_field_folded_evals.len() * P::WIDTH).min(1 << n_vars),
215					n_vars,
216					challenge,
217				)?;
218				Ok(false)
219			}
220		}
221	})
222}
223
224fn parallel_map<P, M>(
225	multilinears: &mut [SumcheckMultilinear<P, M>],
226	map_multilinear: impl Fn(&mut SumcheckMultilinear<P, M>) -> Result<bool, Error> + Sync,
227) -> Result<bool, Error>
228where
229	P: PackedField,
230	M: MultilinearPoly<P> + Send + Sync,
231{
232	let any_transparent_left = multilinears
233		.par_iter_mut()
234		.try_fold(
235			|| false,
236			|any_transparent_left, sumcheck_multilinear| -> Result<bool, Error> {
237				let is_still_transparent = map_multilinear(sumcheck_multilinear)?;
238				Ok(any_transparent_left || is_still_transparent)
239			},
240		)
241		.try_reduce(|| false, |lhs, rhs| Ok(lhs || rhs))?;
242
243	Ok(any_transparent_left)
244}