1use std::{env, iter, marker::PhantomData};
4
5use binius_field::{
6 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
7 RepackedExtension, TowerField,
8 as_packed_field::PackedType,
9 linear_transformation::{PackedTransformationFactory, Transformation},
10 tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
11 underlier::WithUnderlier,
12 util::powers,
13};
14use binius_hal::ComputationBackend;
15use binius_hash::PseudoCompressionFunction;
16use binius_math::{
17 DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
18 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
19};
20use binius_maybe_rayon::prelude::*;
21use binius_ntt::SingleThreadedNTT;
22use binius_utils::bail;
23use bytemuck::zeroed_vec;
24use digest::{Digest, FixedOutputReset, Output, core_api::BlockSizeUser};
25use itertools::chain;
26use tracing::instrument;
27
28use super::{
29 ConstraintSystem, Proof,
30 channel::Boundary,
31 error::Error,
32 verify::{make_flush_oracles, max_n_vars_and_skip_rounds},
33};
34use crate::{
35 constraint_system::{
36 Flush,
37 channel::OracleOrConst,
38 common::{FDomain, FEncode, FExt, FFastExt},
39 exp::{self, reorder_exponents},
40 },
41 fiat_shamir::{CanSample, Challenger},
42 merkle_tree::BinaryMerkleTreeProver,
43 oracle::{Constraint, MultilinearOracleSet, OracleId},
44 piop,
45 protocols::{
46 fri::CommitOutput,
47 gkr_exp,
48 gkr_gpa::{self, GrandProductBatchProveOutput, GrandProductWitness},
49 greedy_evalcheck::{self, GreedyEvalcheckProveOutput},
50 sumcheck::{
51 self, constraint_set_zerocheck_claim, prove::ZerocheckProver,
52 standard_switchover_heuristic,
53 },
54 },
55 ring_switch,
56 transcript::ProverTranscript,
57 witness::{IndexEntry, MultilinearExtensionIndex, MultilinearWitness},
58};
59
60#[instrument("constraint_system::prove", skip_all, level = "debug")]
62pub fn prove<U, Tower, Hash, Compress, Challenger_, Backend>(
63 constraint_system: &ConstraintSystem<FExt<Tower>>,
64 log_inv_rate: usize,
65 security_bits: usize,
66 boundaries: &[Boundary<FExt<Tower>>],
67 mut witness: MultilinearExtensionIndex<PackedType<U, FExt<Tower>>>,
68 backend: &Backend,
69) -> Result<Proof, Error>
70where
71 U: ProverTowerUnderlier<Tower>,
72 Tower: ProverTowerFamily,
73 Tower::B128: PackedTop<Tower>,
74 Hash: Digest + BlockSizeUser + FixedOutputReset + Send + Sync + Clone,
75 Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
76 Challenger_: Challenger + Default,
77 Backend: ComputationBackend,
78 PackedType<U, Tower::B128>: PackedTop<Tower>
80 + PackedFieldIndexable + RepackedExtension<PackedType<U, Tower::B8>>
82 + RepackedExtension<PackedType<U, Tower::B16>>
83 + RepackedExtension<PackedType<U, Tower::B32>>
84 + RepackedExtension<PackedType<U, Tower::B64>>
85 + RepackedExtension<PackedType<U, Tower::B128>>
86 + PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
87 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
88{
89 tracing::debug!(
90 arch = env::consts::ARCH,
91 rayon_threads = binius_maybe_rayon::current_num_threads(),
92 "using computation backend: {backend:?}"
93 );
94
95 let domain_factory = DefaultEvaluationDomainFactory::<FDomain<Tower>>::default();
96 let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
97
98 let mut transcript = ProverTranscript::<Challenger_>::new();
99 transcript.observe().write_slice(boundaries);
100
101 let ConstraintSystem {
102 mut oracles,
103 mut table_constraints,
104 mut flushes,
105 mut exponents,
106 non_zero_oracle_ids,
107 max_channel_id,
108 } = constraint_system.clone();
109
110 reorder_exponents(&mut exponents, &oracles);
111
112 let witness_span = tracing::info_span!(
113 "[phase] Witness Finalization",
114 phase = "witness",
115 perfetto_category = "phase.main"
116 )
117 .entered();
118
119 let exp_compute_layer_span = tracing::info_span!(
122 "[step] Compute Exponentiation Layers",
123 phase = "witness",
124 perfetto_category = "phase.sub"
125 )
126 .entered();
127 let exp_witnesses = exp::make_exp_witnesses::<U, Tower>(&mut witness, &oracles, &exponents)?;
128 drop(exp_compute_layer_span);
129
130 drop(witness_span);
131
132 table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
134
135 let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
137 let merkle_scheme = merkle_prover.scheme();
138
139 let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
140 let committed_multilins = piop::collect_committed_witnesses::<U, _>(
141 &commit_meta,
142 &oracle_to_commit_index,
143 &oracles,
144 &witness,
145 )?;
146
147 let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
148 &commit_meta,
149 merkle_scheme,
150 security_bits,
151 log_inv_rate,
152 )?;
153 let ntt = SingleThreadedNTT::new(fri_params.rs_code().log_len())?
154 .precompute_twiddles()
155 .multithreaded();
156
157 let commit_span =
158 tracing::info_span!("[phase] Commit", phase = "commit", perfetto_category = "phase.main")
159 .entered();
160 let CommitOutput {
161 commitment,
162 committed,
163 codeword,
164 } = piop::commit(&fri_params, &ntt, &merkle_prover, &committed_multilins)?;
165 drop(commit_span);
166
167 let mut writer = transcript.message();
169 writer.write(&commitment);
170
171 let exp_span = tracing::info_span!(
172 "[phase] Exponentiation",
173 phase = "exp",
174 perfetto_category = "phase.main"
175 )
176 .entered();
177 let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
178
179 let exp_evals = gkr_exp::get_evals_in_point_from_witnesses(&exp_witnesses, &exp_challenge)?
180 .into_iter()
181 .map(|x| x.into())
182 .collect::<Vec<_>>();
183
184 let mut writer = transcript.message();
185 writer.write_scalar_slice(&exp_evals);
186
187 let exp_challenge = exp_challenge
188 .into_iter()
189 .map(|x| x.into())
190 .collect::<Vec<_>>();
191
192 let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
193 .into_iter()
194 .map(|claim| claim.isomorphic())
195 .collect::<Vec<_>>();
196
197 let base_exp_output = gkr_exp::batch_prove::<_, _, FFastExt<Tower>, _, _>(
198 EvaluationOrder::HighToLow,
199 exp_witnesses,
200 &exp_claims,
201 fast_domain_factory.clone(),
202 &mut transcript,
203 backend,
204 )?
205 .isomorphic();
206
207 let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
208 drop(exp_span);
209
210 let prodcheck_span = tracing::info_span!(
213 "[phase] Product Check",
214 phase = "prodcheck",
215 perfetto_category = "phase.main"
216 )
217 .entered();
218
219 let nonzero_convert_span = tracing::info_span!(
220 "[task] Convert Non-Zero to Fast Field",
221 phase = "prodcheck",
222 perfetto_category = "task.main"
223 )
224 .entered();
225 let non_zero_fast_witnesses =
226 convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &non_zero_oracle_ids)?;
227 drop(nonzero_convert_span);
228
229 let nonzero_prodcheck_compute_layer_span = tracing::info_span!(
230 "[step] Compute Non-Zero Product Layers",
231 phase = "prodcheck",
232 perfetto_category = "phase.sub"
233 )
234 .entered();
235 let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
236 .into_par_iter()
237 .map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
238 .collect::<Result<Vec<_>, _>>()?;
239 drop(nonzero_prodcheck_compute_layer_span);
240
241 let non_zero_products =
242 gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
243 if non_zero_products
244 .iter()
245 .any(|count| *count == Tower::B128::zero())
246 {
247 bail!(Error::Zeros);
248 }
249
250 let mut writer = transcript.message();
251
252 writer.write_scalar_slice(&non_zero_products);
253
254 let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
255 &non_zero_oracle_ids,
256 &oracles,
257 &non_zero_products,
258 )?;
259
260 let mixing_challenge = transcript.sample();
262 let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
263
264 flushes.sort_by_key(|flush| flush.channel_id);
265 let flush_oracle_ids =
266 make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
267
268 let flush_convert_span = tracing::info_span!(
269 "[task] Convert Flushes to Fast Field",
270 phase = "prodcheck",
271 perfetto_category = "task.main"
272 )
273 .entered();
274 make_masked_flush_witnesses::<U, _>(
275 &oracles,
276 &mut witness,
277 &flush_oracle_ids,
278 &flushes,
279 mixing_challenge,
280 &permutation_challenges,
281 )?;
282
283 let flush_witnesses =
285 convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &flush_oracle_ids)?;
286 drop(flush_convert_span);
287
288 let flush_prodcheck_compute_layer_span = tracing::info_span!(
289 "[step] Compute Flush Product Layers",
290 phase = "prodcheck",
291 perfetto_category = "phase.sub"
292 )
293 .entered();
294 let flush_prodcheck_witnesses = flush_witnesses
295 .into_par_iter()
296 .map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
297 .collect::<Result<Vec<_>, _>>()?;
298 drop(flush_prodcheck_compute_layer_span);
299
300 let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
301
302 transcript.message().write_scalar_slice(&flush_products);
303
304 let flush_prodcheck_claims =
305 gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
306
307 let all_gpa_witnesses =
309 chain!(flush_prodcheck_witnesses, non_zero_prodcheck_witnesses).collect::<Vec<_>>();
310 let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
311 .map(|claim| claim.isomorphic())
312 .collect::<Vec<_>>();
313
314 let GrandProductBatchProveOutput { final_layer_claims } =
315 gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
316 EvaluationOrder::HighToLow,
317 all_gpa_witnesses,
318 &all_gpa_claims,
319 &fast_domain_factory,
320 &mut transcript,
321 backend,
322 )?;
323
324 let final_layer_claims = final_layer_claims
326 .into_iter()
327 .map(|layer_claim| layer_claim.isomorphic())
328 .collect::<Vec<_>>();
329
330 let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
332 chain!(flush_oracle_ids, non_zero_oracle_ids),
333 final_layer_claims,
334 )?;
335 drop(prodcheck_span);
336
337 let zerocheck_span = tracing::info_span!(
339 "[phase] Zerocheck",
340 phase = "zerocheck",
341 perfetto_category = "phase.main",
342 )
343 .entered();
344
345 let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
346 .iter()
347 .cloned()
348 .map(constraint_set_zerocheck_claim)
349 .collect::<Result<Vec<_>, _>>()?
350 .into_iter()
351 .unzip::<_, _, Vec<_>, Vec<_>>();
352
353 let (max_n_vars, skip_rounds) =
354 max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
355
356 let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
357
358 let mut zerocheck_provers = Vec::with_capacity(table_constraints.len());
359
360 for constraint_set in table_constraints {
361 let n_vars = constraint_set.n_vars;
362 let (constraints, multilinears) =
363 sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
364
365 let base_tower_level = chain!(
366 multilinears
367 .iter()
368 .map(|multilinear| 7 - multilinear.log_extension_degree()),
369 constraints
370 .iter()
371 .map(|constraint| constraint.composition.binary_tower_level())
372 )
373 .max()
374 .unwrap_or(0);
375
376 let zerocheck_challenges = &zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..];
378 let domain_factory = domain_factory.clone();
379
380 let constructor =
381 ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _> {
382 constraints,
383 multilinears,
384 zerocheck_challenges,
385 domain_factory,
386 backend,
387 _fdomain_marker: PhantomData,
388 };
389
390 let zerocheck_prover = match base_tower_level {
391 0..=3 => constructor.create::<Tower::B8>()?,
392 4 => constructor.create::<Tower::B16>()?,
393 5 => constructor.create::<Tower::B32>()?,
394 6 => constructor.create::<Tower::B64>()?,
395 7 => constructor.create::<Tower::B128>()?,
396 _ => unreachable!(),
397 };
398
399 zerocheck_provers.push(zerocheck_prover);
400 }
401
402 let zerocheck_output = sumcheck::prove::batch_prove_zerocheck::<
403 FExt<Tower>,
404 FDomain<Tower>,
405 PackedType<U, FExt<Tower>>,
406 _,
407 _,
408 >(zerocheck_provers, skip_rounds, &mut transcript)?;
409
410 let zerocheck_eval_claims =
411 sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
412
413 drop(zerocheck_span);
414
415 let evalcheck_span = tracing::info_span!(
416 "[phase] Evalcheck",
417 phase = "evalcheck",
418 perfetto_category = "phase.main"
419 )
420 .entered();
421
422 let GreedyEvalcheckProveOutput {
424 eval_claims,
425 memoized_data,
426 } = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
427 &mut oracles,
428 &mut witness,
429 chain!(prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
430 standard_switchover_heuristic(-2),
431 &mut transcript,
432 &domain_factory,
433 backend,
434 )?;
435
436 let system = ring_switch::EvalClaimSystem::new(
438 &oracles,
439 &commit_meta,
440 &oracle_to_commit_index,
441 &eval_claims,
442 )?;
443
444 drop(evalcheck_span);
445
446 let ring_switch_span = tracing::info_span!(
447 "[phase] Ring Switch",
448 phase = "ring_switch",
449 perfetto_category = "phase.main"
450 )
451 .entered();
452 let ring_switch::ReducedWitness {
453 transparents: transparent_multilins,
454 sumcheck_claims: piop_sumcheck_claims,
455 } = ring_switch::prove::<_, _, _, Tower, _>(
456 &system,
457 &committed_multilins,
458 &mut transcript,
459 memoized_data,
460 )?;
461 drop(ring_switch_span);
462
463 let piop_compiler_span = tracing::info_span!(
465 "[phase] PIOP Compiler",
466 phase = "piop_compiler",
467 perfetto_category = "phase.main"
468 )
469 .entered();
470 piop::prove::<_, FDomain<Tower>, _, _, _, _, _, _, _, _, _>(
471 &fri_params,
472 &ntt,
473 &merkle_prover,
474 domain_factory,
475 &commit_meta,
476 committed,
477 &codeword,
478 &committed_multilins,
479 &transparent_multilins,
480 &piop_sumcheck_claims,
481 &mut transcript,
482 &backend,
483 )?;
484 drop(piop_compiler_span);
485
486 let proof = Proof {
487 transcript: transcript.finalize(),
488 };
489
490 tracing::event!(
491 name: "proof_size",
492 tracing::Level::INFO,
493 counter = true,
494 value = proof.get_proof_size() as u64,
495 unit = "bytes",
496 );
497
498 Ok(proof)
499}
500
501type TypeErasedZerocheck<'a, P> = Box<dyn ZerocheckProver<'a, P> + 'a>;
502
503struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
504where
505 P: PackedField,
506{
507 constraints: Vec<Constraint<P::Scalar>>,
508 multilinears: Vec<MultilinearWitness<'a, P>>,
509 domain_factory: DomainFactory,
510 zerocheck_challenges: &'a [P::Scalar],
511 backend: &'a Backend,
512 _fdomain_marker: PhantomData<FDomain>,
513}
514
515impl<'a, P, F, FDomain, DomainFactory, Backend>
516 ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
517where
518 F: Field,
519 P: PackedField<Scalar = F>,
520 FDomain: TowerField,
521 DomainFactory: EvaluationDomainFactory<FDomain> + 'a,
522 Backend: ComputationBackend,
523{
524 fn create<FBase>(self) -> Result<TypeErasedZerocheck<'a, P>, Error>
525 where
526 FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
527 P: PackedExtension<F, PackedSubfield = P>
528 + PackedExtension<FDomain>
529 + PackedExtension<FBase>,
530 F: TowerField,
531 {
532 let zerocheck_prover =
533 sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _, _>(
534 self.constraints,
535 self.multilinears,
536 self.domain_factory,
537 self.zerocheck_challenges,
538 self.backend,
539 )?;
540
541 let type_erased_zerocheck_prover = Box::new(zerocheck_prover) as TypeErasedZerocheck<'a, P>;
542
543 Ok(type_erased_zerocheck_prover)
544 }
545}
546
547#[instrument(skip_all, level = "debug")]
548fn make_masked_flush_witnesses<'a, U, Tower>(
549 oracles: &MultilinearOracleSet<FExt<Tower>>,
550 witness_index: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
551 flush_oracle_ids: &[OracleId],
552 flushes: &[Flush<FExt<Tower>>],
553 mixing_challenge: FExt<Tower>,
554 permutation_challenges: &[FExt<Tower>],
555) -> Result<(), Error>
556where
557 U: ProverTowerUnderlier<Tower>,
558 Tower: ProverTowerFamily,
559{
560 for flush in flushes {
563 for &selector_id in &flush.selectors {
564 let selector = witness_index.get_multilin_poly(selector_id)?;
565 let zero_suffix_len = count_zero_suffixes(&selector);
566
567 let nonzero_prefix_len = (1 << selector.n_vars()) - zero_suffix_len;
568 witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
569 selector_id,
570 selector,
571 nonzero_prefix_len,
572 )])?;
573 }
574 }
575
576 let max_n_mixed = flushes
578 .iter()
579 .map(|flush| flush.oracles.len())
580 .max()
581 .unwrap_or_default();
582 let mixing_powers = powers(mixing_challenge)
583 .take(max_n_mixed)
584 .collect::<Vec<_>>();
585
586 let indices_to_update = flush_oracle_ids
588 .par_iter()
589 .zip(flushes)
590 .map(|(&flush_oracle, flush)| {
591 let n_vars = oracles.n_vars(flush_oracle);
592
593 let const_term = flush
594 .oracles
595 .iter()
596 .copied()
597 .zip(mixing_powers.iter())
598 .filter_map(|(oracle_or_const, coeff)| match oracle_or_const {
599 OracleOrConst::Const { base, .. } => Some(base * coeff),
600 _ => None,
601 })
602 .sum::<FExt<Tower>>();
603 let const_term = permutation_challenges[flush.channel_id] + const_term;
604
605 let inner_oracles = flush
606 .oracles
607 .iter()
608 .copied()
609 .zip(mixing_powers.iter())
610 .filter_map(|(oracle_or_const, &coeff)| match oracle_or_const {
611 OracleOrConst::Oracle(oracle_id) => Some((oracle_id, coeff)),
612 _ => None,
613 })
614 .map(|(inner_id, coeff)| {
615 let witness = witness_index.get_multilin_poly(inner_id)?;
616 Ok((witness, coeff))
617 })
618 .collect::<Result<Vec<_>, Error>>()?;
619
620 let selector_entries = flush
621 .selectors
622 .iter()
623 .map(|id| witness_index.get_index_entry(*id))
624 .collect::<Result<Vec<_>, _>>()?;
625
626 let selector_prefix_len = selector_entries
628 .iter()
629 .map(|selector_entry| selector_entry.nonzero_scalars_prefix)
630 .min()
631 .unwrap_or(1 << n_vars);
632
633 let selectors = selector_entries
634 .into_iter()
635 .map(|entry| entry.multilin_poly)
636 .collect::<Vec<_>>();
637
638 let log_width = <PackedType<U, FExt<Tower>>>::LOG_WIDTH;
639 let packed_selector_prefix_len = selector_prefix_len.div_ceil(1 << log_width);
640
641 let mut witness_data = Vec::with_capacity(1 << n_vars.saturating_sub(log_width));
642 (0..packed_selector_prefix_len)
643 .into_par_iter()
644 .map(|i| {
645 <PackedType<U, FExt<Tower>>>::from_fn(|j| {
646 let index = i << log_width | j;
647
648 let selector_off = selectors.iter().any(|selector| {
650 let sel_val = selector
651 .evaluate_on_hypercube(index)
652 .expect("index < 1 << n_vars");
653 sel_val.is_zero()
654 });
655
656 if selector_off {
657 <FExt<Tower>>::ONE
659 } else {
660 let mut inner_oracles_iter = inner_oracles.iter();
662
663 if let Some((poly, coeff)) = inner_oracles_iter.next() {
666 let first_term = if *coeff == FExt::<Tower>::ONE {
667 poly.evaluate_on_hypercube(index).expect("index in bounds")
668 } else {
669 poly.evaluate_on_hypercube_and_scale(index, *coeff)
670 .expect("index in bounds")
671 };
672 inner_oracles_iter.fold(
673 const_term + first_term,
674 |sum, (poly, coeff)| {
675 let scaled_eval = poly
676 .evaluate_on_hypercube_and_scale(index, *coeff)
677 .expect("index in bounds");
678 sum + scaled_eval
679 },
680 )
681 } else {
682 const_term
683 }
684 }
685 })
686 })
687 .collect_into_vec(&mut witness_data);
688 witness_data.resize(witness_data.capacity(), PackedType::<U, FExt<Tower>>::one());
689
690 let witness = MLEDirectAdapter::from(
691 MultilinearExtension::new(n_vars, witness_data)
692 .expect("witness_data created with correct n_vars"),
693 );
694 Ok((witness, selector_prefix_len))
697 })
698 .collect::<Result<Vec<_>, Error>>()?;
699
700 witness_index.update_multilin_poly_with_nonzero_scalars_prefixes(
701 iter::zip(flush_oracle_ids, indices_to_update).map(
702 |(&oracle_id, (witness, nonzero_scalars_prefix))| {
703 (oracle_id, witness.upcast_arc_dyn(), nonzero_scalars_prefix)
704 },
705 ),
706 )?;
707 Ok(())
708}
709
710fn count_zero_suffixes<P: PackedField, M: MultilinearPoly<P>>(poly: &M) -> usize {
711 let zeros = P::zero();
712 if let Some(packed_evals) = poly.packed_evals() {
713 let packed_zero_suffix_len = packed_evals
714 .iter()
715 .rev()
716 .position(|&packed_eval| packed_eval != zeros)
717 .unwrap_or(packed_evals.len());
718
719 let log_scalars_per_elem = P::LOG_WIDTH + poly.log_extension_degree();
720 if poly.n_vars() < log_scalars_per_elem {
721 debug_assert_eq!(packed_evals.len(), 1, "invariant of MultilinearPoly");
722 packed_zero_suffix_len << poly.n_vars()
723 } else {
724 packed_zero_suffix_len << log_scalars_per_elem
725 }
726 } else {
727 0
728 }
729}
730
731#[allow(clippy::type_complexity)]
758#[instrument(skip_all, level = "debug")]
759fn convert_witnesses_to_fast_ext<'a, U, Tower>(
760 oracles: &MultilinearOracleSet<FExt<Tower>>,
761 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
762 oracle_ids: &[OracleId],
763) -> Result<Vec<(usize, Vec<PackedType<U, FFastExt<Tower>>>)>, Error>
764where
765 U: ProverTowerUnderlier<Tower>,
766 Tower: ProverTowerFamily,
767 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
768{
769 let to_fast = Tower::packed_transformation_to_fast();
770
771 oracle_ids
773 .into_par_iter()
774 .map(|&flush_oracle_id| {
775 let n_vars = oracles.n_vars(flush_oracle_id);
776
777 let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
778
779 let IndexEntry {
780 multilin_poly: poly,
781 nonzero_scalars_prefix,
782 } = witness.get_index_entry(flush_oracle_id)?;
783
784 const MAX_SUBCUBE_VARS: usize = 8;
785 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
786 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
787 let non_const_scalars = nonzero_scalars_prefix;
788 let non_const_subcubes = non_const_scalars.div_ceil(1 << subcube_vars);
789
790 let mut fast_ext_result = zeroed_vec(non_const_subcubes * subcube_packed_size);
791 fast_ext_result
792 .par_chunks_exact_mut(subcube_packed_size)
793 .enumerate()
794 .for_each(|(subcube_index, fast_subcube)| {
795 let underliers =
796 PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
797
798 let subcube_evals =
799 PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
800 poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
801 .expect("witness data populated by make_unmasked_flush_witnesses()");
802
803 for underlier in underliers.iter_mut() {
804 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
805 let dest = to_fast.transform(&src);
806 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
807 }
808 });
809
810 fast_ext_result.truncate(non_const_scalars);
811 Ok((n_vars, fast_ext_result))
812 })
813 .collect()
814}