1use std::{iter, sync::Arc};
4
5use binius_field::{
6 PackedExtension, PackedField, PackedFieldIndexable, TowerField,
7 tower::{PackedTop, TowerFamily},
8};
9use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
10use binius_maybe_rayon::prelude::*;
11use binius_utils::checked_arithmetics::log2_ceil_usize;
12use tracing::instrument;
13
14use super::{
15 common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
16 eq_ind::RowBatchCoeffs,
17 error::Error,
18 logging::MLEFoldHisgDimensionsData,
19 tower_tensor_algebra::TowerTensorAlgebra,
20};
21use crate::{
22 fiat_shamir::{CanSample, Challenger},
23 piop::PIOPSumcheckClaim,
24 protocols::evalcheck::subclaims::MemoizedData,
25 ring_switch::{
26 common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd, logging::CalculateRingSwitchEqIndData,
27 },
28 transcript::ProverTranscript,
29 witness::MultilinearWitness,
30};
31
32type FExt<Tower> = <Tower as TowerFamily>::B128;
33
34#[derive(Debug)]
35pub struct ReducedWitness<P: PackedField> {
36 pub transparents: Vec<MultilinearWitness<'static, P>>,
37 pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
38}
39
40pub fn prove<F, P, M, Tower, Challenger_>(
41 system: &EvalClaimSystem<F>,
42 witnesses: &[M],
43 transcript: &mut ProverTranscript<Challenger_>,
44 memoized_data: MemoizedData<P>,
45) -> Result<ReducedWitness<P>, Error>
46where
47 F: TowerField + PackedTop<Tower>,
48 P: PackedFieldIndexable<Scalar = F>
49 + PackedExtension<Tower::B1>
50 + PackedExtension<Tower::B8>
51 + PackedExtension<Tower::B16>
52 + PackedExtension<Tower::B32>
53 + PackedExtension<Tower::B64>
54 + PackedExtension<Tower::B128>,
55 M: MultilinearPoly<P> + Sync,
56 Tower: TowerFamily<B128 = F>,
57 Challenger_: Challenger,
58{
59 if witnesses.len() != system.commit_meta.total_multilins() {
60 return Err(Error::InvalidWitness(
61 "witness length does not match the number of multilinears".into(),
62 ));
63 }
64
65 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
68 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
69 let dimensions_data = MLEFoldHisgDimensionsData::new(witnesses);
70 let mle_fold_high_span = tracing::debug_span!(
71 "[task] (Ring Switch) MLE Fold High",
72 phase = "ring_switch",
73 perfetto_category = "task.main",
74 ?dimensions_data,
75 )
76 .entered();
77
78 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
79
80 let tensor_elems = compute_partial_evals::<_, _, _, Tower>(system, witnesses, memoized_data)?;
82 let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
83 let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
84 &scaled_tensor_elems,
85 &system.prefix_descs,
86 &system.eval_claim_to_prefix_desc_index,
87 )?;
88 drop(mle_fold_high_span);
89 let mut writer = transcript.message();
90 for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
91 debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
92 writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
93 }
94
95 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
97 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
98 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
99 ));
100
101 let row_batched_evals =
102 compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
103 transcript.message().write_scalar_slice(&row_batched_evals);
104
105 let dimensions_data = CalculateRingSwitchEqIndData::new(system.suffix_descs.iter());
107 let calculate_ring_switch_eq_ind_span = tracing::debug_span!(
108 "[task] Calculate Ring Switch Eq Ind",
109 phase = "ring_switch",
110 perfetto_category = "task.main",
111 ?dimensions_data,
112 )
113 .entered();
114
115 let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
116 &system.sumcheck_claim_descs,
117 &system.suffix_descs,
118 row_batch_coeffs,
119 &mixing_coeffs,
120 )?;
121 drop(calculate_ring_switch_eq_ind_span);
122
123 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
124 .enumerate()
125 .map(|(idx, (claim_desc, eval))| {
126 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
127 PIOPSumcheckClaim {
128 n_vars: suffix_desc.suffix.len(),
129 committed: claim_desc.committed_idx,
130 transparent: idx,
131 sum: eval,
132 }
133 })
134 .collect::<Vec<_>>();
135
136 Ok(ReducedWitness {
137 transparents: ring_switch_eq_inds,
138 sumcheck_claims,
139 })
140}
141
142#[instrument(skip_all)]
143fn compute_partial_evals<F, P, M, Tower>(
144 system: &EvalClaimSystem<F>,
145 witnesses: &[M],
146 mut memoized_data: MemoizedData<P>,
147) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
148where
149 F: TowerField,
150 P: PackedField<Scalar = F>,
151 M: MultilinearPoly<P> + Sync,
152 Tower: TowerFamily<B128 = F>,
153{
154 let suffixes = system
155 .suffix_descs
156 .iter()
157 .map(|desc| Arc::as_ref(&desc.suffix))
158 .collect::<Vec<_>>();
159
160 memoized_data.memoize_query_par(suffixes)?;
161
162 let tensor_elems = system
163 .sumcheck_claim_descs
164 .par_iter()
165 .map(
166 |PIOPSumcheckClaimDesc {
167 committed_idx,
168 suffix_desc_idx,
169 eval_claim,
170 }| {
171 let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
172
173 let mut elems = if let Some(partial_eval) =
174 memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
175 {
176 PackedField::iter_slice(
177 partial_eval.packed_evals().expect("packed_evals exist"),
178 )
179 .take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
180 .collect::<Vec<_>>()
181 } else {
182 let suffix_query = memoized_data
183 .full_query_readonly(&suffix_desc.suffix)
184 .expect("memoized above");
185 let partial_eval =
186 witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
187 PackedField::iter_slice(partial_eval.evals())
188 .take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
189 .collect::<Vec<_>>()
190 };
191
192 if elems.len() < (1 << suffix_desc.kappa) {
193 elems = elems
194 .into_iter()
195 .cycle()
196 .take(1 << suffix_desc.kappa)
197 .collect();
198 }
199 TowerTensorAlgebra::new(suffix_desc.kappa, elems)
200 },
201 )
202 .collect::<Result<Vec<_>, _>>()?;
203
204 Ok(tensor_elems)
205}
206
207fn scale_tensor_elems<F, Tower>(
208 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
209 mixing_coeffs: &[F],
210) -> Vec<TowerTensorAlgebra<Tower>>
211where
212 F: TowerField,
213 Tower: TowerFamily<B128 = F>,
214{
215 assert!(tensor_elems.len() <= mixing_coeffs.len());
217
218 tensor_elems
219 .into_par_iter()
220 .zip(mixing_coeffs)
221 .map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
222 .collect()
223}
224
225fn mix_tensor_elems_for_prefixes<F, Tower>(
226 scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
227 prefix_descs: &[EvalClaimPrefixDesc<F>],
228 eval_claim_to_prefix_desc_index: &[usize],
229) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
230where
231 F: TowerField,
232 Tower: TowerFamily<B128 = F>,
233{
234 assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
236
237 let mut batched_tensor_elems = prefix_descs
238 .iter()
239 .map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
240 .collect::<Result<Vec<_>, _>>()?;
241 for (tensor_elem, &desc_index) in
242 iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
243 {
244 let mixed_val = &mut batched_tensor_elems[desc_index];
245 debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
246 mixed_val.add_assign(tensor_elem)?;
247 }
248 Ok(batched_tensor_elems)
249}
250
251#[instrument(skip_all)]
252fn compute_row_batched_sumcheck_evals<F, Tower>(
253 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
254 row_batch_coeffs: &[F],
255) -> Vec<F>
256where
257 F: TowerField,
258 Tower: TowerFamily<B128 = F>,
259 F: PackedTop<Tower>,
260{
261 tensor_elems
262 .into_par_iter()
263 .map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
264 .collect()
265}
266
267fn make_ring_switch_eq_inds<F, P, Tower>(
268 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
269 suffix_descs: &[EvalClaimSuffixDesc<F>],
270 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
271 mixing_coeffs: &[F],
272) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
273where
274 F: TowerField + PackedTop<Tower>,
275 P: PackedFieldIndexable<Scalar = F>
276 + PackedExtension<Tower::B1>
277 + PackedExtension<Tower::B8>
278 + PackedExtension<Tower::B16>
279 + PackedExtension<Tower::B32>
280 + PackedExtension<Tower::B64>
281 + PackedExtension<Tower::B128>,
282 Tower: TowerFamily<B128 = F>,
283{
284 sumcheck_claim_descs
285 .par_iter()
286 .zip(mixing_coeffs)
287 .map(|(claim_desc, &mixing_coeff)| {
288 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
289 make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
290 })
291 .collect()
292}
293
294fn make_ring_switch_eq_ind<P, Tower>(
295 suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
296 row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
297 mixing_coeff: FExt<Tower>,
298) -> Result<MultilinearWitness<'static, P>, Error>
299where
300 P: PackedFieldIndexable<Scalar = FExt<Tower>> + PackedTop<Tower>,
301 Tower: TowerFamily,
302{
303 let eq_ind = match suffix_desc.kappa {
304 7 => RingSwitchEqInd::<Tower::B1, _>::new(
305 suffix_desc.suffix.clone(),
306 row_batch_coeffs,
307 mixing_coeff,
308 )?
309 .multilinear_extension::<P>(),
310 4 => RingSwitchEqInd::<Tower::B8, _>::new(
311 suffix_desc.suffix.clone(),
312 row_batch_coeffs,
313 mixing_coeff,
314 )?
315 .multilinear_extension(),
316 3 => RingSwitchEqInd::<Tower::B16, _>::new(
317 suffix_desc.suffix.clone(),
318 row_batch_coeffs,
319 mixing_coeff,
320 )?
321 .multilinear_extension(),
322 2 => RingSwitchEqInd::<Tower::B32, _>::new(
323 suffix_desc.suffix.clone(),
324 row_batch_coeffs,
325 mixing_coeff,
326 )?
327 .multilinear_extension(),
328 1 => RingSwitchEqInd::<Tower::B64, _>::new(
329 suffix_desc.suffix.clone(),
330 row_batch_coeffs,
331 mixing_coeff,
332 )?
333 .multilinear_extension(),
334 0 => RingSwitchEqInd::<Tower::B128, _>::new(
335 suffix_desc.suffix.clone(),
336 row_batch_coeffs,
337 mixing_coeff,
338 )?
339 .multilinear_extension(),
340 _ => Err(Error::PackingDegreeNotSupported {
341 kappa: suffix_desc.kappa,
342 }),
343 }?;
344 Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
345}