1use std::{cmp::Reverse, env, marker::PhantomData, slice::from_mut};
4
5use binius_field::{
6 as_packed_field::{PackScalar, PackedType},
7 linear_transformation::{PackedTransformationFactory, Transformation},
8 underlier::WithUnderlier,
9 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
10 RepackedExtension, TowerField,
11};
12use binius_hal::ComputationBackend;
13use binius_hash::PseudoCompressionFunction;
14use binius_math::{
15 EvaluationDomainFactory, EvaluationOrder, IsomorphicEvaluationDomainFactory, MLEDirectAdapter,
16 MultilinearExtension, MultilinearPoly,
17};
18use binius_maybe_rayon::prelude::*;
19use binius_utils::bail;
20use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset, Output};
21use either::Either;
22use itertools::{chain, izip};
23use tracing::instrument;
24
25use super::{
26 channel::Boundary,
27 error::Error,
28 verify::{
29 get_post_flush_sumcheck_eval_claims_without_eq, make_flush_oracles,
30 max_n_vars_and_skip_rounds, reorder_for_flushing_by_n_vars,
31 },
32 ConstraintSystem, Proof,
33};
34use crate::{
35 constraint_system::{
36 common::{FDomain, FEncode, FExt, FFastExt},
37 verify::{get_flush_dedup_sumcheck_metas, FlushSumcheckMeta},
38 },
39 fiat_shamir::{CanSample, Challenger},
40 merkle_tree::BinaryMerkleTreeProver,
41 oracle::{Constraint, MultilinearOracleSet, MultilinearPolyVariant, OracleId},
42 piop,
43 protocols::{
44 fri::CommitOutput,
45 gkr_gpa::{
46 self, gpa_sumcheck::prove::GPAProver, GrandProductBatchProveOutput,
47 GrandProductWitness, LayerClaim,
48 },
49 greedy_evalcheck,
50 sumcheck::{
51 self, constraint_set_zerocheck_claim,
52 prove::{SumcheckProver, UnivariateZerocheckProver},
53 standard_switchover_heuristic, zerocheck,
54 },
55 },
56 ring_switch,
57 tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
58 transcript::ProverTranscript,
59 witness::{MultilinearExtensionIndex, MultilinearWitness},
60};
61
62#[instrument("constraint_system::prove", skip_all, level = "debug")]
64pub fn prove<U, Tower, DomainFactory, Hash, Compress, Challenger_, Backend>(
65 constraint_system: &ConstraintSystem<FExt<Tower>>,
66 log_inv_rate: usize,
67 security_bits: usize,
68 boundaries: &[Boundary<FExt<Tower>>],
69 mut witness: MultilinearExtensionIndex<U, FExt<Tower>>,
70 domain_factory: DomainFactory,
71 backend: &Backend,
72) -> Result<Proof, Error>
73where
74 U: ProverTowerUnderlier<Tower>,
75 Tower: ProverTowerFamily,
76 Tower::B128: PackedTop<Tower>,
77 DomainFactory: EvaluationDomainFactory<FDomain<Tower>>,
78 Hash: Digest + BlockSizeUser + FixedOutputReset,
79 Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
80 Challenger_: Challenger + Default,
81 Backend: ComputationBackend,
82 PackedType<U, Tower::B128>: PackedTop<Tower>
84 + PackedFieldIndexable
85 + RepackedExtension<PackedType<U, Tower::B8>>
86 + RepackedExtension<PackedType<U, Tower::B16>>
87 + RepackedExtension<PackedType<U, Tower::B32>>
88 + RepackedExtension<PackedType<U, Tower::B64>>
89 + RepackedExtension<PackedType<U, Tower::B128>>
90 + PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
91 PackedType<U, Tower::FastB128>:
92 PackedFieldIndexable + PackedTransformationFactory<PackedType<U, Tower::B128>>,
93 PackedType<U, Tower::B8>: PackedFieldIndexable,
94 PackedType<U, Tower::B16>: PackedFieldIndexable,
95 PackedType<U, Tower::B32>: PackedFieldIndexable,
96 PackedType<U, Tower::B64>: PackedFieldIndexable,
97{
98 tracing::debug!(
99 arch = env::consts::ARCH,
100 rayon_threads = binius_maybe_rayon::current_num_threads(),
101 "using computation backend: {backend:?}"
102 );
103
104 let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
105
106 let mut transcript = ProverTranscript::<Challenger_>::new();
107 transcript.observe().write_slice(boundaries);
108
109 let ConstraintSystem {
110 mut oracles,
111 mut table_constraints,
112 mut flushes,
113 non_zero_oracle_ids,
114 max_channel_id,
115 } = constraint_system.clone();
116
117 table_constraints.sort_by_key(|constraint_set| Reverse(constraint_set.n_vars));
119
120 let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
122 let merkle_scheme = merkle_prover.scheme();
123
124 let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
125 let committed_multilins = piop::collect_committed_witnesses(
126 &commit_meta,
127 &oracle_to_commit_index,
128 &oracles,
129 &witness,
130 )?;
131
132 let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
133 &commit_meta,
134 merkle_scheme,
135 security_bits,
136 log_inv_rate,
137 )?;
138 let CommitOutput {
139 commitment,
140 committed,
141 codeword,
142 } = piop::commit(&fri_params, &merkle_prover, &committed_multilins)?;
143
144 let mut writer = transcript.message();
146 writer.write(&commitment);
147
148 let non_zero_fast_witnesses =
151 make_fast_masked_flush_witnesses(&oracles, &witness, &non_zero_oracle_ids, None)?;
152 let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
153 .into_par_iter()
154 .map(GrandProductWitness::new)
155 .collect::<Result<Vec<_>, _>>()?;
156
157 let non_zero_products =
158 gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
159 if non_zero_products
160 .iter()
161 .any(|count| *count == Tower::B128::zero())
162 {
163 bail!(Error::Zeros);
164 }
165
166 writer.write_scalar_slice(&non_zero_products);
167
168 let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
169 &non_zero_oracle_ids,
170 &oracles,
171 &non_zero_products,
172 )?;
173
174 let mixing_challenge = transcript.sample();
176 let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
177
178 flushes.sort_by_key(|flush| flush.channel_id);
179 let flush_oracle_ids =
180 make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
181 let flush_selectors = flushes
182 .iter()
183 .map(|flush| flush.selector)
184 .collect::<Vec<_>>();
185
186 make_unmasked_flush_witnesses(&oracles, &mut witness, &flush_oracle_ids)?;
187 let flush_witnesses = make_fast_masked_flush_witnesses(
189 &oracles,
190 &witness,
191 &flush_oracle_ids,
192 Some(&flush_selectors),
193 )?;
194
195 let flush_prodcheck_witnesses = flush_witnesses
197 .into_par_iter()
198 .map(GrandProductWitness::new)
199 .collect::<Result<Vec<_>, _>>()?;
200 let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
201
202 transcript.message().write_scalar_slice(&flush_products);
203
204 let flush_prodcheck_claims =
205 gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
206
207 let all_gpa_witnesses = [flush_prodcheck_witnesses, non_zero_prodcheck_witnesses].concat();
209 let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
210 .map(|claim| claim.isomorphic())
211 .collect::<Vec<_>>();
212
213 let GrandProductBatchProveOutput { final_layer_claims } =
214 gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
215 EvaluationOrder::LowToHigh,
216 all_gpa_witnesses,
217 &all_gpa_claims,
218 &fast_domain_factory,
219 &mut transcript,
220 backend,
221 )?;
222
223 let mut final_layer_claims = final_layer_claims
225 .into_iter()
226 .map(|layer_claim| layer_claim.isomorphic())
227 .collect::<Vec<_>>();
228
229 let non_zero_final_layer_claims = final_layer_claims.split_off(flush_oracle_ids.len());
230 let flush_final_layer_claims = final_layer_claims;
231
232 let non_zero_prodcheck_eval_claims =
234 gkr_gpa::make_eval_claims(non_zero_oracle_ids, non_zero_final_layer_claims)?;
235
236 let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) =
238 reorder_for_flushing_by_n_vars(
239 &oracles,
240 &flush_oracle_ids,
241 flush_selectors,
242 flush_final_layer_claims,
243 );
244
245 let FlushSumcheckProvers {
246 provers,
247 flush_selectors_unique_by_claim,
248 flush_oracle_ids_by_claim,
249 } = get_flush_sumcheck_provers::<_, _, FDomain<Tower>, _, _>(
250 &mut oracles,
251 &flush_oracle_ids,
252 &flush_selectors,
253 &flush_final_layer_claims,
254 &mut witness,
255 &domain_factory,
256 backend,
257 )?;
258
259 let flush_sumcheck_output = sumcheck::prove::batch_prove(provers, &mut transcript)?;
260
261 let flush_eval_claims = get_post_flush_sumcheck_eval_claims_without_eq(
262 &oracles,
263 &flush_selectors_unique_by_claim,
264 &flush_oracle_ids_by_claim,
265 &flush_sumcheck_output,
266 )?;
267
268 let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
270 .iter()
271 .cloned()
272 .map(constraint_set_zerocheck_claim)
273 .collect::<Result<Vec<_>, _>>()?
274 .into_iter()
275 .unzip::<_, _, Vec<_>, Vec<_>>();
276
277 let (max_n_vars, skip_rounds) =
278 max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
279
280 let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
281
282 let switchover_fn = standard_switchover_heuristic(-2);
283
284 let mut univariate_provers = Vec::new();
285 let mut tail_regular_zerocheck_provers = Vec::new();
286 let mut univariatized_multilinears = Vec::new();
287
288 for constraint_set in table_constraints {
289 let skip_challenges = (max_n_vars - constraint_set.n_vars).saturating_sub(skip_rounds);
290 let univariate_decider = |n_vars| n_vars > max_n_vars - skip_rounds;
291
292 let (constraints, multilinears) =
293 sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
294
295 let base_tower_level = chain!(
296 multilinears
297 .iter()
298 .map(|multilinear| 7 - multilinear.log_extension_degree()),
299 constraints
300 .iter()
301 .map(|constraint| constraint.composition.binary_tower_level())
302 )
303 .max()
304 .unwrap_or(0);
305
306 univariatized_multilinears.push(multilinears.clone());
307
308 let constructor =
309 ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _, _> {
310 constraints,
311 multilinears,
312 domain_factory: &domain_factory,
313 switchover_fn,
314 zerocheck_challenges: &zerocheck_challenges[skip_challenges..],
315 backend,
316 _fdomain_marker: PhantomData,
317 };
318
319 let either_prover = match base_tower_level {
320 0..=3 => constructor.create::<Tower::B8>(univariate_decider)?,
321 4 => constructor.create::<Tower::B16>(univariate_decider)?,
322 5 => constructor.create::<Tower::B32>(univariate_decider)?,
323 6 => constructor.create::<Tower::B64>(univariate_decider)?,
324 7 => constructor.create::<Tower::B128>(univariate_decider)?,
325 _ => unreachable!(),
326 };
327
328 match either_prover {
329 Either::Left(univariate_prover) => univariate_provers.push(univariate_prover),
330 Either::Right(zerocheck_prover) => {
331 tail_regular_zerocheck_provers.push(zerocheck_prover)
332 }
333 }
334 }
335
336 let univariate_cnt = univariate_provers.len();
337
338 let univariate_output = sumcheck::prove::batch_prove_zerocheck_univariate_round(
339 univariate_provers,
340 skip_rounds,
341 &mut transcript,
342 )?;
343
344 let univariate_challenge = univariate_output.univariate_challenge;
345
346 let sumcheck_output = sumcheck::prove::batch_prove_with_start(
347 univariate_output.batch_prove_start,
348 tail_regular_zerocheck_provers,
349 &mut transcript,
350 )?;
351
352 let zerocheck_output = zerocheck::verify_sumcheck_outputs(
353 &zerocheck_claims,
354 &zerocheck_challenges,
355 sumcheck_output,
356 )?;
357
358 let mut reduction_claims = Vec::with_capacity(univariate_cnt);
359 let mut reduction_provers = Vec::with_capacity(univariate_cnt);
360
361 for (univariatized_multilinear_evals, multilinears) in
362 izip!(&zerocheck_output.multilinear_evals, univariatized_multilinears)
363 {
364 let claim_n_vars = multilinears
365 .first()
366 .map_or(0, |multilinear| multilinear.n_vars());
367
368 let skip_challenges = (max_n_vars - claim_n_vars).saturating_sub(skip_rounds);
369 let challenges = &zerocheck_output.challenges[skip_challenges..];
370 let reduced_multilinears =
371 sumcheck::prove::reduce_to_skipped_projection(multilinears, challenges, backend)?;
372
373 let claim_skip_rounds = claim_n_vars - challenges.len();
374 let reduction_claim = sumcheck::univariate::univariatizing_reduction_claim(
375 claim_skip_rounds,
376 univariatized_multilinear_evals,
377 )?;
378
379 let reduction_prover =
380 sumcheck::prove::univariatizing_reduction_prover::<_, FDomain<Tower>, _, _>(
381 reduced_multilinears,
382 univariatized_multilinear_evals,
383 univariate_challenge,
384 &domain_factory,
385 backend,
386 )?;
387
388 reduction_claims.push(reduction_claim);
389 reduction_provers.push(reduction_prover);
390 }
391
392 let univariatizing_output = sumcheck::prove::batch_prove(reduction_provers, &mut transcript)?;
393
394 let multilinear_zerocheck_output = sumcheck::univariate::verify_sumcheck_outputs(
395 &reduction_claims,
396 univariate_challenge,
397 &zerocheck_output.challenges,
398 univariatizing_output,
399 )?;
400
401 let zerocheck_eval_claims =
402 sumcheck::make_eval_claims(zerocheck_oracle_metas, multilinear_zerocheck_output)?;
403
404 let eval_claims = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
406 &mut oracles,
407 &mut witness,
408 [non_zero_prodcheck_eval_claims, flush_eval_claims]
409 .concat()
410 .into_iter()
411 .chain(zerocheck_eval_claims),
412 switchover_fn,
413 &mut transcript,
414 &domain_factory,
415 backend,
416 )?;
417
418 let system = ring_switch::EvalClaimSystem::new(
420 &oracles,
421 &commit_meta,
422 &oracle_to_commit_index,
423 &eval_claims,
424 )?;
425
426 let ring_switch::ReducedWitness {
427 transparents: transparent_multilins,
428 sumcheck_claims: piop_sumcheck_claims,
429 } = ring_switch::prove::<_, _, _, Tower, _, _>(
430 &system,
431 &committed_multilins,
432 &mut transcript,
433 backend,
434 )?;
435
436 piop::prove::<_, FDomain<Tower>, _, _, _, _, _, _, _, _>(
438 &fri_params,
439 &merkle_prover,
440 domain_factory,
441 &commit_meta,
442 committed,
443 &codeword,
444 &committed_multilins,
445 &transparent_multilins,
446 &piop_sumcheck_claims,
447 &mut transcript,
448 &backend,
449 )?;
450
451 Ok(Proof {
452 transcript: transcript.finalize(),
453 })
454}
455
456type TypeErasedUnivariateZerocheck<'a, F> = Box<dyn UnivariateZerocheckProver<'a, F> + 'a>;
457type TypeErasedSumcheck<'a, F> = Box<dyn SumcheckProver<F> + 'a>;
458type TypeErasedProver<'a, F> =
459 Either<TypeErasedUnivariateZerocheck<'a, F>, TypeErasedSumcheck<'a, F>>;
460
461struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
462where
463 P: PackedField,
464{
465 constraints: Vec<Constraint<P::Scalar>>,
466 multilinears: Vec<MultilinearWitness<'a, P>>,
467 domain_factory: DomainFactory,
468 switchover_fn: SwitchoverFn,
469 zerocheck_challenges: &'a [P::Scalar],
470 backend: &'a Backend,
471 _fdomain_marker: PhantomData<FDomain>,
472}
473
474impl<'a, P, F, FDomain, DomainFactory, SwitchoverFn, Backend>
475 ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
476where
477 F: Field,
478 P: PackedFieldIndexable<Scalar = F>,
479 FDomain: TowerField,
480 DomainFactory: EvaluationDomainFactory<FDomain>,
481 SwitchoverFn: Fn(usize) -> usize + Clone,
482 Backend: ComputationBackend,
483{
484 fn create<FBase>(
485 self,
486 is_univariate: impl FnOnce(usize) -> bool,
487 ) -> Result<TypeErasedProver<'a, F>, Error>
488 where
489 FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
490 P: PackedExtension<F, PackedSubfield = P>
491 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
492 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
493 F: TowerField,
494 {
495 let univariate_prover =
496 sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _>(
497 self.constraints,
498 self.multilinears,
499 self.domain_factory,
500 self.switchover_fn,
501 self.zerocheck_challenges,
502 self.backend,
503 )?;
504
505 let type_erased_prover = if is_univariate(univariate_prover.n_vars()) {
506 let type_erased_univariate_prover =
507 Box::new(univariate_prover) as TypeErasedUnivariateZerocheck<'a, P::Scalar>;
508
509 Either::Left(type_erased_univariate_prover)
510 } else {
511 let zerocheck_prover = univariate_prover.into_regular_zerocheck()?;
512 let type_erased_zerocheck_prover =
513 Box::new(zerocheck_prover) as TypeErasedSumcheck<'a, P::Scalar>;
514
515 Either::Right(type_erased_zerocheck_prover)
516 };
517
518 Ok(type_erased_prover)
519 }
520}
521
522#[instrument(skip_all, level = "debug")]
523fn make_unmasked_flush_witnesses<'a, U, Tower>(
524 oracles: &MultilinearOracleSet<FExt<Tower>>,
525 witness: &mut MultilinearExtensionIndex<'a, U, FExt<Tower>>,
526 flush_oracle_ids: &[OracleId],
527) -> Result<(), Error>
528where
529 U: ProverTowerUnderlier<Tower>,
530 Tower: ProverTowerFamily,
531{
532 let flush_witnesses: Result<Vec<MultilinearWitness<'a, _>>, Error> = flush_oracle_ids
534 .par_iter()
535 .map(|&oracle_id| {
536 let MultilinearPolyVariant::LinearCombination(lincom) =
537 oracles.oracle(oracle_id).variant
538 else {
539 unreachable!("make_flush_oracles adds linear combination oracles");
540 };
541 let polys = lincom
542 .polys()
543 .map(|id| witness.get_multilin_poly(id))
544 .collect::<Result<Vec<_>, _>>()?;
545
546 let packed_len = 1
547 << lincom
548 .n_vars()
549 .saturating_sub(<PackedType<U, FExt<Tower>>>::LOG_WIDTH);
550 let data = (0..packed_len)
551 .into_par_iter()
552 .map(|i| {
553 <PackedType<U, FExt<Tower>>>::from_fn(|j| {
554 let index = i << <PackedType<U, FExt<Tower>>>::LOG_WIDTH | j;
555 polys.iter().zip(lincom.coefficients()).fold(
556 lincom.offset(),
557 |sum, (poly, coeff)| {
558 sum + poly
559 .evaluate_on_hypercube_and_scale(index, coeff)
560 .unwrap_or(<FExt<Tower>>::ZERO)
561 },
562 )
563 })
564 })
565 .collect::<Vec<_>>();
566 let lincom_poly = MultilinearExtension::new(lincom.n_vars(), data)
567 .expect("data is constructed with the correct length with respect to n_vars");
568
569 Ok(MLEDirectAdapter::from(lincom_poly).upcast_arc_dyn())
570 })
571 .collect();
572
573 witness.update_multilin_poly(izip!(flush_oracle_ids.iter().copied(), flush_witnesses?))?;
574 Ok(())
575}
576
577#[allow(clippy::type_complexity)]
578#[instrument(skip_all, level = "debug")]
579fn make_fast_masked_flush_witnesses<'a, U, Tower>(
580 oracles: &MultilinearOracleSet<FExt<Tower>>,
581 witness: &MultilinearExtensionIndex<'a, U, FExt<Tower>>,
582 flush_oracles: &[OracleId],
583 flush_selectors: Option<&[OracleId]>,
584) -> Result<Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>, Error>
585where
586 U: ProverTowerUnderlier<Tower>,
587 Tower: ProverTowerFamily,
588 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
589{
590 let to_fast = Tower::packed_transformation_to_fast();
591
592 flush_oracles
594 .par_iter()
595 .enumerate()
596 .map(|(i, &flush_oracle_id)| {
597 let n_vars = oracles.n_vars(flush_oracle_id);
598
599 let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
600 let width = 1 << log_width;
601
602 let packed_len = 1 << n_vars.saturating_sub(log_width);
603 let mut fast_ext_result = vec![PackedType::<U, FFastExt<Tower>>::one(); packed_len];
604
605 let poly = witness.get_multilin_poly(flush_oracle_id)?;
606 let selector = flush_selectors
607 .map(|flush_selectors| witness.get_multilin_poly(flush_selectors[i]))
608 .transpose()?;
609
610 const MAX_SUBCUBE_VARS: usize = 8;
611 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
612 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
613
614 fast_ext_result
615 .par_chunks_mut(subcube_packed_size)
616 .enumerate()
617 .for_each(|(subcube_index, fast_subcube)| {
618 let underliers =
619 PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
620
621 let subcube_evals =
622 PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
623 poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
624 .expect("witness data populated by make_unmasked_flush_witnesses()");
625
626 for underlier in underliers.iter_mut() {
627 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
628 let dest = to_fast.transform(&src);
629 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
630 }
631
632 if let Some(selector) = &selector {
633 let fast_subcube =
634 PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
635
636 let mut ones_mask = PackedType::<U, FExt<Tower>>::default();
637 for (i, packed) in fast_subcube.iter_mut().enumerate() {
638 selector
639 .subcube_evals(
640 log_width,
641 (subcube_index << subcube_vars.saturating_sub(log_width)) | i,
642 0,
643 from_mut(&mut ones_mask),
644 )
645 .expect("selector n_vars equals flushed n_vars");
646
647 if ones_mask == PackedField::zero() {
648 *packed = PackedField::one();
649 } else if ones_mask != PackedField::one() {
650 for j in 0..width {
651 if ones_mask.get(j) == FExt::<Tower>::ZERO {
652 packed.set(j, FFastExt::<Tower>::ONE);
653 }
654 }
655 }
656 }
657 }
658 });
659
660 let masked_poly = MultilinearExtension::new(n_vars, fast_ext_result)
661 .expect("data is constructed with the correct length with respect to n_vars");
662 Ok(MLEDirectAdapter::from(masked_poly).upcast_arc_dyn())
663 })
664 .collect()
665}
666
667pub struct FlushSumcheckProvers<Prover> {
668 provers: Vec<Prover>,
669 flush_oracle_ids_by_claim: Vec<Vec<OracleId>>,
670 flush_selectors_unique_by_claim: Vec<Vec<OracleId>>,
671}
672
673#[instrument(skip_all, level = "debug")]
674fn get_flush_sumcheck_provers<'a, 'b, U, Tower, FDomain, DomainFactory, Backend>(
675 oracles: &mut MultilinearOracleSet<Tower::B128>,
676 flush_oracle_ids: &[OracleId],
677 flush_selectors: &[OracleId],
678 final_layer_claims: &[LayerClaim<Tower::B128>],
679 witness: &mut MultilinearExtensionIndex<'a, U, Tower::B128>,
680 domain_factory: DomainFactory,
681 backend: &'b Backend,
682) -> Result<FlushSumcheckProvers<impl SumcheckProver<Tower::B128> + 'b>, Error>
683where
684 U: ProverTowerUnderlier<Tower> + PackScalar<FDomain>,
685 Tower: ProverTowerFamily,
686 Tower::B128: ExtensionField<FDomain>,
687 FDomain: Field,
688 DomainFactory: EvaluationDomainFactory<FDomain>,
689 Backend: ComputationBackend,
690 PackedType<U, Tower::B128>: PackedFieldIndexable,
691 'a: 'b,
692{
693 let flush_sumcheck_metas = get_flush_dedup_sumcheck_metas(
694 oracles,
695 flush_oracle_ids,
696 flush_selectors,
697 final_layer_claims,
698 )?;
699
700 let n_claims = flush_sumcheck_metas.len();
701 let mut provers = Vec::with_capacity(n_claims);
702 let mut flush_oracle_ids_by_claim = Vec::with_capacity(n_claims);
703 let mut flush_selectors_unique_by_claim = Vec::with_capacity(n_claims);
704 for flush_sumcheck_meta in flush_sumcheck_metas {
705 let FlushSumcheckMeta {
706 composite_sum_claims,
707 flush_selectors_unique,
708 flush_oracle_ids,
709 eval_point,
710 } = flush_sumcheck_meta;
711
712 let mut multilinears =
713 Vec::with_capacity(flush_selectors_unique.len() + flush_oracle_ids.len());
714
715 for &flush_selector in &flush_selectors_unique {
716 multilinears.push(witness.get_multilin_poly(flush_selector)?);
717 }
718
719 for &oracle_id in &flush_oracle_ids {
720 multilinears.push(witness.get_multilin_poly(oracle_id)?);
721 }
722
723 let prover = GPAProver::new(
724 EvaluationOrder::LowToHigh,
725 multilinears,
726 None,
727 composite_sum_claims,
728 domain_factory.clone(),
729 &eval_point,
730 backend,
731 )?;
732
733 provers.push(prover);
734 flush_oracle_ids_by_claim.push(flush_oracle_ids);
735 flush_selectors_unique_by_claim.push(flush_selectors_unique);
736 }
737
738 Ok(FlushSumcheckProvers {
739 provers,
740 flush_selectors_unique_by_claim,
741 flush_oracle_ids_by_claim,
742 })
743}