1use binius_field::PackedField;
4use binius_math::{
5 fold_left_lerp_inplace, fold_right_lerp, EvaluationOrder, MultilinearPoly, MultilinearQueryRef,
6};
7use binius_maybe_rayon::prelude::*;
8use binius_utils::checked_arithmetics::log2_ceil_usize;
9use bytemuck::zeroed_vec;
10
11use crate::{
12 common::{subcube_vars_for_bits, MAX_SRC_SUBCUBE_LOG_BITS},
13 Error, SumcheckMultilinear,
14};
15
16pub(crate) fn fold_multilinears<P, M>(
17 evaluation_order: EvaluationOrder,
18 n_vars: usize,
19 multilinears: &mut [SumcheckMultilinear<P, M>],
20 challenge: P::Scalar,
21 tensor_query: Option<MultilinearQueryRef<P>>,
22) -> Result<bool, Error>
23where
24 P: PackedField,
25 M: MultilinearPoly<P> + Send + Sync,
26{
27 match evaluation_order {
28 EvaluationOrder::LowToHigh => {
29 fold_multilinears_low_to_high(n_vars, multilinears, challenge, tensor_query)
30 }
31 EvaluationOrder::HighToLow => {
32 fold_multilinears_high_to_low(n_vars, multilinears, challenge, tensor_query)
33 }
34 }
35}
36
37fn fold_multilinears_low_to_high<P, M>(
38 n_vars: usize,
39 multilinears: &mut [SumcheckMultilinear<P, M>],
40 challenge: P::Scalar,
41 tensor_query: Option<MultilinearQueryRef<P>>,
42) -> Result<bool, Error>
43where
44 P: PackedField,
45 M: MultilinearPoly<P> + Send + Sync,
46{
47 assert!(n_vars > 0);
48 parallel_map(multilinears, |sumcheck_multilinear| -> Result<_, Error> {
49 match sumcheck_multilinear {
50 SumcheckMultilinear::Transparent {
51 multilinear,
52 switchover_round,
53 zero_scalars_suffix,
54 } => {
55 if *switchover_round == 0 {
56 let tensor_query = tensor_query
58 .as_ref()
59 .expect("guaranteed to be Some while there is still a transparent");
60
61 assert!(tensor_query.n_vars() > 0);
62
63 let nonzero_scalars_prefix = (1 << n_vars) - *zero_scalars_suffix;
64
65 let large_field_folded_evals = if nonzero_scalars_prefix < 1 << n_vars {
66 let subcube_vars = subcube_vars_for_bits::<P>(
67 MAX_SRC_SUBCUBE_LOG_BITS,
68 log2_ceil_usize(nonzero_scalars_prefix),
69 tensor_query.n_vars(),
70 n_vars - 1,
71 );
72
73 let packed_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
74
75 let folded_scalars =
76 nonzero_scalars_prefix.div_ceil(1 << tensor_query.n_vars());
77
78 let mut folded =
79 zeroed_vec(folded_scalars.div_ceil(1 << subcube_vars) * packed_len);
80
81 for (subcube_index, subcube_evals) in
83 folded.chunks_exact_mut(packed_len).enumerate()
84 {
85 multilinear.subcube_partial_low_evals(
86 *tensor_query,
87 subcube_vars,
88 subcube_index,
89 subcube_evals,
90 )?;
91 }
92
93 folded.truncate(folded_scalars.div_ceil(P::WIDTH));
94 folded
95 } else {
96 multilinear
97 .evaluate_partial_low(*tensor_query)?
98 .into_evals()
99 };
100
101 *sumcheck_multilinear = SumcheckMultilinear::Folded {
102 large_field_folded_evals,
103 };
104
105 Ok(false)
106 } else {
107 *switchover_round -= 1;
108 Ok(true)
109 }
110 }
111
112 SumcheckMultilinear::Folded {
113 large_field_folded_evals: evals,
114 } => {
115 let mut new_evals = zeroed_vec(evals.len().div_ceil(2));
119
120 fold_right_lerp(
121 evals.as_slice(),
122 evals.len() * P::WIDTH,
124 challenge,
125 &mut new_evals,
126 )?;
127
128 *evals = new_evals;
129 Ok(false)
130 }
131 }
132 })
133}
134
135fn fold_multilinears_high_to_low<P, M>(
136 n_vars: usize,
137 multilinears: &mut [SumcheckMultilinear<P, M>],
138 challenge: P::Scalar,
139 tensor_query: Option<MultilinearQueryRef<P>>,
140) -> Result<bool, Error>
141where
142 P: PackedField,
143 M: MultilinearPoly<P> + Send + Sync,
144{
145 parallel_map(multilinears, |sumcheck_multilinear| -> Result<_, Error> {
146 match sumcheck_multilinear {
147 SumcheckMultilinear::Transparent {
148 multilinear,
149 switchover_round,
150 zero_scalars_suffix,
151 } => {
152 if *switchover_round == 0 {
153 let tensor_query = tensor_query
155 .as_ref()
156 .expect("guaranteed to be Some while there is still a transparent");
157
158 let nonzero_scalars_prefix = (1 << n_vars) - *zero_scalars_suffix;
159
160 let large_field_folded_evals = if nonzero_scalars_prefix < 1 << n_vars {
161 let subcube_vars = subcube_vars_for_bits::<P>(
162 MAX_SRC_SUBCUBE_LOG_BITS,
163 log2_ceil_usize(nonzero_scalars_prefix),
164 tensor_query.n_vars(),
165 n_vars - 1,
166 );
167
168 let packed_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
169
170 let folded_scalars =
171 nonzero_scalars_prefix.min(1 << (n_vars - tensor_query.n_vars()));
172
173 let mut folded =
174 zeroed_vec(folded_scalars.div_ceil(1 << subcube_vars) * packed_len);
175
176 for (subcube_index, subcube_evals) in
178 folded.chunks_exact_mut(packed_len).enumerate()
179 {
180 multilinear.subcube_partial_high_evals(
181 *tensor_query,
182 subcube_vars,
183 subcube_index,
184 subcube_evals,
185 )?;
186 }
187
188 folded.truncate(folded_scalars.div_ceil(P::WIDTH));
189 folded
190 } else {
191 multilinear
192 .evaluate_partial_high(*tensor_query)?
193 .into_evals()
194 };
195
196 *sumcheck_multilinear = SumcheckMultilinear::Folded {
197 large_field_folded_evals,
198 };
199
200 Ok(false)
201 } else {
202 *switchover_round -= 1;
203 Ok(true)
204 }
205 }
206
207 SumcheckMultilinear::Folded {
208 large_field_folded_evals,
209 } => {
210 fold_left_lerp_inplace(
213 large_field_folded_evals,
214 (large_field_folded_evals.len() * P::WIDTH).min(1 << n_vars),
215 n_vars,
216 challenge,
217 )?;
218 Ok(false)
219 }
220 }
221 })
222}
223
224fn parallel_map<P, M>(
225 multilinears: &mut [SumcheckMultilinear<P, M>],
226 map_multilinear: impl Fn(&mut SumcheckMultilinear<P, M>) -> Result<bool, Error> + Sync,
227) -> Result<bool, Error>
228where
229 P: PackedField,
230 M: MultilinearPoly<P> + Send + Sync,
231{
232 let any_transparent_left = multilinears
233 .par_iter_mut()
234 .try_fold(
235 || false,
236 |any_transparent_left, sumcheck_multilinear| -> Result<bool, Error> {
237 let is_still_transparent = map_multilinear(sumcheck_multilinear)?;
238 Ok(any_transparent_left || is_still_transparent)
239 },
240 )
241 .try_reduce(|| false, |lhs, rhs| Ok(lhs || rhs))?;
242
243 Ok(any_transparent_left)
244}