binius_core/constraint_system/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{env, iter, marker::PhantomData};
4
5use binius_field::{
6	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
7	RepackedExtension, TowerField,
8	as_packed_field::PackedType,
9	linear_transformation::{PackedTransformationFactory, Transformation},
10	tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
11	underlier::WithUnderlier,
12	util::powers,
13};
14use binius_hal::ComputationBackend;
15use binius_hash::PseudoCompressionFunction;
16use binius_math::{
17	DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
18	IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
19};
20use binius_maybe_rayon::prelude::*;
21use binius_ntt::SingleThreadedNTT;
22use binius_utils::bail;
23use bytemuck::zeroed_vec;
24use digest::{Digest, FixedOutputReset, Output, core_api::BlockSizeUser};
25use itertools::chain;
26use tracing::instrument;
27
28use super::{
29	ConstraintSystem, Proof,
30	channel::Boundary,
31	error::Error,
32	verify::{make_flush_oracles, max_n_vars_and_skip_rounds},
33};
34use crate::{
35	constraint_system::{
36		Flush,
37		channel::OracleOrConst,
38		common::{FDomain, FEncode, FExt, FFastExt},
39		exp::{self, reorder_exponents},
40	},
41	fiat_shamir::{CanSample, Challenger},
42	merkle_tree::BinaryMerkleTreeProver,
43	oracle::{Constraint, MultilinearOracleSet, OracleId},
44	piop,
45	protocols::{
46		fri::CommitOutput,
47		gkr_exp,
48		gkr_gpa::{self, GrandProductBatchProveOutput, GrandProductWitness},
49		greedy_evalcheck::{self, GreedyEvalcheckProveOutput},
50		sumcheck::{
51			self, constraint_set_zerocheck_claim, prove::ZerocheckProver,
52			standard_switchover_heuristic,
53		},
54	},
55	ring_switch,
56	transcript::ProverTranscript,
57	witness::{IndexEntry, MultilinearExtensionIndex, MultilinearWitness},
58};
59
60/// Generates a proof that a witness satisfies a constraint system with the standard FRI PCS.
61#[instrument("constraint_system::prove", skip_all, level = "debug")]
62pub fn prove<U, Tower, Hash, Compress, Challenger_, Backend>(
63	constraint_system: &ConstraintSystem<FExt<Tower>>,
64	log_inv_rate: usize,
65	security_bits: usize,
66	boundaries: &[Boundary<FExt<Tower>>],
67	mut witness: MultilinearExtensionIndex<PackedType<U, FExt<Tower>>>,
68	backend: &Backend,
69) -> Result<Proof, Error>
70where
71	U: ProverTowerUnderlier<Tower>,
72	Tower: ProverTowerFamily,
73	Tower::B128: PackedTop<Tower>,
74	Hash: Digest + BlockSizeUser + FixedOutputReset + Send + Sync + Clone,
75	Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
76	Challenger_: Challenger + Default,
77	Backend: ComputationBackend,
78	// REVIEW: Consider changing TowerFamily and associated traits to shorten/remove these bounds
79	PackedType<U, Tower::B128>: PackedTop<Tower>
80		+ PackedFieldIndexable // REVIEW: remove this bound after piop::commit is adjusted
81		+ RepackedExtension<PackedType<U, Tower::B8>>
82		+ RepackedExtension<PackedType<U, Tower::B16>>
83		+ RepackedExtension<PackedType<U, Tower::B32>>
84		+ RepackedExtension<PackedType<U, Tower::B64>>
85		+ RepackedExtension<PackedType<U, Tower::B128>>
86		+ PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
87	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
88{
89	tracing::debug!(
90		arch = env::consts::ARCH,
91		rayon_threads = binius_maybe_rayon::current_num_threads(),
92		"using computation backend: {backend:?}"
93	);
94
95	let domain_factory = DefaultEvaluationDomainFactory::<FDomain<Tower>>::default();
96	let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
97
98	let mut transcript = ProverTranscript::<Challenger_>::new();
99	transcript.observe().write_slice(boundaries);
100
101	let ConstraintSystem {
102		mut oracles,
103		mut table_constraints,
104		mut flushes,
105		mut exponents,
106		non_zero_oracle_ids,
107		max_channel_id,
108	} = constraint_system.clone();
109
110	reorder_exponents(&mut exponents, &oracles);
111
112	let witness_span = tracing::info_span!(
113		"[phase] Witness Finalization",
114		phase = "witness",
115		perfetto_category = "phase.main"
116	)
117	.entered();
118
119	// We must generate multiplication witnesses before committing, as this function
120	// adds the committed witnesses for exponentiation results to the witness index.
121	let exp_compute_layer_span = tracing::info_span!(
122		"[step] Compute Exponentiation Layers",
123		phase = "witness",
124		perfetto_category = "phase.sub"
125	)
126	.entered();
127	let exp_witnesses = exp::make_exp_witnesses::<U, Tower>(&mut witness, &oracles, &exponents)?;
128	drop(exp_compute_layer_span);
129
130	drop(witness_span);
131
132	// Stable sort constraint sets in ascending order by number of variables.
133	table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
134
135	// Commit polynomials
136	let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
137	let merkle_scheme = merkle_prover.scheme();
138
139	let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
140	let committed_multilins = piop::collect_committed_witnesses::<U, _>(
141		&commit_meta,
142		&oracle_to_commit_index,
143		&oracles,
144		&witness,
145	)?;
146
147	let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
148		&commit_meta,
149		merkle_scheme,
150		security_bits,
151		log_inv_rate,
152	)?;
153	let ntt = SingleThreadedNTT::new(fri_params.rs_code().log_len())?
154		.precompute_twiddles()
155		.multithreaded();
156
157	let commit_span =
158		tracing::info_span!("[phase] Commit", phase = "commit", perfetto_category = "phase.main")
159			.entered();
160	let CommitOutput {
161		commitment,
162		committed,
163		codeword,
164	} = piop::commit(&fri_params, &ntt, &merkle_prover, &committed_multilins)?;
165	drop(commit_span);
166
167	// Observe polynomial commitment
168	let mut writer = transcript.message();
169	writer.write(&commitment);
170
171	let exp_span = tracing::info_span!(
172		"[phase] Exponentiation",
173		phase = "exp",
174		perfetto_category = "phase.main"
175	)
176	.entered();
177	let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
178
179	let exp_evals = gkr_exp::get_evals_in_point_from_witnesses(&exp_witnesses, &exp_challenge)?
180		.into_iter()
181		.map(|x| x.into())
182		.collect::<Vec<_>>();
183
184	let mut writer = transcript.message();
185	writer.write_scalar_slice(&exp_evals);
186
187	let exp_challenge = exp_challenge
188		.into_iter()
189		.map(|x| x.into())
190		.collect::<Vec<_>>();
191
192	let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
193		.into_iter()
194		.map(|claim| claim.isomorphic())
195		.collect::<Vec<_>>();
196
197	let base_exp_output = gkr_exp::batch_prove::<_, _, FFastExt<Tower>, _, _>(
198		EvaluationOrder::HighToLow,
199		exp_witnesses,
200		&exp_claims,
201		fast_domain_factory.clone(),
202		&mut transcript,
203		backend,
204	)?
205	.isomorphic();
206
207	let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
208	drop(exp_span);
209
210	// Grand product arguments
211	// Grand products for non-zero checking
212	let prodcheck_span = tracing::info_span!(
213		"[phase] Product Check",
214		phase = "prodcheck",
215		perfetto_category = "phase.main"
216	)
217	.entered();
218
219	let nonzero_convert_span = tracing::info_span!(
220		"[task] Convert Non-Zero to Fast Field",
221		phase = "prodcheck",
222		perfetto_category = "task.main"
223	)
224	.entered();
225	let non_zero_fast_witnesses =
226		convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &non_zero_oracle_ids)?;
227	drop(nonzero_convert_span);
228
229	let nonzero_prodcheck_compute_layer_span = tracing::info_span!(
230		"[step] Compute Non-Zero Product Layers",
231		phase = "prodcheck",
232		perfetto_category = "phase.sub"
233	)
234	.entered();
235	let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
236		.into_par_iter()
237		.map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
238		.collect::<Result<Vec<_>, _>>()?;
239	drop(nonzero_prodcheck_compute_layer_span);
240
241	let non_zero_products =
242		gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
243	if non_zero_products
244		.iter()
245		.any(|count| *count == Tower::B128::zero())
246	{
247		bail!(Error::Zeros);
248	}
249
250	let mut writer = transcript.message();
251
252	writer.write_scalar_slice(&non_zero_products);
253
254	let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
255		&non_zero_oracle_ids,
256		&oracles,
257		&non_zero_products,
258	)?;
259
260	// Grand products for flushing
261	let mixing_challenge = transcript.sample();
262	let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
263
264	flushes.sort_by_key(|flush| flush.channel_id);
265	let flush_oracle_ids =
266		make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
267
268	let flush_convert_span = tracing::info_span!(
269		"[task] Convert Flushes to Fast Field",
270		phase = "prodcheck",
271		perfetto_category = "task.main"
272	)
273	.entered();
274	make_masked_flush_witnesses::<U, _>(
275		&oracles,
276		&mut witness,
277		&flush_oracle_ids,
278		&flushes,
279		mixing_challenge,
280		&permutation_challenges,
281	)?;
282
283	// there are no oracle ids associated with these flush_witnesses
284	let flush_witnesses =
285		convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &flush_oracle_ids)?;
286	drop(flush_convert_span);
287
288	let flush_prodcheck_compute_layer_span = tracing::info_span!(
289		"[step] Compute Flush Product Layers",
290		phase = "prodcheck",
291		perfetto_category = "phase.sub"
292	)
293	.entered();
294	let flush_prodcheck_witnesses = flush_witnesses
295		.into_par_iter()
296		.map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
297		.collect::<Result<Vec<_>, _>>()?;
298	drop(flush_prodcheck_compute_layer_span);
299
300	let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
301
302	transcript.message().write_scalar_slice(&flush_products);
303
304	let flush_prodcheck_claims =
305		gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
306
307	// Prove grand products
308	let all_gpa_witnesses =
309		chain!(flush_prodcheck_witnesses, non_zero_prodcheck_witnesses).collect::<Vec<_>>();
310	let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
311		.map(|claim| claim.isomorphic())
312		.collect::<Vec<_>>();
313
314	let GrandProductBatchProveOutput { final_layer_claims } =
315		gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
316			EvaluationOrder::HighToLow,
317			all_gpa_witnesses,
318			&all_gpa_claims,
319			&fast_domain_factory,
320			&mut transcript,
321			backend,
322		)?;
323
324	// Apply isomorphism to the layer claims
325	let final_layer_claims = final_layer_claims
326		.into_iter()
327		.map(|layer_claim| layer_claim.isomorphic())
328		.collect::<Vec<_>>();
329
330	// Reduce non_zero_final_layer_claims to evalcheck claims
331	let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
332		chain!(flush_oracle_ids, non_zero_oracle_ids),
333		final_layer_claims,
334	)?;
335	drop(prodcheck_span);
336
337	// Zerocheck
338	let zerocheck_span = tracing::info_span!(
339		"[phase] Zerocheck",
340		phase = "zerocheck",
341		perfetto_category = "phase.main",
342	)
343	.entered();
344
345	let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
346		.iter()
347		.cloned()
348		.map(constraint_set_zerocheck_claim)
349		.collect::<Result<Vec<_>, _>>()?
350		.into_iter()
351		.unzip::<_, _, Vec<_>, Vec<_>>();
352
353	let (max_n_vars, skip_rounds) =
354		max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
355
356	let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
357
358	let mut zerocheck_provers = Vec::with_capacity(table_constraints.len());
359
360	for constraint_set in table_constraints {
361		let n_vars = constraint_set.n_vars;
362		let (constraints, multilinears) =
363			sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
364
365		let base_tower_level = chain!(
366			multilinears
367				.iter()
368				.map(|multilinear| 7 - multilinear.log_extension_degree()),
369			constraints
370				.iter()
371				.map(|constraint| constraint.composition.binary_tower_level())
372		)
373		.max()
374		.unwrap_or(0);
375
376		// Per prover zerocheck challenges are justified on the high indexed variables
377		let zerocheck_challenges = &zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..];
378		let domain_factory = domain_factory.clone();
379
380		let constructor =
381			ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _> {
382				constraints,
383				multilinears,
384				zerocheck_challenges,
385				domain_factory,
386				backend,
387				_fdomain_marker: PhantomData,
388			};
389
390		let zerocheck_prover = match base_tower_level {
391			0..=3 => constructor.create::<Tower::B8>()?,
392			4 => constructor.create::<Tower::B16>()?,
393			5 => constructor.create::<Tower::B32>()?,
394			6 => constructor.create::<Tower::B64>()?,
395			7 => constructor.create::<Tower::B128>()?,
396			_ => unreachable!(),
397		};
398
399		zerocheck_provers.push(zerocheck_prover);
400	}
401
402	let zerocheck_output = sumcheck::prove::batch_prove_zerocheck::<
403		FExt<Tower>,
404		FDomain<Tower>,
405		PackedType<U, FExt<Tower>>,
406		_,
407		_,
408	>(zerocheck_provers, skip_rounds, &mut transcript)?;
409
410	let zerocheck_eval_claims =
411		sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
412
413	drop(zerocheck_span);
414
415	let evalcheck_span = tracing::info_span!(
416		"[phase] Evalcheck",
417		phase = "evalcheck",
418		perfetto_category = "phase.main"
419	)
420	.entered();
421
422	// Prove evaluation claims
423	let GreedyEvalcheckProveOutput {
424		eval_claims,
425		memoized_data,
426	} = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
427		&mut oracles,
428		&mut witness,
429		chain!(prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
430		standard_switchover_heuristic(-2),
431		&mut transcript,
432		&domain_factory,
433		backend,
434	)?;
435
436	// Reduce committed evaluation claims to PIOP sumcheck claims
437	let system = ring_switch::EvalClaimSystem::new(
438		&oracles,
439		&commit_meta,
440		&oracle_to_commit_index,
441		&eval_claims,
442	)?;
443
444	drop(evalcheck_span);
445
446	let ring_switch_span = tracing::info_span!(
447		"[phase] Ring Switch",
448		phase = "ring_switch",
449		perfetto_category = "phase.main"
450	)
451	.entered();
452	let ring_switch::ReducedWitness {
453		transparents: transparent_multilins,
454		sumcheck_claims: piop_sumcheck_claims,
455	} = ring_switch::prove::<_, _, _, Tower, _>(
456		&system,
457		&committed_multilins,
458		&mut transcript,
459		memoized_data,
460	)?;
461	drop(ring_switch_span);
462
463	// Prove evaluation claims using PIOP compiler
464	let piop_compiler_span = tracing::info_span!(
465		"[phase] PIOP Compiler",
466		phase = "piop_compiler",
467		perfetto_category = "phase.main"
468	)
469	.entered();
470	piop::prove::<_, FDomain<Tower>, _, _, _, _, _, _, _, _, _>(
471		&fri_params,
472		&ntt,
473		&merkle_prover,
474		domain_factory,
475		&commit_meta,
476		committed,
477		&codeword,
478		&committed_multilins,
479		&transparent_multilins,
480		&piop_sumcheck_claims,
481		&mut transcript,
482		&backend,
483	)?;
484	drop(piop_compiler_span);
485
486	let proof = Proof {
487		transcript: transcript.finalize(),
488	};
489
490	tracing::event!(
491		name: "proof_size",
492		tracing::Level::INFO,
493		counter = true,
494		value = proof.get_proof_size() as u64,
495		unit = "bytes",
496	);
497
498	Ok(proof)
499}
500
501type TypeErasedZerocheck<'a, P> = Box<dyn ZerocheckProver<'a, P> + 'a>;
502
503struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
504where
505	P: PackedField,
506{
507	constraints: Vec<Constraint<P::Scalar>>,
508	multilinears: Vec<MultilinearWitness<'a, P>>,
509	domain_factory: DomainFactory,
510	zerocheck_challenges: &'a [P::Scalar],
511	backend: &'a Backend,
512	_fdomain_marker: PhantomData<FDomain>,
513}
514
515impl<'a, P, F, FDomain, DomainFactory, Backend>
516	ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
517where
518	F: Field,
519	P: PackedField<Scalar = F>,
520	FDomain: TowerField,
521	DomainFactory: EvaluationDomainFactory<FDomain> + 'a,
522	Backend: ComputationBackend,
523{
524	fn create<FBase>(self) -> Result<TypeErasedZerocheck<'a, P>, Error>
525	where
526		FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
527		P: PackedExtension<F, PackedSubfield = P>
528			+ PackedExtension<FDomain>
529			+ PackedExtension<FBase>,
530		F: TowerField,
531	{
532		let zerocheck_prover =
533			sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _, _>(
534				self.constraints,
535				self.multilinears,
536				self.domain_factory,
537				self.zerocheck_challenges,
538				self.backend,
539			)?;
540
541		let type_erased_zerocheck_prover = Box::new(zerocheck_prover) as TypeErasedZerocheck<'a, P>;
542
543		Ok(type_erased_zerocheck_prover)
544	}
545}
546
547#[instrument(skip_all, level = "debug")]
548fn make_masked_flush_witnesses<'a, U, Tower>(
549	oracles: &MultilinearOracleSet<FExt<Tower>>,
550	witness_index: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
551	flush_oracle_ids: &[OracleId],
552	flushes: &[Flush<FExt<Tower>>],
553	mixing_challenge: FExt<Tower>,
554	permutation_challenges: &[FExt<Tower>],
555) -> Result<(), Error>
556where
557	U: ProverTowerUnderlier<Tower>,
558	Tower: ProverTowerFamily,
559{
560	// TODO: Move me out into a separate function & deduplicate.
561	// Count the suffix zeros on all selectors.
562	for flush in flushes {
563		for &selector_id in &flush.selectors {
564			let selector = witness_index.get_multilin_poly(selector_id)?;
565			let zero_suffix_len = count_zero_suffixes(&selector);
566
567			let nonzero_prefix_len = (1 << selector.n_vars()) - zero_suffix_len;
568			witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
569				selector_id,
570				selector,
571				nonzero_prefix_len,
572			)])?;
573		}
574	}
575
576	// Find the maximum power of the mixing challenge needed.
577	let max_n_mixed = flushes
578		.iter()
579		.map(|flush| flush.oracles.len())
580		.max()
581		.unwrap_or_default();
582	let mixing_powers = powers(mixing_challenge)
583		.take(max_n_mixed)
584		.collect::<Vec<_>>();
585
586	// The function is on the critical path, parallelize.
587	let indices_to_update = flush_oracle_ids
588		.par_iter()
589		.zip(flushes)
590		.map(|(&flush_oracle, flush)| {
591			let n_vars = oracles.n_vars(flush_oracle);
592
593			let const_term = flush
594				.oracles
595				.iter()
596				.copied()
597				.zip(mixing_powers.iter())
598				.filter_map(|(oracle_or_const, coeff)| match oracle_or_const {
599					OracleOrConst::Const { base, .. } => Some(base * coeff),
600					_ => None,
601				})
602				.sum::<FExt<Tower>>();
603			let const_term = permutation_challenges[flush.channel_id] + const_term;
604
605			let inner_oracles = flush
606				.oracles
607				.iter()
608				.copied()
609				.zip(mixing_powers.iter())
610				.filter_map(|(oracle_or_const, &coeff)| match oracle_or_const {
611					OracleOrConst::Oracle(oracle_id) => Some((oracle_id, coeff)),
612					_ => None,
613				})
614				.map(|(inner_id, coeff)| {
615					let witness = witness_index.get_multilin_poly(inner_id)?;
616					Ok((witness, coeff))
617				})
618				.collect::<Result<Vec<_>, Error>>()?;
619
620			let selector_entries = flush
621				.selectors
622				.iter()
623				.map(|id| witness_index.get_index_entry(*id))
624				.collect::<Result<Vec<_>, _>>()?;
625
626			// Get the number of entries before any selector column is fully disabled.
627			let selector_prefix_len = selector_entries
628				.iter()
629				.map(|selector_entry| selector_entry.nonzero_scalars_prefix)
630				.min()
631				.unwrap_or(1 << n_vars);
632
633			let selectors = selector_entries
634				.into_iter()
635				.map(|entry| entry.multilin_poly)
636				.collect::<Vec<_>>();
637
638			let log_width = <PackedType<U, FExt<Tower>>>::LOG_WIDTH;
639			let packed_selector_prefix_len = selector_prefix_len.div_ceil(1 << log_width);
640
641			let mut witness_data = Vec::with_capacity(1 << n_vars.saturating_sub(log_width));
642			(0..packed_selector_prefix_len)
643				.into_par_iter()
644				.map(|i| {
645					<PackedType<U, FExt<Tower>>>::from_fn(|j| {
646						let index = i << log_width | j;
647
648						// Compute the product of all selectors at this point
649						let selector_off = selectors.iter().any(|selector| {
650							let sel_val = selector
651								.evaluate_on_hypercube(index)
652								.expect("index < 1 << n_vars");
653							sel_val.is_zero()
654						});
655
656						if selector_off {
657							// If any selector is zero, the result is 1
658							<FExt<Tower>>::ONE
659						} else {
660							// Otherwise, compute the linear combination
661							let mut inner_oracles_iter = inner_oracles.iter();
662
663							// Handle the first one specially because the mixing power is ONE,
664							// unless the first oracle was a constant.
665							if let Some((poly, coeff)) = inner_oracles_iter.next() {
666								let first_term = if *coeff == FExt::<Tower>::ONE {
667									poly.evaluate_on_hypercube(index).expect("index in bounds")
668								} else {
669									poly.evaluate_on_hypercube_and_scale(index, *coeff)
670										.expect("index in bounds")
671								};
672								inner_oracles_iter.fold(
673									const_term + first_term,
674									|sum, (poly, coeff)| {
675										let scaled_eval = poly
676											.evaluate_on_hypercube_and_scale(index, *coeff)
677											.expect("index in bounds");
678										sum + scaled_eval
679									},
680								)
681							} else {
682								const_term
683							}
684						}
685					})
686				})
687				.collect_into_vec(&mut witness_data);
688			witness_data.resize(witness_data.capacity(), PackedType::<U, FExt<Tower>>::one());
689
690			let witness = MLEDirectAdapter::from(
691				MultilinearExtension::new(n_vars, witness_data)
692					.expect("witness_data created with correct n_vars"),
693			);
694			// TODO: This is sketchy. The field on witness index is called "nonzero_prefix", but
695			// I'm setting it when the suffix is 1, not zero.
696			Ok((witness, selector_prefix_len))
697		})
698		.collect::<Result<Vec<_>, Error>>()?;
699
700	witness_index.update_multilin_poly_with_nonzero_scalars_prefixes(
701		iter::zip(flush_oracle_ids, indices_to_update).map(
702			|(&oracle_id, (witness, nonzero_scalars_prefix))| {
703				(oracle_id, witness.upcast_arc_dyn(), nonzero_scalars_prefix)
704			},
705		),
706	)?;
707	Ok(())
708}
709
710fn count_zero_suffixes<P: PackedField, M: MultilinearPoly<P>>(poly: &M) -> usize {
711	let zeros = P::zero();
712	if let Some(packed_evals) = poly.packed_evals() {
713		let packed_zero_suffix_len = packed_evals
714			.iter()
715			.rev()
716			.position(|&packed_eval| packed_eval != zeros)
717			.unwrap_or(packed_evals.len());
718
719		let log_scalars_per_elem = P::LOG_WIDTH + poly.log_extension_degree();
720		if poly.n_vars() < log_scalars_per_elem {
721			debug_assert_eq!(packed_evals.len(), 1, "invariant of MultilinearPoly");
722			packed_zero_suffix_len << poly.n_vars()
723		} else {
724			packed_zero_suffix_len << log_scalars_per_elem
725		}
726	} else {
727		0
728	}
729}
730
731/// Converts specified oracles' witness representations from the base extension field
732/// to the fast extension field format for optimized grand product calculations.
733///
734/// This function processes the provided list of oracle IDs, extracting the corresponding
735/// multilinear polynomials from the witness index, and converting their evaluations
736/// to the fast field representation. The conversion is performed efficiently using
737/// the tower transformation infrastructure.
738///
739/// # Performance Considerations
740/// - This function is optimized for parallel execution as it's on the critical path of the proving
741///   system.
742///
743/// # Arguments
744/// * `oracles` - Reference to the multilinear oracle set containing metadata for all oracles
745/// * `witness` - Reference to the witness index containing the multilinear polynomial evaluations
746/// * `oracle_ids` - Slice of oracle IDs for which to generate fast field representations
747///
748/// # Returns
749/// A vector of tuples, where each tuple contains:
750/// - The number of variables in the oracle's multilinear polynomial
751/// - A vector of packed field elements representing the polynomial's evaluations in the fast field
752///
753/// # Errors
754/// Returns an error if:
755/// - Any oracle ID is invalid or not found in the witness index
756/// - Subcube evaluation fails for any polynomial
757#[allow(clippy::type_complexity)]
758#[instrument(skip_all, level = "debug")]
759fn convert_witnesses_to_fast_ext<'a, U, Tower>(
760	oracles: &MultilinearOracleSet<FExt<Tower>>,
761	witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
762	oracle_ids: &[OracleId],
763) -> Result<Vec<(usize, Vec<PackedType<U, FFastExt<Tower>>>)>, Error>
764where
765	U: ProverTowerUnderlier<Tower>,
766	Tower: ProverTowerFamily,
767	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
768{
769	let to_fast = Tower::packed_transformation_to_fast();
770
771	// The function is on the critical path, parallelize.
772	oracle_ids
773		.into_par_iter()
774		.map(|&flush_oracle_id| {
775			let n_vars = oracles.n_vars(flush_oracle_id);
776
777			let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
778
779			let IndexEntry {
780				multilin_poly: poly,
781				nonzero_scalars_prefix,
782			} = witness.get_index_entry(flush_oracle_id)?;
783
784			const MAX_SUBCUBE_VARS: usize = 8;
785			let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
786			let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
787			let non_const_scalars = nonzero_scalars_prefix;
788			let non_const_subcubes = non_const_scalars.div_ceil(1 << subcube_vars);
789
790			let mut fast_ext_result = zeroed_vec(non_const_subcubes * subcube_packed_size);
791			fast_ext_result
792				.par_chunks_exact_mut(subcube_packed_size)
793				.enumerate()
794				.for_each(|(subcube_index, fast_subcube)| {
795					let underliers =
796						PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
797
798					let subcube_evals =
799						PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
800					poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
801						.expect("witness data populated by make_unmasked_flush_witnesses()");
802
803					for underlier in underliers.iter_mut() {
804						let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
805						let dest = to_fast.transform(&src);
806						*underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
807					}
808				});
809
810			fast_ext_result.truncate(non_const_scalars);
811			Ok((n_vars, fast_ext_result))
812		})
813		.collect()
814}