1use std::{iter, sync::Arc};
4
5use binius_compute::{
6 ComputeLayer, ComputeLayerExecutor, ComputeMemory, FSlice, alloc::ComputeAllocator,
7 cpu::CpuMemory, layer,
8};
9use binius_field::{Field, PackedField, PackedFieldIndexable};
10use binius_math::{
11 B1, B8, B16, B32, B64, B128, MultilinearPoly, MultilinearQuery, PackedTop, TowerTop,
12};
13use binius_maybe_rayon::prelude::*;
14use binius_utils::checked_arithmetics::log2_ceil_usize;
15use itertools::izip;
16use tracing::instrument;
17
18use super::{
19 common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
20 eq_ind::RowBatchCoeffs,
21 error::Error,
22 logging::MLEFoldHisgDimensionsData,
23 tower_tensor_algebra::TowerTensorAlgebra,
24};
25use crate::{
26 fiat_shamir::{CanSample, Challenger},
27 piop::PIOPSumcheckClaim,
28 protocols::evalcheck::subclaims::MemoizedData,
29 ring_switch::{
30 common::EvalClaimSuffixDesc,
31 eq_ind::{RingSwitchEqInd, RingSwitchEqIndPrecompute},
32 logging::CalculateRingSwitchEqIndData,
33 },
34 transcript::ProverTranscript,
35};
36
37pub struct ReducedWitness<'a, F: Field, Hal: ComputeLayer<F>> {
38 pub transparents: Vec<FSlice<'a, F, Hal>>,
39 pub sumcheck_claims: Vec<PIOPSumcheckClaim<F>>,
40}
41
42pub fn prove<'a, F, P, M, Challenger_, Hal, HostAllocatorType, DeviceAllocatorType>(
43 system: &EvalClaimSystem<F>,
44 witnesses: &[M],
45 transcript: &mut ProverTranscript<Challenger_>,
46 memoized_data: MemoizedData<P>,
47 hal: &Hal,
48 dev_alloc: &'a DeviceAllocatorType,
49 host_alloc: &HostAllocatorType,
50) -> Result<ReducedWitness<'a, F, Hal>, Error>
51where
52 F: TowerTop + PackedTop<Scalar = F>,
53 P: PackedFieldIndexable<Scalar = F> + PackedTop,
54 M: MultilinearPoly<P> + Sync,
55 Challenger_: Challenger,
56 Hal: ComputeLayer<F>,
57 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
58 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
59{
60 if witnesses.len() != system.commit_meta.total_multilins() {
61 return Err(Error::InvalidWitness(
62 "witness length does not match the number of multilinears".into(),
63 ));
64 }
65
66 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
69 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
70 let dimensions_data = MLEFoldHisgDimensionsData::new(witnesses);
71 let mle_fold_high_span = tracing::debug_span!(
72 "[task] (Ring Switch) MLE Fold High",
73 phase = "ring_switch",
74 perfetto_category = "task.main",
75 ?dimensions_data,
76 )
77 .entered();
78
79 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
80
81 let tensor_elems = compute_partial_evals(system, witnesses, memoized_data)?;
83 let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
84 let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
85 &scaled_tensor_elems,
86 &system.prefix_descs,
87 &system.eval_claim_to_prefix_desc_index,
88 )?;
89 drop(mle_fold_high_span);
90 let mut writer = transcript.message();
91 for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
92 debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
93 writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
94 }
95
96 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
98 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
99 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
100 ));
101
102 let row_batched_evals =
103 compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
104 transcript.message().write_scalar_slice(&row_batched_evals);
105
106 let dimensions_data = CalculateRingSwitchEqIndData::new(system.suffix_descs.iter());
108 let calculate_ring_switch_eq_ind_span = tracing::debug_span!(
109 "[task] Calculate Ring Switch Eq Ind",
110 phase = "ring_switch",
111 perfetto_category = "task.main",
112 ?dimensions_data,
113 )
114 .entered();
115
116 let ring_switch_eq_inds = make_ring_switch_eq_inds(
117 &system.sumcheck_claim_descs,
118 &system.suffix_descs,
119 row_batch_coeffs,
120 &mixing_coeffs,
121 hal,
122 dev_alloc,
123 host_alloc,
124 )?;
125 drop(calculate_ring_switch_eq_ind_span);
126
127 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
128 .enumerate()
129 .map(|(idx, (claim_desc, eval))| {
130 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
131 PIOPSumcheckClaim {
132 n_vars: suffix_desc.suffix.len(),
133 committed: claim_desc.committed_idx,
134 transparent: idx,
135 sum: eval,
136 }
137 })
138 .collect::<Vec<_>>();
139
140 Ok(ReducedWitness {
141 transparents: ring_switch_eq_inds,
142 sumcheck_claims,
143 })
144}
145
146#[instrument(skip_all)]
147fn compute_partial_evals<F, P, M>(
148 system: &EvalClaimSystem<F>,
149 witnesses: &[M],
150 mut memoized_data: MemoizedData<P>,
151) -> Result<Vec<TowerTensorAlgebra<F>>, Error>
152where
153 F: TowerTop,
154 P: PackedField<Scalar = F>,
155 M: MultilinearPoly<P> + Sync,
156{
157 let suffixes = system
158 .suffix_descs
159 .iter()
160 .map(|desc| Arc::as_ref(&desc.suffix))
161 .collect::<Vec<_>>();
162
163 memoized_data.memoize_query_par(suffixes)?;
164
165 let tensor_elems = system
166 .sumcheck_claim_descs
167 .par_iter()
168 .map(
169 |PIOPSumcheckClaimDesc {
170 committed_idx,
171 suffix_desc_idx,
172 eval_claim,
173 }| {
174 let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
175
176 let mut elems = if let Some(partial_eval) =
177 memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
178 {
179 PackedField::iter_slice(
180 partial_eval.packed_evals().expect("packed_evals exist"),
181 )
182 .take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
183 .collect::<Vec<_>>()
184 } else {
185 let suffix_query = memoized_data
186 .full_query_readonly(&suffix_desc.suffix)
187 .expect("memoized above");
188 let partial_eval =
189 witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
190 PackedField::iter_slice(partial_eval.evals())
191 .take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
192 .collect::<Vec<_>>()
193 };
194
195 if elems.len() < (1 << suffix_desc.kappa) {
196 elems = elems
197 .into_iter()
198 .cycle()
199 .take(1 << suffix_desc.kappa)
200 .collect();
201 }
202 TowerTensorAlgebra::new(suffix_desc.kappa, elems)
203 },
204 )
205 .collect::<Result<Vec<_>, _>>()?;
206
207 Ok(tensor_elems)
208}
209
210fn scale_tensor_elems<F>(
211 tensor_elems: Vec<TowerTensorAlgebra<F>>,
212 mixing_coeffs: &[F],
213) -> Vec<TowerTensorAlgebra<F>>
214where
215 F: TowerTop,
216{
217 assert!(tensor_elems.len() <= mixing_coeffs.len());
219
220 tensor_elems
221 .into_par_iter()
222 .zip(mixing_coeffs)
223 .map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
224 .collect()
225}
226
227fn mix_tensor_elems_for_prefixes<F>(
228 scaled_tensor_elems: &[TowerTensorAlgebra<F>],
229 prefix_descs: &[EvalClaimPrefixDesc<F>],
230 eval_claim_to_prefix_desc_index: &[usize],
231) -> Result<Vec<TowerTensorAlgebra<F>>, Error>
232where
233 F: TowerTop,
234{
235 assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
237
238 let mut batched_tensor_elems = prefix_descs
239 .iter()
240 .map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
241 .collect::<Result<Vec<_>, _>>()?;
242 for (tensor_elem, &desc_index) in
243 iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
244 {
245 let mixed_val = &mut batched_tensor_elems[desc_index];
246 debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
247 mixed_val.add_assign(tensor_elem)?;
248 }
249 Ok(batched_tensor_elems)
250}
251
252#[instrument(skip_all)]
253fn compute_row_batched_sumcheck_evals<F>(
254 tensor_elems: Vec<TowerTensorAlgebra<F>>,
255 row_batch_coeffs: &[F],
256) -> Vec<F>
257where
258 F: TowerTop + PackedTop,
259{
260 tensor_elems
261 .into_par_iter()
262 .map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
263 .collect()
264}
265
266fn make_ring_switch_eq_inds<'a, F, Hal, HostAllocatorType, DeviceAllocatorType>(
267 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
268 suffix_descs: &[EvalClaimSuffixDesc<F>],
269 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
270 mixing_coeffs: &[F],
271 hal: &Hal,
272 dev_alloc: &'a DeviceAllocatorType,
273 host_alloc: &HostAllocatorType,
274) -> Result<Vec<FSlice<'a, F, Hal>>, Error>
275where
276 F: TowerTop,
277 Hal: ComputeLayer<F>,
278 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
279 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
280{
281 let mut eq_inds = Vec::with_capacity(sumcheck_claim_descs.len());
282
283 let precompute = sumcheck_claim_descs
284 .iter()
285 .zip(mixing_coeffs)
286 .map(|(claim_desc, mixing_coeffs)| {
287 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
288 RingSwitchEqInd::<F, F>::precompute_values(
289 suffix_desc.suffix.clone(),
290 row_batch_coeffs.clone(),
291 *mixing_coeffs,
292 suffix_desc.kappa,
293 hal,
294 dev_alloc,
295 host_alloc,
296 )
297 })
298 .collect::<Result<Vec<_>, Error>>()?;
299
300 let _ = hal.execute(|exec| {
301 let res = exec
302 .map(
303 izip!(sumcheck_claim_descs, mixing_coeffs, precompute),
304 |exec, (claim_desc, &mixing_coeff, precompute)| {
305 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
306
307 make_ring_switch_eq_ind(
308 suffix_desc,
309 row_batch_coeffs.clone(),
310 mixing_coeff,
311 exec,
312 precompute,
313 )
314 .map_err(|e| layer::Error::CoreLibError(Box::new(e)))
315 },
316 )
317 .map_err(|e| layer::Error::CoreLibError(Box::new(e)))?;
318
319 eq_inds.extend(res);
320 Ok(vec![])
321 })?;
322
323 Ok(eq_inds)
324}
325
326fn make_ring_switch_eq_ind<'a, F, Mem, Exec>(
327 suffix_desc: &EvalClaimSuffixDesc<F>,
328 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
329 mixing_coeff: F,
330 exec: &mut Exec,
331 precompute: RingSwitchEqIndPrecompute<'a, F, Mem>,
332) -> Result<Mem::FSlice<'a>, Error>
333where
334 F: TowerTop,
335 Mem: ComputeMemory<F>,
336 Exec: ComputeLayerExecutor<F, DevMem = Mem>,
337{
338 let eq_ind = match F::TOWER_LEVEL - suffix_desc.kappa {
339 0 => RingSwitchEqInd::<B1, _>::new(
340 suffix_desc.suffix.clone(),
341 row_batch_coeffs,
342 mixing_coeff,
343 )?
344 .multilinear_extension(precompute, exec, 0),
345 3 => RingSwitchEqInd::<B8, _>::new(
346 suffix_desc.suffix.clone(),
347 row_batch_coeffs,
348 mixing_coeff,
349 )?
350 .multilinear_extension(precompute, exec, 3),
351 4 => RingSwitchEqInd::<B16, _>::new(
352 suffix_desc.suffix.clone(),
353 row_batch_coeffs,
354 mixing_coeff,
355 )?
356 .multilinear_extension(precompute, exec, 4),
357 5 => RingSwitchEqInd::<B32, _>::new(
358 suffix_desc.suffix.clone(),
359 row_batch_coeffs,
360 mixing_coeff,
361 )?
362 .multilinear_extension(precompute, exec, 5),
363 6 => RingSwitchEqInd::<B64, _>::new(
364 suffix_desc.suffix.clone(),
365 row_batch_coeffs,
366 mixing_coeff,
367 )?
368 .multilinear_extension(precompute, exec, 6),
369 7 => RingSwitchEqInd::<B128, _>::new(
370 suffix_desc.suffix.clone(),
371 row_batch_coeffs,
372 mixing_coeff,
373 )?
374 .multilinear_extension(precompute, exec, 7),
375 _ => Err(Error::PackingDegreeNotSupported {
376 kappa: suffix_desc.kappa,
377 }),
378 }?;
379 Ok(eq_ind)
380}