1use std::iter;
4
5use binius_field::{BinaryField, ExtensionField, PackedField};
6use binius_math::{MultilinearQuery, extrapolate_line_scalar};
7use binius_maybe_rayon::prelude::*;
8use bytemuck::zeroed_vec;
9use tracing::instrument;
10
11use crate::AdditiveNTT;
12
13#[instrument(skip_all, level = "debug")]
27pub fn fold_interleaved_allocated<F, FS, NTT, P>(
28 ntt: &NTT,
29 codeword: &[P],
30 challenges: &[F],
31 log_len: usize,
32 log_batch_size: usize,
33 out: &mut [F],
34) where
35 F: BinaryField + ExtensionField<FS>,
36 FS: BinaryField,
37 NTT: AdditiveNTT<FS> + Sync,
38 P: PackedField<Scalar = F>,
39{
40 assert_eq!(codeword.len(), 1 << (log_len + log_batch_size).saturating_sub(P::LOG_WIDTH));
41 assert!(challenges.len() >= log_batch_size);
42 assert_eq!(out.len(), 1 << (log_len - (challenges.len() - log_batch_size)));
43
44 let (interleave_challenges, fold_challenges) = challenges.split_at(log_batch_size);
45 let tensor = MultilinearQuery::expand(interleave_challenges);
46
47 let fold_chunk_size = 1 << fold_challenges.len();
49 let chunk_size = 1 << challenges.len().saturating_sub(P::LOG_WIDTH);
50 codeword
51 .par_chunks(chunk_size)
52 .enumerate()
53 .zip(out)
54 .for_each_init(
55 || vec![F::default(); fold_chunk_size],
56 |scratch_buffer, ((i, chunk), out)| {
57 *out = fold_interleaved_chunk(
58 ntt,
59 log_len,
60 log_batch_size,
61 i,
62 chunk,
63 tensor.expansion(),
64 fold_challenges,
65 scratch_buffer,
66 )
67 },
68 )
69}
70
71pub fn fold_interleaved<F, FS, NTT, P>(
72 ntt: &NTT,
73 codeword: &[P],
74 challenges: &[F],
75 log_len: usize,
76 log_batch_size: usize,
77) -> Vec<F>
78where
79 F: BinaryField + ExtensionField<FS>,
80 FS: BinaryField,
81 NTT: AdditiveNTT<FS> + Sync,
82 P: PackedField<Scalar = F>,
83{
84 let mut result =
85 zeroed_vec(1 << log_len.saturating_sub(challenges.len().saturating_sub(log_batch_size)));
86 fold_interleaved_allocated(ntt, codeword, challenges, log_len, log_batch_size, &mut result);
87 result
88}
89
90#[inline]
96fn fold_pair<F, FS, NTT>(ntt: &NTT, round: usize, index: usize, values: (F, F), r: F) -> F
97where
98 F: BinaryField + ExtensionField<FS>,
99 FS: BinaryField,
100 NTT: AdditiveNTT<FS>,
101{
102 let t = ntt.get_subspace_eval(round, index);
104 let (mut u, mut v) = values;
105 v += u;
106 u += v * t;
107 extrapolate_line_scalar(u, v, r)
108}
109
110#[inline]
136pub fn fold_chunk<F, FS, NTT>(
137 ntt: &NTT,
138 mut log_len: usize,
139 chunk_index: usize,
140 values: &mut [F],
141 challenges: &[F],
142) -> F
143where
144 F: BinaryField + ExtensionField<FS>,
145 FS: BinaryField,
146 NTT: AdditiveNTT<FS>,
147{
148 let mut log_size = challenges.len();
149
150 debug_assert!(log_size <= log_len);
152 debug_assert!(log_len <= ntt.log_domain_size());
153 debug_assert_eq!(values.len(), 1 << log_size);
154
155 for &challenge in challenges {
157 for index_offset in 0..1 << (log_size - 1) {
159 let pair = (values[index_offset << 1], values[(index_offset << 1) | 1]);
160 values[index_offset] = fold_pair(
161 ntt,
162 log_len,
163 (chunk_index << (log_size - 1)) | index_offset,
164 pair,
165 challenge,
166 )
167 }
168
169 log_len -= 1;
170 log_size -= 1;
171 }
172
173 values[0]
174}
175
176#[inline]
199#[allow(clippy::too_many_arguments)]
200pub fn fold_interleaved_chunk<F, FS, P, NTT>(
201 ntt: &NTT,
202 log_len: usize,
203 log_batch_size: usize,
204 chunk_index: usize,
205 values: &[P],
206 tensor: &[P],
207 fold_challenges: &[F],
208 scratch_buffer: &mut [F],
209) -> F
210where
211 F: BinaryField + ExtensionField<FS>,
212 FS: BinaryField,
213 NTT: AdditiveNTT<FS>,
214 P: PackedField<Scalar = F>,
215{
216 debug_assert!(fold_challenges.len() <= log_len);
218 debug_assert!(log_len <= ntt.log_domain_size());
219 debug_assert_eq!(
220 values.len(),
221 1 << (fold_challenges.len() + log_batch_size).saturating_sub(P::LOG_WIDTH)
222 );
223 debug_assert_eq!(tensor.len(), 1 << log_batch_size.saturating_sub(P::LOG_WIDTH));
224 debug_assert!(scratch_buffer.len() >= 1 << fold_challenges.len());
225
226 let scratch_buffer = &mut scratch_buffer[..1 << fold_challenges.len()];
227
228 if log_batch_size == 0 {
229 iter::zip(&mut *scratch_buffer, P::iter_slice(values)).for_each(|(dst, val)| *dst = val);
230 } else {
231 let folded_values = values
232 .chunks(1 << (log_batch_size - P::LOG_WIDTH))
233 .map(|chunk| {
234 iter::zip(chunk, tensor)
235 .map(|(&a_i, &b_i)| a_i * b_i)
236 .sum::<P>()
237 .into_iter()
238 .take(1 << log_batch_size)
239 .sum()
240 });
241 iter::zip(&mut *scratch_buffer, folded_values).for_each(|(dst, val)| *dst = val);
242 };
243
244 fold_chunk(ntt, log_len, chunk_index, scratch_buffer, fold_challenges)
245}