1use std::{iter, sync::Arc};
4
5use binius_field::{
6 tower::{PackedTop, TowerFamily},
7 PackedField, PackedFieldIndexable, TowerField,
8};
9use binius_hal::ComputationBackend;
10use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::checked_arithmetics::log2_ceil_usize;
13use tracing::instrument;
14
15use super::{
16 common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
17 eq_ind::RowBatchCoeffs,
18 error::Error,
19 logging::MLEFoldHisgDimensionsData,
20 tower_tensor_algebra::TowerTensorAlgebra,
21};
22use crate::{
23 fiat_shamir::{CanSample, Challenger},
24 piop::PIOPSumcheckClaim,
25 protocols::evalcheck::subclaims::MemoizedData,
26 ring_switch::{
27 common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd, logging::CalculateRingSwitchEqIndData,
28 },
29 transcript::ProverTranscript,
30 witness::MultilinearWitness,
31};
32
33type FExt<Tower> = <Tower as TowerFamily>::B128;
34
35#[derive(Debug)]
36pub struct ReducedWitness<P: PackedField> {
37 pub transparents: Vec<MultilinearWitness<'static, P>>,
38 pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
39}
40
41pub fn prove<F, P, M, Tower, Challenger_, Backend>(
42 system: &EvalClaimSystem<F>,
43 witnesses: &[M],
44 transcript: &mut ProverTranscript<Challenger_>,
45 memoized_data: MemoizedData<P, Backend>,
46 backend: &Backend,
47) -> Result<ReducedWitness<P>, Error>
48where
49 F: TowerField + PackedTop<Tower>,
50 P: PackedFieldIndexable<Scalar = F>,
51 M: MultilinearPoly<P> + Sync,
52 Tower: TowerFamily<B128 = F>,
53 Challenger_: Challenger,
54 Backend: ComputationBackend,
55{
56 if witnesses.len() != system.commit_meta.total_multilins() {
57 return Err(Error::InvalidWitness(
58 "witness length does not match the number of multilinears".into(),
59 ));
60 }
61
62 let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
65 let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
66 let dimensions_data = MLEFoldHisgDimensionsData::new(witnesses);
67 let mle_fold_high_span = tracing::debug_span!(
68 "[task] (Ring Switch) MLE Fold High",
69 phase = "ring_switch",
70 perfetto_category = "task.main",
71 dimensions_data = ?dimensions_data,
72 )
73 .entered();
74
75 let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
76
77 let tensor_elems =
79 compute_partial_evals::<_, _, _, Tower, _>(system, witnesses, memoized_data, backend)?;
80 let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
81 let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
82 &scaled_tensor_elems,
83 &system.prefix_descs,
84 &system.eval_claim_to_prefix_desc_index,
85 )?;
86 drop(mle_fold_high_span);
87 let mut writer = transcript.message();
88 for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
89 debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
90 writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
91 }
92
93 let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
95 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
96 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
97 ));
98
99 let row_batched_evals =
100 compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
101 transcript.message().write_scalar_slice(&row_batched_evals);
102
103 let dimensions_data = CalculateRingSwitchEqIndData::new(system.suffix_descs.iter());
105 let calculate_ring_switch_eq_ind_span = tracing::debug_span!(
106 "[task] Calculate Ring Switch Eq Ind",
107 phase = "ring_switch",
108 perfetto_category = "task.main",
109 dimensions_data = ?dimensions_data,
110 )
111 .entered();
112
113 let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
114 &system.sumcheck_claim_descs,
115 &system.suffix_descs,
116 row_batch_coeffs,
117 &mixing_coeffs,
118 )?;
119 drop(calculate_ring_switch_eq_ind_span);
120
121 let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
122 .enumerate()
123 .map(|(idx, (claim_desc, eval))| {
124 let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
125 PIOPSumcheckClaim {
126 n_vars: suffix_desc.suffix.len(),
127 committed: claim_desc.committed_idx,
128 transparent: idx,
129 sum: eval,
130 }
131 })
132 .collect::<Vec<_>>();
133
134 Ok(ReducedWitness {
135 transparents: ring_switch_eq_inds,
136 sumcheck_claims,
137 })
138}
139
140#[instrument(skip_all)]
141fn compute_partial_evals<F, P, M, Tower, Backend>(
142 system: &EvalClaimSystem<F>,
143 witnesses: &[M],
144 mut memoized_data: MemoizedData<P, Backend>,
145 backend: &Backend,
146) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
147where
148 F: TowerField,
149 P: PackedField<Scalar = F>,
150 M: MultilinearPoly<P> + Sync,
151 Tower: TowerFamily<B128 = F>,
152 Backend: ComputationBackend,
153{
154 let suffixes = system
155 .suffix_descs
156 .iter()
157 .map(|desc| Arc::as_ref(&desc.suffix))
158 .collect::<Vec<_>>();
159
160 memoized_data.memoize_query_par(suffixes, backend)?;
161
162 let tensor_elems = system
163 .sumcheck_claim_descs
164 .par_iter()
165 .map(
166 |PIOPSumcheckClaimDesc {
167 committed_idx,
168 suffix_desc_idx,
169 eval_claim,
170 }| {
171 let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
172
173 let mut elems = if let Some(partial_eval) =
174 memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
175 {
176 PackedField::iter_slice(
177 partial_eval.packed_evals().expect("packed_evals exist"),
178 )
179 .take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
180 .collect::<Vec<_>>()
181 } else {
182 let suffix_query = memoized_data
183 .full_query_readonly(&suffix_desc.suffix)
184 .expect("memoized above");
185 let partial_eval =
186 witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
187 PackedField::iter_slice(partial_eval.evals())
188 .take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
189 .collect::<Vec<_>>()
190 };
191
192 if elems.len() < (1 << suffix_desc.kappa) {
193 elems = elems
194 .into_iter()
195 .cycle()
196 .take(1 << suffix_desc.kappa)
197 .collect();
198 }
199 TowerTensorAlgebra::new(suffix_desc.kappa, elems)
200 },
201 )
202 .collect::<Result<Vec<_>, _>>()?;
203
204 Ok(tensor_elems)
205}
206
207fn scale_tensor_elems<F, Tower>(
208 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
209 mixing_coeffs: &[F],
210) -> Vec<TowerTensorAlgebra<Tower>>
211where
212 F: TowerField,
213 Tower: TowerFamily<B128 = F>,
214{
215 assert!(tensor_elems.len() <= mixing_coeffs.len());
217
218 tensor_elems
219 .into_par_iter()
220 .zip(mixing_coeffs)
221 .map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
222 .collect()
223}
224
225fn mix_tensor_elems_for_prefixes<F, Tower>(
226 scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
227 prefix_descs: &[EvalClaimPrefixDesc<F>],
228 eval_claim_to_prefix_desc_index: &[usize],
229) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
230where
231 F: TowerField,
232 Tower: TowerFamily<B128 = F>,
233{
234 assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
236
237 let mut batched_tensor_elems = prefix_descs
238 .iter()
239 .map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
240 .collect::<Result<Vec<_>, _>>()?;
241 for (tensor_elem, &desc_index) in
242 iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
243 {
244 let mixed_val = &mut batched_tensor_elems[desc_index];
245 debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
246 mixed_val.add_assign(tensor_elem)?;
247 }
248 Ok(batched_tensor_elems)
249}
250
251#[instrument(skip_all)]
252fn compute_row_batched_sumcheck_evals<F, Tower>(
253 tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
254 row_batch_coeffs: &[F],
255) -> Vec<F>
256where
257 F: TowerField,
258 Tower: TowerFamily<B128 = F>,
259 F: PackedTop<Tower>,
260{
261 tensor_elems
262 .into_par_iter()
263 .map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
264 .collect()
265}
266
267fn make_ring_switch_eq_inds<F, P, Tower>(
268 sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
269 suffix_descs: &[EvalClaimSuffixDesc<F>],
270 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
271 mixing_coeffs: &[F],
272) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
273where
274 F: TowerField + PackedTop<Tower>,
275 P: PackedFieldIndexable<Scalar = F>,
276 Tower: TowerFamily<B128 = F>,
277{
278 sumcheck_claim_descs
279 .par_iter()
280 .zip(mixing_coeffs)
281 .map(|(claim_desc, &mixing_coeff)| {
282 let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
283 make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
284 })
285 .collect()
286}
287
288fn make_ring_switch_eq_ind<P, Tower>(
289 suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
290 row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
291 mixing_coeff: FExt<Tower>,
292) -> Result<MultilinearWitness<'static, P>, Error>
293where
294 P: PackedFieldIndexable<Scalar = FExt<Tower>>,
295 Tower: TowerFamily,
296{
297 let eq_ind = match suffix_desc.kappa {
298 7 => RingSwitchEqInd::<Tower::B1, _>::new(
299 suffix_desc.suffix.clone(),
300 row_batch_coeffs,
301 mixing_coeff,
302 )?
303 .multilinear_extension::<P>(),
304 4 => RingSwitchEqInd::<Tower::B8, _>::new(
305 suffix_desc.suffix.clone(),
306 row_batch_coeffs,
307 mixing_coeff,
308 )?
309 .multilinear_extension(),
310 3 => RingSwitchEqInd::<Tower::B16, _>::new(
311 suffix_desc.suffix.clone(),
312 row_batch_coeffs,
313 mixing_coeff,
314 )?
315 .multilinear_extension(),
316 2 => RingSwitchEqInd::<Tower::B32, _>::new(
317 suffix_desc.suffix.clone(),
318 row_batch_coeffs,
319 mixing_coeff,
320 )?
321 .multilinear_extension(),
322 1 => RingSwitchEqInd::<Tower::B64, _>::new(
323 suffix_desc.suffix.clone(),
324 row_batch_coeffs,
325 mixing_coeff,
326 )?
327 .multilinear_extension(),
328 0 => RingSwitchEqInd::<Tower::B128, _>::new(
329 suffix_desc.suffix.clone(),
330 row_batch_coeffs,
331 mixing_coeff,
332 )?
333 .multilinear_extension(),
334 _ => Err(Error::PackingDegreeNotSupported {
335 kappa: suffix_desc.kappa,
336 }),
337 }?;
338 Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
339}