1use std::{collections::HashSet, env, iter, marker::PhantomData};
4
5use binius_compute::{ComputeData, ComputeLayer, alloc::ComputeAllocator, cpu::CpuMemory};
6use binius_field::{
7 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
8 RepackedExtension, TowerField,
9 as_packed_field::PackedType,
10 linear_transformation::{PackedTransformationFactory, Transformation},
11 tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
12 underlier::WithUnderlier,
13 util::powers,
14};
15use binius_hal::ComputationBackend;
16use binius_hash::{PseudoCompressionFunction, multi_digest::ParallelDigest};
17use binius_math::{
18 CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
19 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
20};
21use binius_maybe_rayon::prelude::*;
22use binius_ntt::SingleThreadedNTT;
23use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
24use bytemuck::zeroed_vec;
25use digest::{FixedOutputReset, Output, core_api::BlockSizeUser};
26use itertools::chain;
27use tracing::instrument;
28use tracing_profile::utils::emit_max_rss;
29
30use super::{
31 ConstraintSystem, Proof,
32 channel::Boundary,
33 error::Error,
34 verify::{make_flush_oracles, max_n_vars_and_skip_rounds},
35};
36use crate::{
37 constraint_system::{
38 Flush,
39 channel::OracleOrConst,
40 common::{FDomain, FEncode, FExt, FFastExt},
41 exp::{self, reorder_exponents},
42 verify::augment_flush_po2_step_down,
43 },
44 fiat_shamir::{CanSample, Challenger},
45 merkle_tree::BinaryMerkleTreeProver,
46 oracle::{
47 Constraint, ConstraintSetBuilder, MultilinearOracleSet, MultilinearPolyVariant, OracleId,
48 SizedConstraintSet,
49 },
50 piop,
51 protocols::{
52 evalcheck::{
53 ConstraintSetEqIndPoint, EvalPoint, EvalcheckMultilinearClaim,
54 subclaims::{MemoizedData, prove_mlecheck_with_switchover},
55 },
56 fri::CommitOutput,
57 gkr_exp,
58 gkr_gpa::{self, GrandProductBatchProveOutput, GrandProductWitness},
59 greedy_evalcheck::{self, GreedyEvalcheckProveOutput},
60 sumcheck::{
61 self, constraint_set_zerocheck_claim, immediate_switchover_heuristic,
62 prove::ZerocheckProver, standard_switchover_heuristic,
63 },
64 },
65 ring_switch,
66 transcript::ProverTranscript,
67 transparent::step_down::StepDown,
68 witness::{IndexEntry, MultilinearExtensionIndex, MultilinearWitness},
69};
70
71#[allow(clippy::too_many_arguments)]
73#[instrument("constraint_system::prove", skip_all, level = "debug")]
74pub fn prove<
75 Hal,
76 U,
77 Tower,
78 Hash,
79 Compress,
80 Challenger_,
81 Backend,
82 HostAllocatorType,
83 DeviceAllocatorType,
84>(
85 compute_data: &mut ComputeData<Tower::B128, Hal, HostAllocatorType, DeviceAllocatorType>,
86 constraint_system: &ConstraintSystem<FExt<Tower>>,
87 log_inv_rate: usize,
88 security_bits: usize,
89 constraint_system_digest: &Output<Hash::Digest>,
90 boundaries: &[Boundary<FExt<Tower>>],
91 table_sizes: &[usize],
92 mut witness: MultilinearExtensionIndex<PackedType<U, FExt<Tower>>>,
93 backend: &Backend,
94) -> Result<Proof, Error>
95where
96 Hal: ComputeLayer<Tower::B128> + Default,
97 U: ProverTowerUnderlier<Tower>,
98 Tower: ProverTowerFamily,
99 Tower::B128:
100 binius_math::TowerTop + binius_math::PackedTop + PackedTop<Tower> + From<FFastExt<Tower>>,
101 Hash: ParallelDigest,
102 Hash::Digest: BlockSizeUser + FixedOutputReset + Send + Sync + Clone,
103 Compress: PseudoCompressionFunction<Output<Hash::Digest>, 2> + Default + Sync,
104 Challenger_: Challenger + Default,
105 Backend: ComputationBackend,
106 PackedType<U, Tower::B128>: PackedTop<Tower>
108 + PackedFieldIndexable
109 + RepackedExtension<PackedType<U, Tower::B1>>
111 + RepackedExtension<PackedType<U, Tower::B8>>
112 + RepackedExtension<PackedType<U, Tower::B16>>
113 + RepackedExtension<PackedType<U, Tower::B32>>
114 + RepackedExtension<PackedType<U, Tower::B64>>
115 + RepackedExtension<PackedType<U, Tower::B128>>
116 + PackedTransformationFactory<PackedType<U, Tower::FastB128>>
117 + binius_math::PackedTop,
118 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
119 HostAllocatorType: ComputeAllocator<Tower::B128, CpuMemory>,
120 DeviceAllocatorType: ComputeAllocator<Tower::B128, Hal::DevMem>,
121{
122 tracing::debug!(
123 arch = env::consts::ARCH,
124 rayon_threads = binius_maybe_rayon::current_num_threads(),
125 "using computation backend: {backend:?}"
126 );
127
128 let domain_factory = DefaultEvaluationDomainFactory::<FDomain<Tower>>::default();
129 let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
130
131 let ConstraintSystem {
132 oracles,
133 table_constraints,
134 mut flushes,
135 mut exponents,
136 mut non_zero_oracle_ids,
137 channel_count,
138 table_size_specs,
139 } = constraint_system.clone();
140
141 constraint_system.check_table_sizes(table_sizes)?;
142 let mut oracles = oracles.instantiate(table_sizes)?;
143
144 flushes.retain(|flush| table_sizes[flush.table_id] > 0);
150 flushes.sort_by_key(|flush| flush.channel_id);
151
152 non_zero_oracle_ids.retain(|oracle| !oracles.is_zero_sized(*oracle));
153 exponents.retain(|exp| !oracles.is_zero_sized(exp.exp_result_id));
154
155 let mut table_constraints = table_constraints
156 .into_iter()
157 .filter_map(|u| {
158 if table_sizes[u.table_id] == 0 {
159 None
160 } else {
161 let n_vars = u.log_values_per_row + log2_ceil_usize(table_sizes[u.table_id]);
162 Some(SizedConstraintSet::new(n_vars, u))
163 }
164 })
165 .collect::<Vec<_>>();
166 table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
168
169 reorder_exponents(&mut exponents, &oracles);
170
171 let mut transcript = ProverTranscript::<Challenger_>::new();
172 transcript
173 .observe()
174 .write_slice(constraint_system_digest.as_ref());
175 transcript.observe().write_slice(boundaries);
176 let mut writer = transcript.message();
177 writer.write_slice(table_sizes);
178
179 let witness_span = tracing::info_span!(
180 "[phase] Witness Finalization",
181 phase = "witness",
182 perfetto_category = "phase.main"
183 )
184 .entered();
185
186 let exp_compute_layer_span = tracing::info_span!(
189 "[step] Compute Exponentiation Layers",
190 phase = "witness",
191 perfetto_category = "phase.sub"
192 )
193 .entered();
194 let exp_witnesses = exp::make_exp_witnesses::<U, Tower>(&mut witness, &oracles, &exponents)?;
195 drop(exp_compute_layer_span);
196
197 drop(witness_span);
198
199 let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
201 let merkle_scheme = merkle_prover.scheme();
202
203 let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
204 let committed_multilins = piop::collect_committed_witnesses::<U, _>(
205 &commit_meta,
206 &oracle_to_commit_index,
207 &oracles,
208 &witness,
209 )?;
210
211 let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
212 &commit_meta,
213 merkle_scheme,
214 security_bits,
215 log_inv_rate,
216 )?;
217 let ntt = SingleThreadedNTT::with_subspace(fri_params.rs_code().subspace())?
218 .precompute_twiddles()
219 .multithreaded();
220
221 let commit_span =
222 tracing::info_span!("[phase] Commit", phase = "commit", perfetto_category = "phase.main")
223 .entered();
224 let CommitOutput {
225 commitment,
226 committed,
227 codeword,
228 } = piop::commit(&fri_params, &ntt, &merkle_prover, &committed_multilins)?;
229 emit_max_rss();
230 drop(commit_span);
231
232 let mut writer = transcript.message();
234 writer.write(&commitment);
235
236 let exp_span = tracing::info_span!(
237 "[phase] Exponentiation",
238 phase = "exp",
239 perfetto_category = "phase.main"
240 )
241 .entered();
242 let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
243
244 let exp_evals = gkr_exp::get_evals_in_point_from_witnesses(&exp_witnesses, &exp_challenge)?
245 .into_iter()
246 .map(|x| x.into())
247 .collect::<Vec<_>>();
248
249 let mut writer = transcript.message();
250 writer.write_scalar_slice(&exp_evals);
251
252 let exp_challenge = exp_challenge
253 .into_iter()
254 .map(|x| x.into())
255 .collect::<Vec<_>>();
256
257 let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
258 .into_iter()
259 .map(|claim| claim.isomorphic())
260 .collect::<Vec<_>>();
261
262 let base_exp_output = gkr_exp::batch_prove::<_, _, FFastExt<Tower>, _, _>(
263 EvaluationOrder::HighToLow,
264 exp_witnesses,
265 &exp_claims,
266 fast_domain_factory.clone(),
267 &mut transcript,
268 backend,
269 )?
270 .isomorphic();
271
272 let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
273 emit_max_rss();
274 drop(exp_span);
275
276 let prodcheck_span = tracing::info_span!(
279 "[phase] Product Check",
280 phase = "prodcheck",
281 perfetto_category = "phase.main"
282 )
283 .entered();
284
285 let nonzero_convert_span = tracing::info_span!(
286 "[task] Convert Non-Zero to Fast Field",
287 phase = "prodcheck",
288 perfetto_category = "task.main"
289 )
290 .entered();
291 let non_zero_fast_witnesses =
292 convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &non_zero_oracle_ids)?;
293 emit_max_rss();
294 drop(nonzero_convert_span);
295
296 let nonzero_prodcheck_compute_layer_span = tracing::info_span!(
297 "[step] Compute Non-Zero Product Layers",
298 phase = "prodcheck",
299 perfetto_category = "phase.sub"
300 )
301 .entered();
302 let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
303 .into_par_iter()
304 .map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
305 .collect::<Result<Vec<_>, _>>()?;
306 emit_max_rss();
307 drop(nonzero_prodcheck_compute_layer_span);
308
309 let non_zero_products =
310 gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
311 if non_zero_products
312 .iter()
313 .any(|count| *count == Tower::B128::zero())
314 {
315 bail!(Error::Zeros);
316 }
317
318 let mut writer = transcript.message();
319
320 writer.write_scalar_slice(&non_zero_products);
321
322 let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
323 &non_zero_oracle_ids,
324 &oracles,
325 &non_zero_products,
326 )?;
327
328 let mixing_challenge = transcript.sample();
330 let permutation_challenges = transcript.sample_vec(channel_count);
331
332 flushes.retain(|flush| table_sizes[flush.table_id] > 0);
333 flushes.sort_by_key(|flush| flush.channel_id);
334 let po2_step_down_polys =
335 augment_flush_po2_step_down(&mut oracles, &mut flushes, &table_size_specs, table_sizes)?;
336 populate_flush_po2_step_down_witnesses::<U, _>(po2_step_down_polys, &mut witness)?;
337 let flush_oracle_ids =
338 make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
339
340 let flush_convert_span = tracing::info_span!(
341 "[task] Convert Flushes to Fast Field",
342 phase = "prodcheck",
343 perfetto_category = "task.main"
344 )
345 .entered();
346
347 let mut fast_witness = MultilinearExtensionIndex::<PackedType<U, FFastExt<Tower>>>::new();
348
349 make_masked_flush_witnesses::<U, _>(
350 &oracles,
351 &mut witness,
352 &mut fast_witness,
353 &flush_oracle_ids,
354 &flushes,
355 mixing_challenge,
356 &permutation_challenges,
357 )?;
358
359 let flush_witnesses =
361 convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &flush_oracle_ids)?;
362 emit_max_rss();
363 drop(flush_convert_span);
364
365 let flush_prodcheck_compute_layer_span = tracing::info_span!(
366 "[step] Compute Flush Product Layers",
367 phase = "prodcheck",
368 perfetto_category = "phase.sub"
369 )
370 .entered();
371 let flush_prodcheck_witnesses = flush_witnesses
372 .into_par_iter()
373 .map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
374 .collect::<Result<Vec<_>, _>>()?;
375 emit_max_rss();
376 drop(flush_prodcheck_compute_layer_span);
377
378 let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
379
380 transcript.message().write_scalar_slice(&flush_products);
381
382 let flush_prodcheck_claims =
383 gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
384
385 let all_gpa_witnesses =
387 chain!(flush_prodcheck_witnesses, non_zero_prodcheck_witnesses).collect::<Vec<_>>();
388 let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
389 .map(|claim| claim.isomorphic())
390 .collect::<Vec<_>>();
391
392 let GrandProductBatchProveOutput { final_layer_claims } =
393 gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
394 EvaluationOrder::HighToLow,
395 all_gpa_witnesses,
396 &all_gpa_claims,
397 &fast_domain_factory,
398 &mut transcript,
399 backend,
400 )?;
401
402 let final_layer_claims = final_layer_claims
404 .into_iter()
405 .map(|layer_claim| layer_claim.isomorphic())
406 .collect::<Vec<_>>();
407
408 let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
410 chain!(flush_oracle_ids.clone(), non_zero_oracle_ids),
411 final_layer_claims,
412 )?;
413
414 let mut flush_prodcheck_eval_claims = prodcheck_eval_claims;
415
416 let prodcheck_eval_claims = flush_prodcheck_eval_claims.split_off(flush_oracle_ids.len());
417
418 let flush_eval_claims = reduce_flush_evalcheck_claims::<U, Tower, Challenger_, Backend>(
419 flush_prodcheck_eval_claims,
420 &oracles,
421 fast_witness,
422 fast_domain_factory.clone(),
423 &mut transcript,
424 backend,
425 )?;
426
427 emit_max_rss();
428 drop(prodcheck_span);
429
430 let zerocheck_span = tracing::info_span!(
432 "[phase] Zerocheck",
433 phase = "zerocheck",
434 perfetto_category = "phase.main",
435 )
436 .entered();
437
438 let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
439 .iter()
440 .cloned()
441 .map(constraint_set_zerocheck_claim)
442 .collect::<Result<Vec<_>, _>>()?
443 .into_iter()
444 .unzip::<_, _, Vec<_>, Vec<_>>();
445
446 let (max_n_vars, skip_rounds) =
447 max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
448
449 let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
450
451 let mut zerocheck_provers = Vec::with_capacity(table_constraints.len());
452
453 for constraint_set in table_constraints {
454 let n_vars = constraint_set.n_vars;
455 let (constraints, multilinears) =
456 sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
457
458 let base_tower_level = chain!(
459 multilinears
460 .iter()
461 .map(|multilinear| 7 - multilinear.log_extension_degree()),
462 constraints
463 .iter()
464 .map(|constraint| constraint.composition.binary_tower_level())
465 )
466 .max()
467 .unwrap_or(0);
468
469 let zerocheck_challenges = &zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..];
471 let domain_factory = domain_factory.clone();
472
473 let constructor =
474 ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _> {
475 constraints,
476 multilinears,
477 zerocheck_challenges,
478 domain_factory,
479 backend,
480 _fdomain_marker: PhantomData,
481 };
482
483 let zerocheck_prover = match base_tower_level {
484 0..=3 => constructor.create::<Tower::B8>()?,
485 4 => constructor.create::<Tower::B16>()?,
486 5 => constructor.create::<Tower::B32>()?,
487 6 => constructor.create::<Tower::B64>()?,
488 7 => constructor.create::<Tower::B128>()?,
489 _ => unreachable!(),
490 };
491
492 zerocheck_provers.push(zerocheck_prover);
493 }
494
495 let zerocheck_output = sumcheck::prove::batch_prove_zerocheck::<
496 FExt<Tower>,
497 FDomain<Tower>,
498 PackedType<U, FExt<Tower>>,
499 _,
500 _,
501 >(zerocheck_provers, skip_rounds, &mut transcript)?;
502
503 let zerocheck_eval_claims =
504 sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
505
506 emit_max_rss();
507 drop(zerocheck_span);
508
509 let evalcheck_span = tracing::info_span!(
510 "[phase] Evalcheck",
511 phase = "evalcheck",
512 perfetto_category = "phase.main"
513 )
514 .entered();
515
516 let GreedyEvalcheckProveOutput {
518 eval_claims,
519 memoized_data,
520 } = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
521 &mut oracles,
522 &mut witness,
523 chain!(flush_eval_claims, prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
524 standard_switchover_heuristic(-2),
525 &mut transcript,
526 &domain_factory,
527 backend,
528 )?;
529
530 let system = ring_switch::EvalClaimSystem::new(
532 &oracles,
533 &commit_meta,
534 &oracle_to_commit_index,
535 &eval_claims,
536 )?;
537
538 emit_max_rss();
539 drop(evalcheck_span);
540
541 let ring_switch_span = tracing::info_span!(
542 "[phase] Ring Switch",
543 phase = "ring_switch",
544 perfetto_category = "phase.main"
545 )
546 .entered();
547
548 let hal = compute_data.hal;
549
550 let dev_alloc = &compute_data.dev_alloc;
551 let host_alloc = &compute_data.host_alloc;
552
553 let ring_switch::ReducedWitness {
554 transparents: transparent_multilins,
555 sumcheck_claims: piop_sumcheck_claims,
556 } = ring_switch::prove(
557 &system,
558 &committed_multilins,
559 &mut transcript,
560 memoized_data,
561 hal,
562 dev_alloc,
563 host_alloc,
564 )?;
565 emit_max_rss();
566 drop(ring_switch_span);
567
568 let piop_compiler_span = tracing::info_span!(
570 "[phase] PIOP Compiler",
571 phase = "piop_compiler",
572 perfetto_category = "phase.main"
573 )
574 .entered();
575
576 piop::prove(
577 compute_data,
578 &fri_params,
579 &ntt,
580 &merkle_prover,
581 &commit_meta,
582 committed,
583 &codeword,
584 &committed_multilins,
585 transparent_multilins,
586 &piop_sumcheck_claims,
587 &mut transcript,
588 )?;
589 emit_max_rss();
590 drop(piop_compiler_span);
591
592 let proof = Proof {
593 transcript: transcript.finalize(),
594 };
595
596 tracing::event!(
597 name: "proof_size",
598 tracing::Level::INFO,
599 counter = true,
600 value = proof.get_proof_size() as u64,
601 unit = "bytes",
602 );
603
604 Ok(proof)
605}
606
607type TypeErasedZerocheck<'a, P> = Box<dyn ZerocheckProver<'a, P> + 'a>;
608
609struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
610where
611 P: PackedField,
612{
613 constraints: Vec<Constraint<P::Scalar>>,
614 multilinears: Vec<MultilinearWitness<'a, P>>,
615 domain_factory: DomainFactory,
616 zerocheck_challenges: &'a [P::Scalar],
617 backend: &'a Backend,
618 _fdomain_marker: PhantomData<FDomain>,
619}
620
621impl<'a, P, F, FDomain, DomainFactory, Backend>
622 ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
623where
624 F: Field,
625 P: PackedField<Scalar = F>,
626 FDomain: TowerField,
627 DomainFactory: EvaluationDomainFactory<FDomain> + 'a,
628 Backend: ComputationBackend,
629{
630 fn create<FBase>(self) -> Result<TypeErasedZerocheck<'a, P>, Error>
631 where
632 FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
633 P: PackedExtension<F, PackedSubfield = P>
634 + PackedExtension<FDomain>
635 + PackedExtension<FBase>,
636 F: TowerField,
637 {
638 let zerocheck_prover =
639 sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _, _>(
640 self.constraints,
641 self.multilinears,
642 self.domain_factory,
643 self.zerocheck_challenges,
644 self.backend,
645 )?;
646
647 let type_erased_zerocheck_prover = Box::new(zerocheck_prover) as TypeErasedZerocheck<'a, P>;
648
649 Ok(type_erased_zerocheck_prover)
650 }
651}
652
653fn populate_flush_po2_step_down_witnesses<'a, U, Tower>(
654 step_down_polys: Vec<(OracleId, StepDown)>,
655 witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
656) -> Result<(), Error>
657where
658 U: ProverTowerUnderlier<Tower>,
659 Tower: ProverTowerFamily,
660{
661 for (oracle_id, step_down_poly) in step_down_polys {
662 let witness_poly = step_down_poly
663 .multilinear_extension::<PackedType<U, Tower::B1>>()?
664 .specialize_arc_dyn();
665 witness.update_multilin_poly([(oracle_id, witness_poly)])?
666 }
667 Ok(())
668}
669
670#[instrument(skip_all, level = "debug")]
671pub fn make_masked_flush_witnesses<'a, U, Tower>(
672 oracles: &MultilinearOracleSet<FExt<Tower>>,
673 witness_index: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
674 fast_witness_index: &mut MultilinearExtensionIndex<'a, PackedType<U, FFastExt<Tower>>>,
675 flush_oracle_ids: &[OracleId],
676 flushes: &[Flush<FExt<Tower>>],
677 mixing_challenge: FExt<Tower>,
678 permutation_challenges: &[FExt<Tower>],
679) -> Result<(), Error>
680where
681 U: ProverTowerUnderlier<Tower>,
682 Tower: ProverTowerFamily,
683 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
684 + RepackedExtension<PackedType<U, Tower::B1>>,
685{
686 for flush in flushes {
689 let fast_selectors =
690 convert_1b_witnesses_to_fast_ext::<U, Tower>(witness_index, &flush.selectors)?;
691
692 for (&selector_id, fast_selector) in flush.selectors.iter().zip(fast_selectors) {
693 let selector = witness_index.get_multilin_poly(selector_id)?;
694 let zero_suffix_len = count_zero_suffixes(&selector);
695
696 let nonzero_prefix_len = (1 << selector.n_vars()) - zero_suffix_len;
697 witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
698 selector_id,
699 selector,
700 nonzero_prefix_len,
701 )])?;
702
703 fast_witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
704 selector_id,
705 fast_selector,
706 nonzero_prefix_len,
707 )])?;
708 }
709 }
710
711 let inner_oracles_id = flushes
712 .iter()
713 .flat_map(|flush| {
714 flush
715 .oracles
716 .iter()
717 .filter_map(|oracle_or_const| match oracle_or_const {
718 OracleOrConst::Oracle(oracle_id) => Some(*oracle_id),
719 _ => None,
720 })
721 })
722 .collect::<HashSet<_>>();
723
724 let inner_oracles_id = inner_oracles_id.into_iter().collect::<Vec<_>>();
725
726 let fast_inner_oracles =
727 convert_witnesses_to_fast_ext::<U, Tower>(oracles, witness_index, &inner_oracles_id)?;
728
729 for ((n_vars, witness_data), id) in fast_inner_oracles.into_iter().zip(inner_oracles_id) {
730 let fast_witness = MLEDirectAdapter::from(
731 MultilinearExtension::new(n_vars, witness_data)
732 .expect("witness_data created with correct n_vars"),
733 );
734
735 let nonzero_scalars_prefix = witness_index.get_index_entry(id)?.nonzero_scalars_prefix;
736
737 fast_witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
738 id,
739 fast_witness.upcast_arc_dyn(),
740 nonzero_scalars_prefix,
741 )])?;
742 }
743
744 let max_n_mixed = flushes
746 .iter()
747 .map(|flush| flush.oracles.len())
748 .max()
749 .unwrap_or_default();
750 let mixing_powers = powers(mixing_challenge)
751 .take(max_n_mixed)
752 .collect::<Vec<_>>();
753
754 let indices_to_update = flush_oracle_ids
756 .par_iter()
757 .zip(flushes)
758 .map(|(&flush_oracle, flush)| {
759 let n_vars = oracles.n_vars(flush_oracle);
760
761 let const_term = flush
762 .oracles
763 .iter()
764 .copied()
765 .zip(mixing_powers.iter())
766 .filter_map(|(oracle_or_const, coeff)| match oracle_or_const {
767 OracleOrConst::Const { base, .. } => Some(base * coeff),
768 _ => None,
769 })
770 .sum::<FExt<Tower>>();
771 let const_term = permutation_challenges[flush.channel_id] + const_term;
772
773 let inner_oracles = flush
774 .oracles
775 .iter()
776 .copied()
777 .zip(mixing_powers.iter())
778 .filter_map(|(oracle_or_const, &coeff)| match oracle_or_const {
779 OracleOrConst::Oracle(oracle_id) => Some((oracle_id, coeff)),
780 _ => None,
781 })
782 .map(|(inner_id, coeff)| {
783 let witness = witness_index.get_multilin_poly(inner_id)?;
784 Ok((witness, coeff))
785 })
786 .collect::<Result<Vec<_>, Error>>()?;
787
788 let selector_entries = flush
789 .selectors
790 .iter()
791 .map(|id| witness_index.get_index_entry(*id))
792 .collect::<Result<Vec<_>, _>>()?;
793
794 let selector_prefix_len = selector_entries
796 .iter()
797 .map(|selector_entry| selector_entry.nonzero_scalars_prefix)
798 .min()
799 .unwrap_or(1 << n_vars);
800
801 let selectors = selector_entries
802 .into_iter()
803 .map(|entry| entry.multilin_poly)
804 .collect::<Vec<_>>();
805
806 let log_width = <PackedType<U, FExt<Tower>>>::LOG_WIDTH;
807 let packed_selector_prefix_len = selector_prefix_len.div_ceil(1 << log_width);
808
809 let mut witness_data = Vec::with_capacity(1 << n_vars.saturating_sub(log_width));
810 (0..packed_selector_prefix_len)
811 .into_par_iter()
812 .map(|i| {
813 <PackedType<U, FExt<Tower>>>::from_fn(|j| {
814 let index = i << log_width | j;
815
816 if index >= 1 << n_vars {
818 return <FExt<Tower>>::ZERO;
819 }
820
821 let selector_off = selectors.iter().any(|selector| {
823 let sel_val = selector
824 .evaluate_on_hypercube(index)
825 .expect("index < 1 << n_vars");
826 sel_val.is_zero()
827 });
828
829 if selector_off {
830 <FExt<Tower>>::ONE
832 } else {
833 let mut inner_oracles_iter = inner_oracles.iter();
835
836 if let Some((poly, coeff)) = inner_oracles_iter.next() {
839 let first_term = if *coeff == FExt::<Tower>::ONE {
840 poly.evaluate_on_hypercube(index).expect("index in bounds")
841 } else {
842 poly.evaluate_on_hypercube_and_scale(index, *coeff)
843 .expect("index in bounds")
844 };
845 inner_oracles_iter.fold(
846 const_term + first_term,
847 |sum, (poly, coeff)| {
848 let scaled_eval = poly
849 .evaluate_on_hypercube_and_scale(index, *coeff)
850 .expect("index in bounds");
851 sum + scaled_eval
852 },
853 )
854 } else {
855 const_term
856 }
857 }
858 })
859 })
860 .collect_into_vec(&mut witness_data);
861 witness_data.resize(witness_data.capacity(), PackedType::<U, FExt<Tower>>::one());
862
863 let witness = MLEDirectAdapter::from(
864 MultilinearExtension::new(n_vars, witness_data)
865 .expect("witness_data created with correct n_vars"),
866 );
867 Ok((witness, selector_prefix_len))
870 })
871 .collect::<Result<Vec<_>, Error>>()?;
872
873 witness_index.update_multilin_poly_with_nonzero_scalars_prefixes(
874 iter::zip(flush_oracle_ids, indices_to_update).map(
875 |(&oracle_id, (witness, nonzero_scalars_prefix))| {
876 (oracle_id, witness.upcast_arc_dyn(), nonzero_scalars_prefix)
877 },
878 ),
879 )?;
880 Ok(())
881}
882
883fn count_zero_suffixes<P: PackedField, M: MultilinearPoly<P>>(poly: &M) -> usize {
884 let zeros = P::zero();
885 if let Some(packed_evals) = poly.packed_evals() {
886 let packed_zero_suffix_len = packed_evals
887 .iter()
888 .rev()
889 .position(|&packed_eval| packed_eval != zeros)
890 .unwrap_or(packed_evals.len());
891
892 let log_scalars_per_elem = P::LOG_WIDTH + poly.log_extension_degree();
893 if poly.n_vars() < log_scalars_per_elem {
894 debug_assert_eq!(packed_evals.len(), 1, "invariant of MultilinearPoly");
895 packed_zero_suffix_len << poly.n_vars()
896 } else {
897 packed_zero_suffix_len << log_scalars_per_elem
898 }
899 } else {
900 0
901 }
902}
903
904#[allow(clippy::type_complexity)]
931#[instrument(skip_all, level = "debug")]
932fn convert_witnesses_to_fast_ext<'a, U, Tower>(
933 oracles: &MultilinearOracleSet<FExt<Tower>>,
934 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
935 oracle_ids: &[OracleId],
936) -> Result<Vec<(usize, Vec<PackedType<U, FFastExt<Tower>>>)>, Error>
937where
938 U: ProverTowerUnderlier<Tower>,
939 Tower: ProverTowerFamily,
940 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
941{
942 let to_fast = Tower::packed_transformation_to_fast();
943
944 oracle_ids
946 .into_par_iter()
947 .map(|&flush_oracle_id| {
948 let n_vars = oracles.n_vars(flush_oracle_id);
949
950 let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
951
952 let IndexEntry {
953 multilin_poly: poly,
954 nonzero_scalars_prefix,
955 } = witness.get_index_entry(flush_oracle_id)?;
956
957 const MAX_SUBCUBE_VARS: usize = 8;
958 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
959 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
960 let non_const_scalars = nonzero_scalars_prefix;
961 let non_const_subcubes = non_const_scalars.div_ceil(1 << subcube_vars);
962
963 let mut fast_ext_result = zeroed_vec(non_const_subcubes * subcube_packed_size);
964 fast_ext_result
965 .par_chunks_exact_mut(subcube_packed_size)
966 .enumerate()
967 .for_each(|(subcube_index, fast_subcube)| {
968 let underliers =
969 PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
970
971 let subcube_evals =
972 PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
973 poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
974 .expect("witness data populated by make_unmasked_flush_witnesses()");
975
976 for underlier in underliers.iter_mut() {
977 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
978 let dest = to_fast.transform(&src);
979 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
980 }
981 });
982
983 fast_ext_result.truncate(non_const_scalars);
984 Ok((n_vars, fast_ext_result))
985 })
986 .collect()
987}
988
989#[allow(clippy::type_complexity)]
990pub fn convert_1b_witnesses_to_fast_ext<'a, U, Tower>(
991 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
992 ids: &[OracleId],
993) -> Result<Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>, Error>
994where
995 U: ProverTowerUnderlier<Tower>,
996 Tower: ProverTowerFamily,
997 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
998 + RepackedExtension<PackedType<U, Tower::B1>>,
999{
1000 ids.iter()
1001 .map(|&id| {
1002 let exp_witness = witness.get_multilin_poly(id)?;
1003
1004 let packed_evals = exp_witness
1005 .packed_evals()
1006 .expect("poly contain packed_evals");
1007
1008 let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
1009
1010 MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
1011 .map(|mle| mle.specialize_arc_dyn())
1012 .map_err(Error::from)
1013 })
1014 .collect::<Result<Vec<_>, _>>()
1015}
1016
1017#[instrument(skip_all, name = "flush::reduce_flush_evalcheck_claims")]
1018fn reduce_flush_evalcheck_claims<
1019 U,
1020 Tower: ProverTowerFamily,
1021 Challenger_,
1022 Backend: ComputationBackend,
1023>(
1024 claims: Vec<EvalcheckMultilinearClaim<FExt<Tower>>>,
1025 oracles: &MultilinearOracleSet<FExt<Tower>>,
1026 witness_index: MultilinearExtensionIndex<PackedType<U, FFastExt<Tower>>>,
1027 domain_factory: IsomorphicEvaluationDomainFactory<FFastExt<Tower>>,
1028 transcript: &mut ProverTranscript<Challenger_>,
1029 backend: &Backend,
1030) -> Result<Vec<EvalcheckMultilinearClaim<FExt<Tower>>>, Error>
1031where
1032 FExt<Tower>: From<FFastExt<Tower>>,
1033 FFastExt<Tower>: From<FExt<Tower>>,
1034 U: ProverTowerUnderlier<Tower>,
1035 Challenger_: Challenger + Default,
1036{
1037 let mut linear_claims = Vec::new();
1038
1039 #[allow(clippy::type_complexity)]
1040 let mut new_mlechecks_constraints: Vec<(
1041 EvalPoint<FFastExt<Tower>>,
1042 ConstraintSetBuilder<FFastExt<Tower>>,
1043 )> = Vec::new();
1044
1045 for claim in &claims {
1046 match &oracles[claim.id].variant {
1047 MultilinearPolyVariant::LinearCombination(_) => linear_claims.push(claim.clone()),
1048 MultilinearPolyVariant::Composite(composite) => {
1049 let eval_point = claim.eval_point.isomorphic();
1050
1051 let eval = claim.eval.into();
1052
1053 let position = new_mlechecks_constraints
1054 .iter()
1055 .position(|(ep, _)| *ep == eval_point)
1056 .unwrap_or(new_mlechecks_constraints.len());
1057
1058 let oracle_ids = composite.inner().clone();
1059
1060 let exp = <_ as CompositionPoly<FExt<Tower>>>::expression(composite.c());
1061 let fast_exp = exp.convert_field::<FFastExt<Tower>>();
1062
1063 if let Some((_, constraint_builder)) = new_mlechecks_constraints.get_mut(position) {
1064 constraint_builder.add_sumcheck(oracle_ids, fast_exp, eval);
1065 } else {
1066 let mut new_builder = ConstraintSetBuilder::new();
1067 new_builder.add_sumcheck(oracle_ids, fast_exp, eval);
1068 new_mlechecks_constraints.push((eval_point.clone(), new_builder));
1069 }
1070 }
1071 _ => unreachable!(),
1072 }
1073 }
1074
1075 let new_mlechecks = new_mlechecks_constraints
1076 .into_iter()
1077 .map(|(ep, builder)| {
1078 builder
1079 .build_one(oracles)
1080 .map(|constraint| ConstraintSetEqIndPoint {
1081 eq_ind_challenges: ep.clone(),
1082 constraint_set: constraint,
1083 })
1084 .map_err(Error::from)
1085 })
1086 .collect::<Result<Vec<_>, Error>>()?;
1087
1088 let mut memoized_data = MemoizedData::new();
1089
1090 let mut fast_new_evalcheck_claims = Vec::new();
1091
1092 for ConstraintSetEqIndPoint {
1093 eq_ind_challenges,
1094 constraint_set,
1095 } in new_mlechecks
1096 {
1097 let evalcheck_claims = prove_mlecheck_with_switchover::<_, _, FFastExt<Tower>, _, _>(
1098 &witness_index,
1099 constraint_set,
1100 eq_ind_challenges,
1101 &mut memoized_data,
1102 transcript,
1103 immediate_switchover_heuristic,
1104 domain_factory.clone(),
1105 backend,
1106 )?;
1107 fast_new_evalcheck_claims.extend(evalcheck_claims);
1108 }
1109
1110 Ok(chain!(
1111 fast_new_evalcheck_claims
1112 .into_iter()
1113 .map(|claim| claim.isomorphic::<FExt<Tower>>()),
1114 linear_claims.into_iter()
1115 )
1116 .collect::<Vec<_>>())
1117}