1use std::iter;
4
5use binius_field::{BinaryField, ExtensionField, PackedField};
6use binius_math::{MultilinearQuery, extrapolate_line_scalar};
7use binius_maybe_rayon::prelude::*;
8use bytemuck::zeroed_vec;
9use tracing::instrument;
10
11use crate::AdditiveNTT;
12
13#[instrument(skip_all, level = "debug")]
27pub fn fold_interleaved_allocated<F, FS, NTT, P>(
28 ntt: &NTT,
29 codeword: &[P],
30 challenges: &[F],
31 log_len: usize,
32 log_batch_size: usize,
33 out: &mut [F],
34) where
35 F: BinaryField + ExtensionField<FS>,
36 FS: BinaryField,
37 NTT: AdditiveNTT<FS> + Sync,
38 P: PackedField<Scalar = F>,
39{
40 assert_eq!(codeword.len(), 1 << (log_len + log_batch_size).saturating_sub(P::LOG_WIDTH));
41 assert!(challenges.len() >= log_batch_size);
42 assert_eq!(out.len(), 1 << (log_len - (challenges.len() - log_batch_size)));
43
44 let (interleave_challenges, fold_challenges) = challenges.split_at(log_batch_size);
45 let tensor = MultilinearQuery::expand(interleave_challenges);
46
47 let fold_chunk_size = 1 << fold_challenges.len();
49 let chunk_size = 1 << challenges.len().saturating_sub(P::LOG_WIDTH);
50 codeword
51 .par_chunks(chunk_size)
52 .enumerate()
53 .zip(out)
54 .for_each_init(
55 || vec![F::default(); fold_chunk_size],
56 |scratch_buffer, ((i, chunk), out)| {
57 *out = fold_interleaved_chunk(
58 ntt,
59 log_len,
60 log_batch_size,
61 i,
62 chunk,
63 tensor.expansion(),
64 fold_challenges,
65 scratch_buffer,
66 )
67 },
68 )
69}
70
71pub fn fold_interleaved<F, FS, NTT, P>(
72 ntt: &NTT,
73 codeword: &[P],
74 challenges: &[F],
75 log_len: usize,
76 log_batch_size: usize,
77) -> Vec<F>
78where
79 F: BinaryField + ExtensionField<FS>,
80 FS: BinaryField,
81 NTT: AdditiveNTT<FS> + Sync,
82 P: PackedField<Scalar = F>,
83{
84 let mut result =
85 zeroed_vec(1 << log_len.saturating_sub(challenges.len().saturating_sub(log_batch_size)));
86 fold_interleaved_allocated(ntt, codeword, challenges, log_len, log_batch_size, &mut result);
87 result
88}
89
90#[inline]
96fn fold_pair<F, FS, NTT>(ntt: &NTT, round: usize, index: usize, values: (F, F), r: F) -> F
97where
98 F: BinaryField + ExtensionField<FS>,
99 FS: BinaryField,
100 NTT: AdditiveNTT<FS>,
101{
102 let t = ntt.get_subspace_eval(round, index);
104 let (mut u, mut v) = values;
105 v += u;
106 u += v * t;
107 extrapolate_line_scalar(u, v, r)
108}
109
110#[inline]
136pub fn fold_chunk<F, FS, NTT>(
137 ntt: &NTT,
138 mut log_len: usize,
139 chunk_index: usize,
140 values: &mut [F],
141 challenges: &[F],
142) -> F
143where
144 F: BinaryField + ExtensionField<FS>,
145 FS: BinaryField,
146 NTT: AdditiveNTT<FS>,
147{
148 let mut log_size = challenges.len();
149
150 debug_assert!(log_size <= log_len);
152 debug_assert!(log_len <= ntt.log_domain_size());
153 debug_assert_eq!(values.len(), 1 << log_size);
154
155 for &challenge in challenges {
157 let ntt_round = ntt.log_domain_size() - log_len;
159 for index_offset in 0..1 << (log_size - 1) {
160 let pair = (values[index_offset << 1], values[(index_offset << 1) | 1]);
161 values[index_offset] = fold_pair(
162 ntt,
163 ntt_round,
164 (chunk_index << (log_size - 1)) | index_offset,
165 pair,
166 challenge,
167 )
168 }
169
170 log_len -= 1;
171 log_size -= 1;
172 }
173
174 values[0]
175}
176
177#[inline]
200#[allow(clippy::too_many_arguments)]
201pub fn fold_interleaved_chunk<F, FS, P, NTT>(
202 ntt: &NTT,
203 log_len: usize,
204 log_batch_size: usize,
205 chunk_index: usize,
206 values: &[P],
207 tensor: &[P],
208 fold_challenges: &[F],
209 scratch_buffer: &mut [F],
210) -> F
211where
212 F: BinaryField + ExtensionField<FS>,
213 FS: BinaryField,
214 NTT: AdditiveNTT<FS>,
215 P: PackedField<Scalar = F>,
216{
217 debug_assert!(fold_challenges.len() <= log_len);
219 debug_assert!(log_len <= ntt.log_domain_size());
220 debug_assert_eq!(
221 values.len(),
222 1 << (fold_challenges.len() + log_batch_size).saturating_sub(P::LOG_WIDTH)
223 );
224 debug_assert_eq!(tensor.len(), 1 << log_batch_size.saturating_sub(P::LOG_WIDTH));
225 debug_assert!(scratch_buffer.len() >= 1 << fold_challenges.len());
226
227 let scratch_buffer = &mut scratch_buffer[..1 << fold_challenges.len()];
228
229 if log_batch_size == 0 {
230 iter::zip(&mut *scratch_buffer, P::iter_slice(values)).for_each(|(dst, val)| *dst = val);
231 } else {
232 let folded_values = values
233 .chunks(1 << (log_batch_size - P::LOG_WIDTH))
234 .map(|chunk| {
235 iter::zip(chunk, tensor)
236 .map(|(&a_i, &b_i)| a_i * b_i)
237 .sum::<P>()
238 .into_iter()
239 .take(1 << log_batch_size)
240 .sum()
241 });
242 iter::zip(&mut *scratch_buffer, folded_values).for_each(|(dst, val)| *dst = val);
243 };
244
245 fold_chunk(ntt, log_len, chunk_index, scratch_buffer, fold_challenges)
246}