1use std::{cmp::Reverse, env, marker::PhantomData, slice::from_mut};
4
5use binius_field::{
6 as_packed_field::{PackScalar, PackedType},
7 linear_transformation::{PackedTransformationFactory, Transformation},
8 underlier::WithUnderlier,
9 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
10 RepackedExtension, TowerField,
11};
12use binius_hal::ComputationBackend;
13use binius_hash::PseudoCompressionFunction;
14use binius_math::{
15 DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
16 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
17};
18use binius_maybe_rayon::prelude::*;
19use binius_utils::bail;
20use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset, Output};
21use either::Either;
22use itertools::{chain, izip};
23use tracing::instrument;
24
25use super::{
26 channel::Boundary,
27 error::Error,
28 verify::{
29 get_post_flush_sumcheck_eval_claims_without_eq, make_flush_oracles,
30 max_n_vars_and_skip_rounds, reorder_for_flushing_by_n_vars,
31 },
32 ConstraintSystem, Proof,
33};
34use crate::{
35 constraint_system::{
36 common::{FDomain, FEncode, FExt, FFastExt},
37 exp,
38 verify::{get_flush_dedup_sumcheck_metas, FlushSumcheckMeta},
39 },
40 fiat_shamir::{CanSample, Challenger},
41 merkle_tree::BinaryMerkleTreeProver,
42 oracle::{Constraint, MultilinearOracleSet, MultilinearPolyVariant, OracleId},
43 piop,
44 protocols::{
45 fri::CommitOutput,
46 gkr_exp,
47 gkr_gpa::{self, GrandProductBatchProveOutput, GrandProductWitness, LayerClaim},
48 greedy_evalcheck::{self, GreedyEvalcheckProveOutput},
49 sumcheck::{
50 self, constraint_set_zerocheck_claim, immediate_switchover_heuristic,
51 prove::{
52 eq_ind::EqIndSumcheckProverBuilder, SumcheckProver, UnivariateZerocheckProver,
53 },
54 standard_switchover_heuristic, zerocheck,
55 },
56 },
57 ring_switch,
58 tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
59 transcript::ProverTranscript,
60 witness::{MultilinearExtensionIndex, MultilinearWitness},
61};
62
63#[instrument("constraint_system::prove", skip_all, level = "debug")]
65pub fn prove<U, Tower, Hash, Compress, Challenger_, Backend>(
66 constraint_system: &ConstraintSystem<FExt<Tower>>,
67 log_inv_rate: usize,
68 security_bits: usize,
69 boundaries: &[Boundary<FExt<Tower>>],
70 mut witness: MultilinearExtensionIndex<PackedType<U, FExt<Tower>>>,
71 backend: &Backend,
72) -> Result<Proof, Error>
73where
74 U: ProverTowerUnderlier<Tower>,
75 Tower: ProverTowerFamily,
76 Tower::B128: PackedTop<Tower>,
77 Hash: Digest + BlockSizeUser + FixedOutputReset + Send + Sync + Clone,
78 Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
79 Challenger_: Challenger + Default,
80 Backend: ComputationBackend,
81 PackedType<U, Tower::B128>: PackedTop<Tower>
83 + PackedFieldIndexable
84 + RepackedExtension<PackedType<U, Tower::B8>>
85 + RepackedExtension<PackedType<U, Tower::B16>>
86 + RepackedExtension<PackedType<U, Tower::B32>>
87 + RepackedExtension<PackedType<U, Tower::B64>>
88 + RepackedExtension<PackedType<U, Tower::B128>>
89 + PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
90 PackedType<U, Tower::FastB128>:
91 PackedFieldIndexable + PackedTransformationFactory<PackedType<U, Tower::B128>>,
92 PackedType<U, Tower::B8>: PackedFieldIndexable,
93 PackedType<U, Tower::B16>: PackedFieldIndexable,
94 PackedType<U, Tower::B32>: PackedFieldIndexable,
95 PackedType<U, Tower::B64>: PackedFieldIndexable,
96{
97 tracing::debug!(
98 arch = env::consts::ARCH,
99 rayon_threads = binius_maybe_rayon::current_num_threads(),
100 "using computation backend: {backend:?}"
101 );
102
103 let domain_factory = DefaultEvaluationDomainFactory::<FDomain<Tower>>::default();
104 let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
105
106 let mut transcript = ProverTranscript::<Challenger_>::new();
107 transcript.observe().write_slice(boundaries);
108
109 let ConstraintSystem {
110 mut oracles,
111 mut table_constraints,
112 mut flushes,
113 mut exponents,
114 non_zero_oracle_ids,
115 max_channel_id,
116 } = constraint_system.clone();
117
118 exponents.sort_by_key(|b| std::cmp::Reverse(b.n_vars(&oracles)));
119
120 let exp_witnesses = exp::make_exp_witnesses::<U, Tower>(&mut witness, &oracles, &exponents)?;
123
124 table_constraints.sort_by_key(|constraint_set| Reverse(constraint_set.n_vars));
126
127 let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
129 let merkle_scheme = merkle_prover.scheme();
130
131 let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
132 let committed_multilins = piop::collect_committed_witnesses::<U, _>(
133 &commit_meta,
134 &oracle_to_commit_index,
135 &oracles,
136 &witness,
137 )?;
138
139 let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
140 &commit_meta,
141 merkle_scheme,
142 security_bits,
143 log_inv_rate,
144 )?;
145 let CommitOutput {
146 commitment,
147 committed,
148 codeword,
149 } = piop::commit(&fri_params, &merkle_prover, &committed_multilins)?;
150
151 let mut writer = transcript.message();
153 writer.write(&commitment);
154
155 let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
157
158 let exp_evals = gkr_exp::get_evals_in_point_from_witnesses(&exp_witnesses, &exp_challenge)?
159 .into_iter()
160 .map(|x| x.into())
161 .collect::<Vec<_>>();
162
163 let mut writer = transcript.message();
164 writer.write_scalar_slice(&exp_evals);
165
166 let exp_challenge = exp_challenge
167 .into_iter()
168 .map(|x| x.into())
169 .collect::<Vec<_>>();
170
171 let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
172 .into_iter()
173 .map(|claim| claim.isomorphic())
174 .collect::<Vec<_>>();
175
176 let base_exp_output = gkr_exp::batch_prove::<_, _, FFastExt<Tower>, _, _>(
177 EvaluationOrder::HighToLow,
178 exp_witnesses,
179 &exp_claims,
180 fast_domain_factory.clone(),
181 &mut transcript,
182 backend,
183 )?
184 .isomorphic();
185
186 let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
187
188 let non_zero_fast_witnesses =
191 make_fast_masked_flush_witnesses::<U, _>(&oracles, &witness, &non_zero_oracle_ids, None)?;
192 let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
193 .into_par_iter()
194 .map(GrandProductWitness::new)
195 .collect::<Result<Vec<_>, _>>()?;
196
197 let non_zero_products =
198 gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
199 if non_zero_products
200 .iter()
201 .any(|count| *count == Tower::B128::zero())
202 {
203 bail!(Error::Zeros);
204 }
205
206 let mut writer = transcript.message();
207
208 writer.write_scalar_slice(&non_zero_products);
209
210 let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
211 &non_zero_oracle_ids,
212 &oracles,
213 &non_zero_products,
214 )?;
215
216 let mixing_challenge = transcript.sample();
218 let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
219
220 flushes.sort_by_key(|flush| flush.channel_id);
221 let flush_oracle_ids =
222 make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
223 let flush_selectors = flushes
224 .iter()
225 .map(|flush| flush.selector)
226 .collect::<Vec<_>>();
227
228 make_unmasked_flush_witnesses::<U, _>(&oracles, &mut witness, &flush_oracle_ids)?;
229 let flush_witnesses = make_fast_masked_flush_witnesses::<U, _>(
231 &oracles,
232 &witness,
233 &flush_oracle_ids,
234 Some(&flush_selectors),
235 )?;
236
237 let flush_prodcheck_witnesses = flush_witnesses
239 .into_par_iter()
240 .map(GrandProductWitness::new)
241 .collect::<Result<Vec<_>, _>>()?;
242 let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
243
244 transcript.message().write_scalar_slice(&flush_products);
245
246 let flush_prodcheck_claims =
247 gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
248
249 let all_gpa_witnesses = [flush_prodcheck_witnesses, non_zero_prodcheck_witnesses].concat();
251 let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
252 .map(|claim| claim.isomorphic())
253 .collect::<Vec<_>>();
254
255 let GrandProductBatchProveOutput { final_layer_claims } =
256 gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
257 EvaluationOrder::LowToHigh,
258 all_gpa_witnesses,
259 &all_gpa_claims,
260 &fast_domain_factory,
261 &mut transcript,
262 backend,
263 )?;
264
265 let mut final_layer_claims = final_layer_claims
267 .into_iter()
268 .map(|layer_claim| layer_claim.isomorphic())
269 .collect::<Vec<_>>();
270
271 let non_zero_final_layer_claims = final_layer_claims.split_off(flush_oracle_ids.len());
272 let flush_final_layer_claims = final_layer_claims;
273
274 let non_zero_prodcheck_eval_claims =
276 gkr_gpa::make_eval_claims(non_zero_oracle_ids, non_zero_final_layer_claims)?;
277
278 let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) =
280 reorder_for_flushing_by_n_vars(
281 &oracles,
282 &flush_oracle_ids,
283 flush_selectors,
284 flush_final_layer_claims,
285 );
286
287 let FlushSumcheckProvers {
288 provers,
289 flush_selectors_unique_by_claim,
290 flush_oracle_ids_by_claim,
291 } = get_flush_sumcheck_provers::<U, _, FDomain<Tower>, _, _>(
292 &mut oracles,
293 &flush_oracle_ids,
294 &flush_selectors,
295 &flush_final_layer_claims,
296 &mut witness,
297 &domain_factory,
298 backend,
299 )?;
300
301 let flush_sumcheck_output = sumcheck::prove::batch_prove(provers, &mut transcript)?;
302
303 let flush_eval_claims = get_post_flush_sumcheck_eval_claims_without_eq(
304 &oracles,
305 &flush_selectors_unique_by_claim,
306 &flush_oracle_ids_by_claim,
307 &flush_sumcheck_output,
308 )?;
309
310 let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
312 .iter()
313 .cloned()
314 .map(constraint_set_zerocheck_claim)
315 .collect::<Result<Vec<_>, _>>()?
316 .into_iter()
317 .unzip::<_, _, Vec<_>, Vec<_>>();
318
319 let eq_ind_sumcheck_claims = zerocheck::reduce_to_eq_ind_sumchecks(&zerocheck_claims)?;
320
321 let (max_n_vars, skip_rounds) =
322 max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
323
324 let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
325
326 let switchover_fn = standard_switchover_heuristic(-2);
327
328 let mut univariate_provers = Vec::new();
329 let mut tail_regular_zerocheck_provers = Vec::new();
330 let mut univariatized_multilinears = Vec::new();
331
332 for constraint_set in table_constraints {
333 let skip_challenges = (max_n_vars - constraint_set.n_vars).saturating_sub(skip_rounds);
334 let univariate_decider = |n_vars| n_vars > max_n_vars - skip_rounds;
335
336 let (constraints, multilinears) =
337 sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
338
339 let base_tower_level = chain!(
340 multilinears
341 .iter()
342 .map(|multilinear| 7 - multilinear.log_extension_degree()),
343 constraints
344 .iter()
345 .map(|constraint| constraint.composition.binary_tower_level())
346 )
347 .max()
348 .unwrap_or(0);
349
350 univariatized_multilinears.push(multilinears.clone());
351
352 let constructor =
353 ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _, _> {
354 constraints,
355 multilinears,
356 domain_factory: domain_factory.clone(),
357 switchover_fn,
358 zerocheck_challenges: &zerocheck_challenges[skip_challenges..],
359 backend,
360 _fdomain_marker: PhantomData,
361 };
362
363 let either_prover = match base_tower_level {
364 0..=3 => constructor.create::<Tower::B8>(univariate_decider)?,
365 4 => constructor.create::<Tower::B16>(univariate_decider)?,
366 5 => constructor.create::<Tower::B32>(univariate_decider)?,
367 6 => constructor.create::<Tower::B64>(univariate_decider)?,
368 7 => constructor.create::<Tower::B128>(univariate_decider)?,
369 _ => unreachable!(),
370 };
371
372 match either_prover {
373 Either::Left(univariate_prover) => univariate_provers.push(univariate_prover),
374 Either::Right(zerocheck_prover) => {
375 tail_regular_zerocheck_provers.push(zerocheck_prover)
376 }
377 }
378 }
379
380 let univariate_cnt = univariate_provers.len();
381
382 let univariate_output = sumcheck::prove::batch_prove_zerocheck_univariate_round(
383 univariate_provers,
384 skip_rounds,
385 &mut transcript,
386 )?;
387
388 let univariate_challenge = univariate_output.univariate_challenge;
389
390 let sumcheck_output = sumcheck::prove::batch_prove_with_start(
391 univariate_output.batch_prove_start,
392 tail_regular_zerocheck_provers,
393 &mut transcript,
394 )?;
395
396 let zerocheck_output = sumcheck::eq_ind::verify_sumcheck_outputs(
397 &eq_ind_sumcheck_claims,
398 &zerocheck_challenges,
399 sumcheck_output,
400 )?;
401
402 let mut reduction_claims = Vec::with_capacity(univariate_cnt);
403 let mut reduction_provers = Vec::with_capacity(univariate_cnt);
404
405 for (univariatized_multilinear_evals, multilinears) in
406 izip!(&zerocheck_output.multilinear_evals, univariatized_multilinears)
407 {
408 let claim_n_vars = multilinears
409 .first()
410 .map_or(0, |multilinear| multilinear.n_vars());
411
412 let skip_challenges = (max_n_vars - claim_n_vars).saturating_sub(skip_rounds);
413 let challenges = &zerocheck_output.challenges[skip_challenges..];
414 let reduced_multilinears =
415 sumcheck::prove::reduce_to_skipped_projection(multilinears, challenges, backend)?;
416
417 let claim_skip_rounds = claim_n_vars - challenges.len();
418 let reduction_claim = sumcheck::univariate::univariatizing_reduction_claim(
419 claim_skip_rounds,
420 univariatized_multilinear_evals,
421 )?;
422
423 let reduction_prover =
424 sumcheck::prove::univariatizing_reduction_prover::<_, FDomain<Tower>, _, _>(
425 reduced_multilinears,
426 univariatized_multilinear_evals,
427 univariate_challenge,
428 backend,
429 )?;
430
431 reduction_claims.push(reduction_claim);
432 reduction_provers.push(reduction_prover);
433 }
434
435 let univariatizing_output = sumcheck::prove::batch_prove(reduction_provers, &mut transcript)?;
436
437 let multilinear_zerocheck_output = sumcheck::univariate::verify_sumcheck_outputs(
438 &reduction_claims,
439 univariate_challenge,
440 &zerocheck_output.challenges,
441 univariatizing_output,
442 )?;
443
444 let zerocheck_eval_claims = sumcheck::make_eval_claims(
445 EvaluationOrder::LowToHigh,
446 zerocheck_oracle_metas,
447 multilinear_zerocheck_output,
448 )?;
449
450 let GreedyEvalcheckProveOutput {
452 eval_claims,
453 memoized_data,
454 } = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
455 &mut oracles,
456 &mut witness,
457 [non_zero_prodcheck_eval_claims, flush_eval_claims]
458 .concat()
459 .into_iter()
460 .chain(zerocheck_eval_claims)
461 .chain(exp_eval_claims),
462 switchover_fn,
463 &mut transcript,
464 &domain_factory,
465 backend,
466 )?;
467
468 let system = ring_switch::EvalClaimSystem::new(
470 &oracles,
471 &commit_meta,
472 &oracle_to_commit_index,
473 &eval_claims,
474 )?;
475
476 let ring_switch::ReducedWitness {
477 transparents: transparent_multilins,
478 sumcheck_claims: piop_sumcheck_claims,
479 } = ring_switch::prove::<_, _, _, Tower, _, _>(
480 &system,
481 &committed_multilins,
482 &mut transcript,
483 memoized_data,
484 backend,
485 )?;
486
487 piop::prove::<_, FDomain<Tower>, _, _, _, _, _, _, _, _>(
489 &fri_params,
490 &merkle_prover,
491 domain_factory,
492 &commit_meta,
493 committed,
494 &codeword,
495 &committed_multilins,
496 &transparent_multilins,
497 &piop_sumcheck_claims,
498 &mut transcript,
499 &backend,
500 )?;
501
502 Ok(Proof {
503 transcript: transcript.finalize(),
504 })
505}
506
507type TypeErasedUnivariateZerocheck<'a, F> = Box<dyn UnivariateZerocheckProver<'a, F> + 'a>;
508type TypeErasedSumcheck<'a, F> = Box<dyn SumcheckProver<F> + 'a>;
509type TypeErasedProver<'a, F> =
510 Either<TypeErasedUnivariateZerocheck<'a, F>, TypeErasedSumcheck<'a, F>>;
511
512struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
513where
514 P: PackedField,
515{
516 constraints: Vec<Constraint<P::Scalar>>,
517 multilinears: Vec<MultilinearWitness<'a, P>>,
518 domain_factory: DomainFactory,
519 switchover_fn: SwitchoverFn,
520 zerocheck_challenges: &'a [P::Scalar],
521 backend: &'a Backend,
522 _fdomain_marker: PhantomData<FDomain>,
523}
524
525impl<'a, P, F, FDomain, DomainFactory, SwitchoverFn, Backend>
526 ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
527where
528 F: Field,
529 P: PackedFieldIndexable<Scalar = F>,
530 FDomain: TowerField,
531 DomainFactory: EvaluationDomainFactory<FDomain> + 'a,
532 SwitchoverFn: Fn(usize) -> usize + Clone + 'a,
533 Backend: ComputationBackend,
534{
535 fn create<FBase>(
536 self,
537 is_univariate: impl FnOnce(usize) -> bool,
538 ) -> Result<TypeErasedProver<'a, F>, Error>
539 where
540 FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
541 P: PackedExtension<F, PackedSubfield = P>
542 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
543 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
544 F: TowerField,
545 {
546 let univariate_prover =
547 sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _, _, _>(
548 self.constraints,
549 self.multilinears,
550 self.domain_factory,
551 self.switchover_fn,
552 self.zerocheck_challenges,
553 self.backend,
554 )?;
555
556 let type_erased_prover = if is_univariate(univariate_prover.n_vars()) {
557 let type_erased_univariate_prover =
558 Box::new(univariate_prover) as TypeErasedUnivariateZerocheck<'a, P::Scalar>;
559
560 Either::Left(type_erased_univariate_prover)
561 } else {
562 let zerocheck_prover = univariate_prover.into_regular_zerocheck()?;
563 let type_erased_zerocheck_prover =
564 Box::new(zerocheck_prover) as TypeErasedSumcheck<'a, P::Scalar>;
565
566 Either::Right(type_erased_zerocheck_prover)
567 };
568
569 Ok(type_erased_prover)
570 }
571}
572
573#[instrument(skip_all, level = "debug")]
574fn make_unmasked_flush_witnesses<'a, U, Tower>(
575 oracles: &MultilinearOracleSet<FExt<Tower>>,
576 witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
577 flush_oracle_ids: &[OracleId],
578) -> Result<(), Error>
579where
580 U: ProverTowerUnderlier<Tower>,
581 Tower: ProverTowerFamily,
582{
583 let flush_witnesses: Result<Vec<MultilinearWitness<'a, _>>, Error> = flush_oracle_ids
585 .par_iter()
586 .map(|&oracle_id| {
587 let MultilinearPolyVariant::LinearCombination(lincom) =
588 oracles.oracle(oracle_id).variant
589 else {
590 unreachable!("make_flush_oracles adds linear combination oracles");
591 };
592 let polys = lincom
593 .polys()
594 .map(|id| witness.get_multilin_poly(id))
595 .collect::<Result<Vec<_>, _>>()?;
596
597 let packed_len = 1
598 << lincom
599 .n_vars()
600 .saturating_sub(<PackedType<U, FExt<Tower>>>::LOG_WIDTH);
601 let data = (0..packed_len)
602 .into_par_iter()
603 .map(|i| {
604 <PackedType<U, FExt<Tower>>>::from_fn(|j| {
605 let index = i << <PackedType<U, FExt<Tower>>>::LOG_WIDTH | j;
606 polys.iter().zip(lincom.coefficients()).fold(
607 lincom.offset(),
608 |sum, (poly, coeff)| {
609 sum + poly
610 .evaluate_on_hypercube_and_scale(index, coeff)
611 .unwrap_or(<FExt<Tower>>::ZERO)
612 },
613 )
614 })
615 })
616 .collect::<Vec<_>>();
617 let lincom_poly = MultilinearExtension::new(lincom.n_vars(), data)
618 .expect("data is constructed with the correct length with respect to n_vars");
619
620 Ok(MLEDirectAdapter::from(lincom_poly).upcast_arc_dyn())
621 })
622 .collect();
623
624 witness.update_multilin_poly(izip!(flush_oracle_ids.iter().copied(), flush_witnesses?))?;
625 Ok(())
626}
627
628#[allow(clippy::type_complexity)]
629#[instrument(skip_all, level = "debug")]
630fn make_fast_masked_flush_witnesses<'a, U, Tower>(
631 oracles: &MultilinearOracleSet<FExt<Tower>>,
632 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
633 flush_oracles: &[OracleId],
634 flush_selectors: Option<&[OracleId]>,
635) -> Result<Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>, Error>
636where
637 U: ProverTowerUnderlier<Tower>,
638 Tower: ProverTowerFamily,
639 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
640{
641 let to_fast = Tower::packed_transformation_to_fast();
642
643 flush_oracles
645 .par_iter()
646 .enumerate()
647 .map(|(i, &flush_oracle_id)| {
648 let n_vars = oracles.n_vars(flush_oracle_id);
649
650 let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
651 let width = 1 << log_width;
652
653 let packed_len = 1 << n_vars.saturating_sub(log_width);
654 let mut fast_ext_result = vec![PackedType::<U, FFastExt<Tower>>::one(); packed_len];
655
656 let poly = witness.get_multilin_poly(flush_oracle_id)?;
657 let selector = flush_selectors
658 .map(|flush_selectors| witness.get_multilin_poly(flush_selectors[i]))
659 .transpose()?;
660
661 const MAX_SUBCUBE_VARS: usize = 8;
662 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
663 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
664
665 fast_ext_result
666 .par_chunks_mut(subcube_packed_size)
667 .enumerate()
668 .for_each(|(subcube_index, fast_subcube)| {
669 let underliers =
670 PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
671
672 let subcube_evals =
673 PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
674 poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
675 .expect("witness data populated by make_unmasked_flush_witnesses()");
676
677 for underlier in underliers.iter_mut() {
678 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
679 let dest = to_fast.transform(&src);
680 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
681 }
682
683 if let Some(selector) = &selector {
684 let fast_subcube =
685 PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
686
687 let mut ones_mask = PackedType::<U, FExt<Tower>>::default();
688 for (i, packed) in fast_subcube.iter_mut().enumerate() {
689 selector
690 .subcube_evals(
691 log_width,
692 (subcube_index << subcube_vars.saturating_sub(log_width)) | i,
693 0,
694 from_mut(&mut ones_mask),
695 )
696 .expect("selector n_vars equals flushed n_vars");
697
698 if ones_mask == PackedField::zero() {
699 *packed = PackedField::one();
700 } else if ones_mask != PackedField::one() {
701 for j in 0..width {
702 if ones_mask.get(j) == FExt::<Tower>::ZERO {
703 packed.set(j, FFastExt::<Tower>::ONE);
704 }
705 }
706 }
707 }
708 }
709 });
710
711 let masked_poly = MultilinearExtension::new(n_vars, fast_ext_result)
712 .expect("data is constructed with the correct length with respect to n_vars");
713 Ok(MLEDirectAdapter::from(masked_poly).upcast_arc_dyn())
714 })
715 .collect()
716}
717
718pub struct FlushSumcheckProvers<Prover> {
719 provers: Vec<Prover>,
720 flush_oracle_ids_by_claim: Vec<Vec<OracleId>>,
721 flush_selectors_unique_by_claim: Vec<Vec<OracleId>>,
722}
723
724#[instrument(skip_all, level = "debug")]
725fn get_flush_sumcheck_provers<'a, 'b, U, Tower, FDomain, DomainFactory, Backend>(
726 oracles: &mut MultilinearOracleSet<Tower::B128>,
727 flush_oracle_ids: &[OracleId],
728 flush_selectors: &[OracleId],
729 final_layer_claims: &[LayerClaim<Tower::B128>],
730 witness: &mut MultilinearExtensionIndex<'a, PackedType<U, Tower::B128>>,
731 domain_factory: DomainFactory,
732 backend: &'b Backend,
733) -> Result<FlushSumcheckProvers<impl SumcheckProver<Tower::B128> + 'b>, Error>
734where
735 U: ProverTowerUnderlier<Tower> + PackScalar<FDomain>,
736 Tower: ProverTowerFamily,
737 Tower::B128: ExtensionField<FDomain>,
738 FDomain: Field,
739 DomainFactory: EvaluationDomainFactory<FDomain>,
740 Backend: ComputationBackend,
741 PackedType<U, Tower::B128>: PackedFieldIndexable,
742 'a: 'b,
743{
744 let flush_sumcheck_metas = get_flush_dedup_sumcheck_metas(
745 oracles,
746 flush_oracle_ids,
747 flush_selectors,
748 final_layer_claims,
749 )?;
750
751 let n_claims = flush_sumcheck_metas.len();
752 let mut provers = Vec::with_capacity(n_claims);
753 let mut flush_oracle_ids_by_claim = Vec::with_capacity(n_claims);
754 let mut flush_selectors_unique_by_claim = Vec::with_capacity(n_claims);
755 for flush_sumcheck_meta in flush_sumcheck_metas {
756 let FlushSumcheckMeta {
757 composite_sum_claims,
758 flush_selectors_unique,
759 flush_oracle_ids,
760 eval_point,
761 } = flush_sumcheck_meta;
762
763 let mut multilinears =
764 Vec::with_capacity(flush_selectors_unique.len() + flush_oracle_ids.len());
765
766 let mut nonzero_scalars_prefixes = Vec::with_capacity(multilinears.len());
767
768 for &oracle_id in chain!(&flush_selectors_unique, &flush_oracle_ids) {
769 let entry = witness.get_index_entry(oracle_id)?;
770 multilinears.push(entry.multilin_poly);
771 nonzero_scalars_prefixes.push(entry.nonzero_scalars_prefix);
772 }
773
774 let prover = EqIndSumcheckProverBuilder::new(backend)
775 .with_nonzero_scalars_prefixes(&nonzero_scalars_prefixes)
776 .build(
777 EvaluationOrder::LowToHigh,
778 multilinears,
779 &eval_point,
780 composite_sum_claims,
781 domain_factory.clone(),
782 immediate_switchover_heuristic,
783 )?;
784
785 provers.push(prover);
786 flush_oracle_ids_by_claim.push(flush_oracle_ids);
787 flush_selectors_unique_by_claim.push(flush_selectors_unique);
788 }
789
790 Ok(FlushSumcheckProvers {
791 provers,
792 flush_selectors_unique_by_claim,
793 flush_oracle_ids_by_claim,
794 })
795}