binius_core/constraint_system/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{collections::HashSet, env, iter, marker::PhantomData};
4
5use binius_compute::{ComputeData, ComputeLayer, alloc::ComputeAllocator, cpu::CpuMemory};
6use binius_field::{
7	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
8	RepackedExtension, TowerField,
9	as_packed_field::PackedType,
10	linear_transformation::{PackedTransformationFactory, Transformation},
11	tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
12	underlier::WithUnderlier,
13	util::powers,
14};
15use binius_hal::ComputationBackend;
16use binius_hash::{PseudoCompressionFunction, multi_digest::ParallelDigest};
17use binius_math::{
18	CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
19	IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
20};
21use binius_maybe_rayon::prelude::*;
22use binius_ntt::SingleThreadedNTT;
23use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
24use bytemuck::zeroed_vec;
25use digest::{FixedOutputReset, Output, core_api::BlockSizeUser};
26use itertools::chain;
27use tracing::instrument;
28use tracing_profile::utils::emit_max_rss;
29
30use super::{
31	ConstraintSystem, Proof,
32	channel::Boundary,
33	error::Error,
34	verify::{make_flush_oracles, max_n_vars_and_skip_rounds},
35};
36use crate::{
37	constraint_system::{
38		Flush,
39		channel::OracleOrConst,
40		common::{FDomain, FEncode, FExt, FFastExt},
41		exp::{self, reorder_exponents},
42		verify::augment_flush_po2_step_down,
43	},
44	fiat_shamir::{CanSample, Challenger},
45	merkle_tree::BinaryMerkleTreeProver,
46	oracle::{
47		Constraint, ConstraintSetBuilder, MultilinearOracleSet, MultilinearPolyVariant, OracleId,
48		SizedConstraintSet,
49	},
50	piop,
51	protocols::{
52		evalcheck::{
53			ConstraintSetEqIndPoint, EvalPoint, EvalcheckMultilinearClaim,
54			subclaims::{MemoizedData, prove_mlecheck_with_switchover},
55		},
56		fri::CommitOutput,
57		gkr_exp,
58		gkr_gpa::{self, GrandProductBatchProveOutput, GrandProductWitness},
59		greedy_evalcheck::{self, GreedyEvalcheckProveOutput},
60		sumcheck::{
61			self, constraint_set_zerocheck_claim, immediate_switchover_heuristic,
62			prove::ZerocheckProver, standard_switchover_heuristic,
63		},
64	},
65	ring_switch,
66	transcript::ProverTranscript,
67	transparent::step_down::StepDown,
68	witness::{IndexEntry, MultilinearExtensionIndex, MultilinearWitness},
69};
70
71/// Generates a proof that a witness satisfies a constraint system with the standard FRI PCS.
72#[allow(clippy::too_many_arguments)]
73#[instrument("constraint_system::prove", skip_all, level = "debug")]
74pub fn prove<
75	Hal,
76	U,
77	Tower,
78	Hash,
79	Compress,
80	Challenger_,
81	Backend,
82	HostAllocatorType,
83	DeviceAllocatorType,
84>(
85	compute_data: &mut ComputeData<Tower::B128, Hal, HostAllocatorType, DeviceAllocatorType>,
86	constraint_system: &ConstraintSystem<FExt<Tower>>,
87	log_inv_rate: usize,
88	security_bits: usize,
89	constraint_system_digest: &Output<Hash::Digest>,
90	boundaries: &[Boundary<FExt<Tower>>],
91	table_sizes: &[usize],
92	mut witness: MultilinearExtensionIndex<PackedType<U, FExt<Tower>>>,
93	backend: &Backend,
94) -> Result<Proof, Error>
95where
96	Hal: ComputeLayer<Tower::B128> + Default,
97	U: ProverTowerUnderlier<Tower>,
98	Tower: ProverTowerFamily,
99	Tower::B128:
100		binius_math::TowerTop + binius_math::PackedTop + PackedTop<Tower> + From<FFastExt<Tower>>,
101	Hash: ParallelDigest,
102	Hash::Digest: BlockSizeUser + FixedOutputReset + Send + Sync + Clone,
103	Compress: PseudoCompressionFunction<Output<Hash::Digest>, 2> + Default + Sync,
104	Challenger_: Challenger + Default,
105	Backend: ComputationBackend,
106	// REVIEW: Consider changing TowerFamily and associated traits to shorten/remove these bounds
107	PackedType<U, Tower::B128>: PackedTop<Tower>
108		+ PackedFieldIndexable
109		// REVIEW: remove this bound after piop::commit is adjusted
110		+ RepackedExtension<PackedType<U, Tower::B1>>
111		+ RepackedExtension<PackedType<U, Tower::B8>>
112		+ RepackedExtension<PackedType<U, Tower::B16>>
113		+ RepackedExtension<PackedType<U, Tower::B32>>
114		+ RepackedExtension<PackedType<U, Tower::B64>>
115		+ RepackedExtension<PackedType<U, Tower::B128>>
116		+ PackedTransformationFactory<PackedType<U, Tower::FastB128>>
117		+ binius_math::PackedTop,
118	PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
119	HostAllocatorType: ComputeAllocator<Tower::B128, CpuMemory>,
120	DeviceAllocatorType: ComputeAllocator<Tower::B128, Hal::DevMem>,
121{
122	tracing::debug!(
123		arch = env::consts::ARCH,
124		rayon_threads = binius_maybe_rayon::current_num_threads(),
125		"using computation backend: {backend:?}"
126	);
127
128	let domain_factory = DefaultEvaluationDomainFactory::<FDomain<Tower>>::default();
129	let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
130
131	let ConstraintSystem {
132		oracles,
133		table_constraints,
134		mut flushes,
135		mut exponents,
136		mut non_zero_oracle_ids,
137		channel_count,
138		table_size_specs,
139	} = constraint_system.clone();
140
141	constraint_system.check_table_sizes(table_sizes)?;
142	let mut oracles = oracles.instantiate(table_sizes)?;
143
144	// Prepare the constraint system for proving:
145	//
146	// - Trim all the zero sized oracles.
147	// - Canonicalize the ordering.
148
149	flushes.retain(|flush| table_sizes[flush.table_id] > 0);
150	flushes.sort_by_key(|flush| flush.channel_id);
151
152	non_zero_oracle_ids.retain(|oracle| !oracles.is_zero_sized(*oracle));
153	exponents.retain(|exp| !oracles.is_zero_sized(exp.exp_result_id));
154
155	let mut table_constraints = table_constraints
156		.into_iter()
157		.filter_map(|u| {
158			if table_sizes[u.table_id] == 0 {
159				None
160			} else {
161				let n_vars = u.log_values_per_row + log2_ceil_usize(table_sizes[u.table_id]);
162				Some(SizedConstraintSet::new(n_vars, u))
163			}
164		})
165		.collect::<Vec<_>>();
166	// Stable sort constraint sets in ascending order by number of variables.
167	table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
168
169	reorder_exponents(&mut exponents, &oracles);
170
171	let mut transcript = ProverTranscript::<Challenger_>::new();
172	transcript
173		.observe()
174		.write_slice(constraint_system_digest.as_ref());
175	transcript.observe().write_slice(boundaries);
176	let mut writer = transcript.message();
177	writer.write_slice(table_sizes);
178
179	let witness_span = tracing::info_span!(
180		"[phase] Witness Finalization",
181		phase = "witness",
182		perfetto_category = "phase.main"
183	)
184	.entered();
185
186	// We must generate multiplication witnesses before committing, as this function
187	// adds the committed witnesses for exponentiation results to the witness index.
188	let exp_compute_layer_span = tracing::info_span!(
189		"[step] Compute Exponentiation Layers",
190		phase = "witness",
191		perfetto_category = "phase.sub"
192	)
193	.entered();
194	let exp_witnesses = exp::make_exp_witnesses::<U, Tower>(&mut witness, &oracles, &exponents)?;
195	drop(exp_compute_layer_span);
196
197	drop(witness_span);
198
199	// Commit polynomials
200	let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
201	let merkle_scheme = merkle_prover.scheme();
202
203	let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
204	let committed_multilins = piop::collect_committed_witnesses::<U, _>(
205		&commit_meta,
206		&oracle_to_commit_index,
207		&oracles,
208		&witness,
209	)?;
210
211	let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
212		&commit_meta,
213		merkle_scheme,
214		security_bits,
215		log_inv_rate,
216	)?;
217	let ntt = SingleThreadedNTT::with_subspace(fri_params.rs_code().subspace())?
218		.precompute_twiddles()
219		.multithreaded();
220
221	let commit_span =
222		tracing::info_span!("[phase] Commit", phase = "commit", perfetto_category = "phase.main")
223			.entered();
224	let CommitOutput {
225		commitment,
226		committed,
227		codeword,
228	} = piop::commit(&fri_params, &ntt, &merkle_prover, &committed_multilins)?;
229	emit_max_rss();
230	drop(commit_span);
231
232	// Observe polynomial commitment
233	let mut writer = transcript.message();
234	writer.write(&commitment);
235
236	let exp_span = tracing::info_span!(
237		"[phase] Exponentiation",
238		phase = "exp",
239		perfetto_category = "phase.main"
240	)
241	.entered();
242	let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
243
244	let exp_evals = gkr_exp::get_evals_in_point_from_witnesses(&exp_witnesses, &exp_challenge)?
245		.into_iter()
246		.map(|x| x.into())
247		.collect::<Vec<_>>();
248
249	let mut writer = transcript.message();
250	writer.write_scalar_slice(&exp_evals);
251
252	let exp_challenge = exp_challenge
253		.into_iter()
254		.map(|x| x.into())
255		.collect::<Vec<_>>();
256
257	let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
258		.into_iter()
259		.map(|claim| claim.isomorphic())
260		.collect::<Vec<_>>();
261
262	let base_exp_output = gkr_exp::batch_prove::<_, _, FFastExt<Tower>, _, _>(
263		EvaluationOrder::HighToLow,
264		exp_witnesses,
265		&exp_claims,
266		fast_domain_factory.clone(),
267		&mut transcript,
268		backend,
269	)?
270	.isomorphic();
271
272	let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
273	emit_max_rss();
274	drop(exp_span);
275
276	// Grand product arguments
277	// Grand products for non-zero checking
278	let prodcheck_span = tracing::info_span!(
279		"[phase] Product Check",
280		phase = "prodcheck",
281		perfetto_category = "phase.main"
282	)
283	.entered();
284
285	let nonzero_convert_span = tracing::info_span!(
286		"[task] Convert Non-Zero to Fast Field",
287		phase = "prodcheck",
288		perfetto_category = "task.main"
289	)
290	.entered();
291	let non_zero_fast_witnesses =
292		convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &non_zero_oracle_ids)?;
293	emit_max_rss();
294	drop(nonzero_convert_span);
295
296	let nonzero_prodcheck_compute_layer_span = tracing::info_span!(
297		"[step] Compute Non-Zero Product Layers",
298		phase = "prodcheck",
299		perfetto_category = "phase.sub"
300	)
301	.entered();
302	let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
303		.into_par_iter()
304		.map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
305		.collect::<Result<Vec<_>, _>>()?;
306	emit_max_rss();
307	drop(nonzero_prodcheck_compute_layer_span);
308
309	let non_zero_products =
310		gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
311	if non_zero_products
312		.iter()
313		.any(|count| *count == Tower::B128::zero())
314	{
315		bail!(Error::Zeros);
316	}
317
318	let mut writer = transcript.message();
319
320	writer.write_scalar_slice(&non_zero_products);
321
322	let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
323		&non_zero_oracle_ids,
324		&oracles,
325		&non_zero_products,
326	)?;
327
328	// Grand products for flushing
329	let mixing_challenge = transcript.sample();
330	let permutation_challenges = transcript.sample_vec(channel_count);
331
332	flushes.retain(|flush| table_sizes[flush.table_id] > 0);
333	flushes.sort_by_key(|flush| flush.channel_id);
334	let po2_step_down_polys =
335		augment_flush_po2_step_down(&mut oracles, &mut flushes, &table_size_specs, table_sizes)?;
336	populate_flush_po2_step_down_witnesses::<U, _>(po2_step_down_polys, &mut witness)?;
337	let flush_oracle_ids =
338		make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
339
340	let flush_convert_span = tracing::info_span!(
341		"[task] Convert Flushes to Fast Field",
342		phase = "prodcheck",
343		perfetto_category = "task.main"
344	)
345	.entered();
346
347	let mut fast_witness = MultilinearExtensionIndex::<PackedType<U, FFastExt<Tower>>>::new();
348
349	make_masked_flush_witnesses::<U, _>(
350		&oracles,
351		&mut witness,
352		&mut fast_witness,
353		&flush_oracle_ids,
354		&flushes,
355		mixing_challenge,
356		&permutation_challenges,
357	)?;
358
359	// there are no oracle ids associated with these flush_witnesses
360	let flush_witnesses =
361		convert_witnesses_to_fast_ext::<U, _>(&oracles, &witness, &flush_oracle_ids)?;
362	emit_max_rss();
363	drop(flush_convert_span);
364
365	let flush_prodcheck_compute_layer_span = tracing::info_span!(
366		"[step] Compute Flush Product Layers",
367		phase = "prodcheck",
368		perfetto_category = "phase.sub"
369	)
370	.entered();
371	let flush_prodcheck_witnesses = flush_witnesses
372		.into_par_iter()
373		.map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
374		.collect::<Result<Vec<_>, _>>()?;
375	emit_max_rss();
376	drop(flush_prodcheck_compute_layer_span);
377
378	let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
379
380	transcript.message().write_scalar_slice(&flush_products);
381
382	let flush_prodcheck_claims =
383		gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
384
385	// Prove grand products
386	let all_gpa_witnesses =
387		chain!(flush_prodcheck_witnesses, non_zero_prodcheck_witnesses).collect::<Vec<_>>();
388	let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
389		.map(|claim| claim.isomorphic())
390		.collect::<Vec<_>>();
391
392	let GrandProductBatchProveOutput { final_layer_claims } =
393		gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
394			EvaluationOrder::HighToLow,
395			all_gpa_witnesses,
396			&all_gpa_claims,
397			&fast_domain_factory,
398			&mut transcript,
399			backend,
400		)?;
401
402	// Apply isomorphism to the layer claims
403	let final_layer_claims = final_layer_claims
404		.into_iter()
405		.map(|layer_claim| layer_claim.isomorphic())
406		.collect::<Vec<_>>();
407
408	// Reduce non_zero_final_layer_claims to evalcheck claims
409	let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
410		chain!(flush_oracle_ids.clone(), non_zero_oracle_ids),
411		final_layer_claims,
412	)?;
413
414	let mut flush_prodcheck_eval_claims = prodcheck_eval_claims;
415
416	let prodcheck_eval_claims = flush_prodcheck_eval_claims.split_off(flush_oracle_ids.len());
417
418	let flush_eval_claims = reduce_flush_evalcheck_claims::<U, Tower, Challenger_, Backend>(
419		flush_prodcheck_eval_claims,
420		&oracles,
421		fast_witness,
422		fast_domain_factory.clone(),
423		&mut transcript,
424		backend,
425	)?;
426
427	emit_max_rss();
428	drop(prodcheck_span);
429
430	// Zerocheck
431	let zerocheck_span = tracing::info_span!(
432		"[phase] Zerocheck",
433		phase = "zerocheck",
434		perfetto_category = "phase.main",
435	)
436	.entered();
437
438	let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
439		.iter()
440		.cloned()
441		.map(constraint_set_zerocheck_claim)
442		.collect::<Result<Vec<_>, _>>()?
443		.into_iter()
444		.unzip::<_, _, Vec<_>, Vec<_>>();
445
446	let (max_n_vars, skip_rounds) =
447		max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
448
449	let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
450
451	let mut zerocheck_provers = Vec::with_capacity(table_constraints.len());
452
453	for constraint_set in table_constraints {
454		let n_vars = constraint_set.n_vars;
455		let (constraints, multilinears) =
456			sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
457
458		let base_tower_level = chain!(
459			multilinears
460				.iter()
461				.map(|multilinear| 7 - multilinear.log_extension_degree()),
462			constraints
463				.iter()
464				.map(|constraint| constraint.composition.binary_tower_level())
465		)
466		.max()
467		.unwrap_or(0);
468
469		// Per prover zerocheck challenges are justified on the high indexed variables
470		let zerocheck_challenges = &zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..];
471		let domain_factory = domain_factory.clone();
472
473		let constructor =
474			ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _> {
475				constraints,
476				multilinears,
477				zerocheck_challenges,
478				domain_factory,
479				backend,
480				_fdomain_marker: PhantomData,
481			};
482
483		let zerocheck_prover = match base_tower_level {
484			0..=3 => constructor.create::<Tower::B8>()?,
485			4 => constructor.create::<Tower::B16>()?,
486			5 => constructor.create::<Tower::B32>()?,
487			6 => constructor.create::<Tower::B64>()?,
488			7 => constructor.create::<Tower::B128>()?,
489			_ => unreachable!(),
490		};
491
492		zerocheck_provers.push(zerocheck_prover);
493	}
494
495	let zerocheck_output = sumcheck::prove::batch_prove_zerocheck::<
496		FExt<Tower>,
497		FDomain<Tower>,
498		PackedType<U, FExt<Tower>>,
499		_,
500		_,
501	>(zerocheck_provers, skip_rounds, &mut transcript)?;
502
503	let zerocheck_eval_claims =
504		sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
505
506	emit_max_rss();
507	drop(zerocheck_span);
508
509	let evalcheck_span = tracing::info_span!(
510		"[phase] Evalcheck",
511		phase = "evalcheck",
512		perfetto_category = "phase.main"
513	)
514	.entered();
515
516	// Prove evaluation claims
517	let GreedyEvalcheckProveOutput {
518		eval_claims,
519		memoized_data,
520	} = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
521		&mut oracles,
522		&mut witness,
523		chain!(flush_eval_claims, prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
524		standard_switchover_heuristic(-2),
525		&mut transcript,
526		&domain_factory,
527		backend,
528	)?;
529
530	// Reduce committed evaluation claims to PIOP sumcheck claims
531	let system = ring_switch::EvalClaimSystem::new(
532		&oracles,
533		&commit_meta,
534		&oracle_to_commit_index,
535		&eval_claims,
536	)?;
537
538	emit_max_rss();
539	drop(evalcheck_span);
540
541	let ring_switch_span = tracing::info_span!(
542		"[phase] Ring Switch",
543		phase = "ring_switch",
544		perfetto_category = "phase.main"
545	)
546	.entered();
547
548	let hal = compute_data.hal;
549
550	let dev_alloc = &compute_data.dev_alloc;
551	let host_alloc = &compute_data.host_alloc;
552
553	let ring_switch::ReducedWitness {
554		transparents: transparent_multilins,
555		sumcheck_claims: piop_sumcheck_claims,
556	} = ring_switch::prove(
557		&system,
558		&committed_multilins,
559		&mut transcript,
560		memoized_data,
561		hal,
562		dev_alloc,
563		host_alloc,
564	)?;
565	emit_max_rss();
566	drop(ring_switch_span);
567
568	// Prove evaluation claims using PIOP compiler
569	let piop_compiler_span = tracing::info_span!(
570		"[phase] PIOP Compiler",
571		phase = "piop_compiler",
572		perfetto_category = "phase.main"
573	)
574	.entered();
575
576	piop::prove(
577		compute_data,
578		&fri_params,
579		&ntt,
580		&merkle_prover,
581		&commit_meta,
582		committed,
583		&codeword,
584		&committed_multilins,
585		transparent_multilins,
586		&piop_sumcheck_claims,
587		&mut transcript,
588	)?;
589	emit_max_rss();
590	drop(piop_compiler_span);
591
592	let proof = Proof {
593		transcript: transcript.finalize(),
594	};
595
596	tracing::event!(
597		name: "proof_size",
598		tracing::Level::INFO,
599		counter = true,
600		value = proof.get_proof_size() as u64,
601		unit = "bytes",
602	);
603
604	Ok(proof)
605}
606
607type TypeErasedZerocheck<'a, P> = Box<dyn ZerocheckProver<'a, P> + 'a>;
608
609struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
610where
611	P: PackedField,
612{
613	constraints: Vec<Constraint<P::Scalar>>,
614	multilinears: Vec<MultilinearWitness<'a, P>>,
615	domain_factory: DomainFactory,
616	zerocheck_challenges: &'a [P::Scalar],
617	backend: &'a Backend,
618	_fdomain_marker: PhantomData<FDomain>,
619}
620
621impl<'a, P, F, FDomain, DomainFactory, Backend>
622	ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
623where
624	F: Field,
625	P: PackedField<Scalar = F>,
626	FDomain: TowerField,
627	DomainFactory: EvaluationDomainFactory<FDomain> + 'a,
628	Backend: ComputationBackend,
629{
630	fn create<FBase>(self) -> Result<TypeErasedZerocheck<'a, P>, Error>
631	where
632		FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
633		P: PackedExtension<F, PackedSubfield = P>
634			+ PackedExtension<FDomain>
635			+ PackedExtension<FBase>,
636		F: TowerField,
637	{
638		let zerocheck_prover =
639			sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _, _>(
640				self.constraints,
641				self.multilinears,
642				self.domain_factory,
643				self.zerocheck_challenges,
644				self.backend,
645			)?;
646
647		let type_erased_zerocheck_prover = Box::new(zerocheck_prover) as TypeErasedZerocheck<'a, P>;
648
649		Ok(type_erased_zerocheck_prover)
650	}
651}
652
653fn populate_flush_po2_step_down_witnesses<'a, U, Tower>(
654	step_down_polys: Vec<(OracleId, StepDown)>,
655	witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
656) -> Result<(), Error>
657where
658	U: ProverTowerUnderlier<Tower>,
659	Tower: ProverTowerFamily,
660{
661	for (oracle_id, step_down_poly) in step_down_polys {
662		let witness_poly = step_down_poly
663			.multilinear_extension::<PackedType<U, Tower::B1>>()?
664			.specialize_arc_dyn();
665		witness.update_multilin_poly([(oracle_id, witness_poly)])?
666	}
667	Ok(())
668}
669
670#[instrument(skip_all, level = "debug")]
671pub fn make_masked_flush_witnesses<'a, U, Tower>(
672	oracles: &MultilinearOracleSet<FExt<Tower>>,
673	witness_index: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
674	fast_witness_index: &mut MultilinearExtensionIndex<'a, PackedType<U, FFastExt<Tower>>>,
675	flush_oracle_ids: &[OracleId],
676	flushes: &[Flush<FExt<Tower>>],
677	mixing_challenge: FExt<Tower>,
678	permutation_challenges: &[FExt<Tower>],
679) -> Result<(), Error>
680where
681	U: ProverTowerUnderlier<Tower>,
682	Tower: ProverTowerFamily,
683	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
684		+ RepackedExtension<PackedType<U, Tower::B1>>,
685{
686	// TODO: Move me out into a separate function & deduplicate.
687	// Count the suffix zeros on all selectors.
688	for flush in flushes {
689		let fast_selectors =
690			convert_1b_witnesses_to_fast_ext::<U, Tower>(witness_index, &flush.selectors)?;
691
692		for (&selector_id, fast_selector) in flush.selectors.iter().zip(fast_selectors) {
693			let selector = witness_index.get_multilin_poly(selector_id)?;
694			let zero_suffix_len = count_zero_suffixes(&selector);
695
696			let nonzero_prefix_len = (1 << selector.n_vars()) - zero_suffix_len;
697			witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
698				selector_id,
699				selector,
700				nonzero_prefix_len,
701			)])?;
702
703			fast_witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
704				selector_id,
705				fast_selector,
706				nonzero_prefix_len,
707			)])?;
708		}
709	}
710
711	let inner_oracles_id = flushes
712		.iter()
713		.flat_map(|flush| {
714			flush
715				.oracles
716				.iter()
717				.filter_map(|oracle_or_const| match oracle_or_const {
718					OracleOrConst::Oracle(oracle_id) => Some(*oracle_id),
719					_ => None,
720				})
721		})
722		.collect::<HashSet<_>>();
723
724	let inner_oracles_id = inner_oracles_id.into_iter().collect::<Vec<_>>();
725
726	let fast_inner_oracles =
727		convert_witnesses_to_fast_ext::<U, Tower>(oracles, witness_index, &inner_oracles_id)?;
728
729	for ((n_vars, witness_data), id) in fast_inner_oracles.into_iter().zip(inner_oracles_id) {
730		let fast_witness = MLEDirectAdapter::from(
731			MultilinearExtension::new(n_vars, witness_data)
732				.expect("witness_data created with correct n_vars"),
733		);
734
735		let nonzero_scalars_prefix = witness_index.get_index_entry(id)?.nonzero_scalars_prefix;
736
737		fast_witness_index.update_multilin_poly_with_nonzero_scalars_prefixes([(
738			id,
739			fast_witness.upcast_arc_dyn(),
740			nonzero_scalars_prefix,
741		)])?;
742	}
743
744	// Find the maximum power of the mixing challenge needed.
745	let max_n_mixed = flushes
746		.iter()
747		.map(|flush| flush.oracles.len())
748		.max()
749		.unwrap_or_default();
750	let mixing_powers = powers(mixing_challenge)
751		.take(max_n_mixed)
752		.collect::<Vec<_>>();
753
754	// The function is on the critical path, parallelize.
755	let indices_to_update = flush_oracle_ids
756		.par_iter()
757		.zip(flushes)
758		.map(|(&flush_oracle, flush)| {
759			let n_vars = oracles.n_vars(flush_oracle);
760
761			let const_term = flush
762				.oracles
763				.iter()
764				.copied()
765				.zip(mixing_powers.iter())
766				.filter_map(|(oracle_or_const, coeff)| match oracle_or_const {
767					OracleOrConst::Const { base, .. } => Some(base * coeff),
768					_ => None,
769				})
770				.sum::<FExt<Tower>>();
771			let const_term = permutation_challenges[flush.channel_id] + const_term;
772
773			let inner_oracles = flush
774				.oracles
775				.iter()
776				.copied()
777				.zip(mixing_powers.iter())
778				.filter_map(|(oracle_or_const, &coeff)| match oracle_or_const {
779					OracleOrConst::Oracle(oracle_id) => Some((oracle_id, coeff)),
780					_ => None,
781				})
782				.map(|(inner_id, coeff)| {
783					let witness = witness_index.get_multilin_poly(inner_id)?;
784					Ok((witness, coeff))
785				})
786				.collect::<Result<Vec<_>, Error>>()?;
787
788			let selector_entries = flush
789				.selectors
790				.iter()
791				.map(|id| witness_index.get_index_entry(*id))
792				.collect::<Result<Vec<_>, _>>()?;
793
794			// Get the number of entries before any selector column is fully disabled.
795			let selector_prefix_len = selector_entries
796				.iter()
797				.map(|selector_entry| selector_entry.nonzero_scalars_prefix)
798				.min()
799				.unwrap_or(1 << n_vars);
800
801			let selectors = selector_entries
802				.into_iter()
803				.map(|entry| entry.multilin_poly)
804				.collect::<Vec<_>>();
805
806			let log_width = <PackedType<U, FExt<Tower>>>::LOG_WIDTH;
807			let packed_selector_prefix_len = selector_prefix_len.div_ceil(1 << log_width);
808
809			let mut witness_data = Vec::with_capacity(1 << n_vars.saturating_sub(log_width));
810			(0..packed_selector_prefix_len)
811				.into_par_iter()
812				.map(|i| {
813					<PackedType<U, FExt<Tower>>>::from_fn(|j| {
814						let index = i << log_width | j;
815
816						// If n_vars < P::LOG_WIDTH, fill the remaining scalars with zeroes.
817						if index >= 1 << n_vars {
818							return <FExt<Tower>>::ZERO;
819						}
820
821						// Compute the product of all selectors at this point
822						let selector_off = selectors.iter().any(|selector| {
823							let sel_val = selector
824								.evaluate_on_hypercube(index)
825								.expect("index < 1 << n_vars");
826							sel_val.is_zero()
827						});
828
829						if selector_off {
830							// If any selector is zero, the result is 1
831							<FExt<Tower>>::ONE
832						} else {
833							// Otherwise, compute the linear combination
834							let mut inner_oracles_iter = inner_oracles.iter();
835
836							// Handle the first one specially because the mixing power is ONE,
837							// unless the first oracle was a constant.
838							if let Some((poly, coeff)) = inner_oracles_iter.next() {
839								let first_term = if *coeff == FExt::<Tower>::ONE {
840									poly.evaluate_on_hypercube(index).expect("index in bounds")
841								} else {
842									poly.evaluate_on_hypercube_and_scale(index, *coeff)
843										.expect("index in bounds")
844								};
845								inner_oracles_iter.fold(
846									const_term + first_term,
847									|sum, (poly, coeff)| {
848										let scaled_eval = poly
849											.evaluate_on_hypercube_and_scale(index, *coeff)
850											.expect("index in bounds");
851										sum + scaled_eval
852									},
853								)
854							} else {
855								const_term
856							}
857						}
858					})
859				})
860				.collect_into_vec(&mut witness_data);
861			witness_data.resize(witness_data.capacity(), PackedType::<U, FExt<Tower>>::one());
862
863			let witness = MLEDirectAdapter::from(
864				MultilinearExtension::new(n_vars, witness_data)
865					.expect("witness_data created with correct n_vars"),
866			);
867			// TODO: This is sketchy. The field on witness index is called "nonzero_prefix", but
868			// I'm setting it when the suffix is 1, not zero.
869			Ok((witness, selector_prefix_len))
870		})
871		.collect::<Result<Vec<_>, Error>>()?;
872
873	witness_index.update_multilin_poly_with_nonzero_scalars_prefixes(
874		iter::zip(flush_oracle_ids, indices_to_update).map(
875			|(&oracle_id, (witness, nonzero_scalars_prefix))| {
876				(oracle_id, witness.upcast_arc_dyn(), nonzero_scalars_prefix)
877			},
878		),
879	)?;
880	Ok(())
881}
882
883fn count_zero_suffixes<P: PackedField, M: MultilinearPoly<P>>(poly: &M) -> usize {
884	let zeros = P::zero();
885	if let Some(packed_evals) = poly.packed_evals() {
886		let packed_zero_suffix_len = packed_evals
887			.iter()
888			.rev()
889			.position(|&packed_eval| packed_eval != zeros)
890			.unwrap_or(packed_evals.len());
891
892		let log_scalars_per_elem = P::LOG_WIDTH + poly.log_extension_degree();
893		if poly.n_vars() < log_scalars_per_elem {
894			debug_assert_eq!(packed_evals.len(), 1, "invariant of MultilinearPoly");
895			packed_zero_suffix_len << poly.n_vars()
896		} else {
897			packed_zero_suffix_len << log_scalars_per_elem
898		}
899	} else {
900		0
901	}
902}
903
904/// Converts specified oracles' witness representations from the base extension field
905/// to the fast extension field format for optimized grand product calculations.
906///
907/// This function processes the provided list of oracle IDs, extracting the corresponding
908/// multilinear polynomials from the witness index, and converting their evaluations
909/// to the fast field representation. The conversion is performed efficiently using
910/// the tower transformation infrastructure.
911///
912/// # Performance Considerations
913/// - This function is optimized for parallel execution as it's on the critical path of the proving
914///   system.
915///
916/// # Arguments
917/// * `oracles` - Reference to the multilinear oracle set containing metadata for all oracles
918/// * `witness` - Reference to the witness index containing the multilinear polynomial evaluations
919/// * `oracle_ids` - Slice of oracle IDs for which to generate fast field representations
920///
921/// # Returns
922/// A vector of tuples, where each tuple contains:
923/// - The number of variables in the oracle's multilinear polynomial
924/// - A vector of packed field elements representing the polynomial's evaluations in the fast field
925///
926/// # Errors
927/// Returns an error if:
928/// - Any oracle ID is invalid or not found in the witness index
929/// - Subcube evaluation fails for any polynomial
930#[allow(clippy::type_complexity)]
931#[instrument(skip_all, level = "debug")]
932fn convert_witnesses_to_fast_ext<'a, U, Tower>(
933	oracles: &MultilinearOracleSet<FExt<Tower>>,
934	witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
935	oracle_ids: &[OracleId],
936) -> Result<Vec<(usize, Vec<PackedType<U, FFastExt<Tower>>>)>, Error>
937where
938	U: ProverTowerUnderlier<Tower>,
939	Tower: ProverTowerFamily,
940	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
941{
942	let to_fast = Tower::packed_transformation_to_fast();
943
944	// The function is on the critical path, parallelize.
945	oracle_ids
946		.into_par_iter()
947		.map(|&flush_oracle_id| {
948			let n_vars = oracles.n_vars(flush_oracle_id);
949
950			let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
951
952			let IndexEntry {
953				multilin_poly: poly,
954				nonzero_scalars_prefix,
955			} = witness.get_index_entry(flush_oracle_id)?;
956
957			const MAX_SUBCUBE_VARS: usize = 8;
958			let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
959			let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
960			let non_const_scalars = nonzero_scalars_prefix;
961			let non_const_subcubes = non_const_scalars.div_ceil(1 << subcube_vars);
962
963			let mut fast_ext_result = zeroed_vec(non_const_subcubes * subcube_packed_size);
964			fast_ext_result
965				.par_chunks_exact_mut(subcube_packed_size)
966				.enumerate()
967				.for_each(|(subcube_index, fast_subcube)| {
968					let underliers =
969						PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
970
971					let subcube_evals =
972						PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
973					poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
974						.expect("witness data populated by make_unmasked_flush_witnesses()");
975
976					for underlier in underliers.iter_mut() {
977						let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
978						let dest = to_fast.transform(&src);
979						*underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
980					}
981				});
982
983			fast_ext_result.truncate(non_const_scalars);
984			Ok((n_vars, fast_ext_result))
985		})
986		.collect()
987}
988
989#[allow(clippy::type_complexity)]
990pub fn convert_1b_witnesses_to_fast_ext<'a, U, Tower>(
991	witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
992	ids: &[OracleId],
993) -> Result<Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>, Error>
994where
995	U: ProverTowerUnderlier<Tower>,
996	Tower: ProverTowerFamily,
997	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>
998		+ RepackedExtension<PackedType<U, Tower::B1>>,
999{
1000	ids.iter()
1001		.map(|&id| {
1002			let exp_witness = witness.get_multilin_poly(id)?;
1003
1004			let packed_evals = exp_witness
1005				.packed_evals()
1006				.expect("poly contain packed_evals");
1007
1008			let packed_evals = PackedType::<U, Tower::B128>::cast_bases(packed_evals);
1009
1010			MultilinearExtension::new(exp_witness.n_vars(), packed_evals.to_vec())
1011				.map(|mle| mle.specialize_arc_dyn())
1012				.map_err(Error::from)
1013		})
1014		.collect::<Result<Vec<_>, _>>()
1015}
1016
1017#[instrument(skip_all, name = "flush::reduce_flush_evalcheck_claims")]
1018fn reduce_flush_evalcheck_claims<
1019	U,
1020	Tower: ProverTowerFamily,
1021	Challenger_,
1022	Backend: ComputationBackend,
1023>(
1024	claims: Vec<EvalcheckMultilinearClaim<FExt<Tower>>>,
1025	oracles: &MultilinearOracleSet<FExt<Tower>>,
1026	witness_index: MultilinearExtensionIndex<PackedType<U, FFastExt<Tower>>>,
1027	domain_factory: IsomorphicEvaluationDomainFactory<FFastExt<Tower>>,
1028	transcript: &mut ProverTranscript<Challenger_>,
1029	backend: &Backend,
1030) -> Result<Vec<EvalcheckMultilinearClaim<FExt<Tower>>>, Error>
1031where
1032	FExt<Tower>: From<FFastExt<Tower>>,
1033	FFastExt<Tower>: From<FExt<Tower>>,
1034	U: ProverTowerUnderlier<Tower>,
1035	Challenger_: Challenger + Default,
1036{
1037	let mut linear_claims = Vec::new();
1038
1039	#[allow(clippy::type_complexity)]
1040	let mut new_mlechecks_constraints: Vec<(
1041		EvalPoint<FFastExt<Tower>>,
1042		ConstraintSetBuilder<FFastExt<Tower>>,
1043	)> = Vec::new();
1044
1045	for claim in &claims {
1046		match &oracles[claim.id].variant {
1047			MultilinearPolyVariant::LinearCombination(_) => linear_claims.push(claim.clone()),
1048			MultilinearPolyVariant::Composite(composite) => {
1049				let eval_point = claim.eval_point.isomorphic();
1050
1051				let eval = claim.eval.into();
1052
1053				let position = new_mlechecks_constraints
1054					.iter()
1055					.position(|(ep, _)| *ep == eval_point)
1056					.unwrap_or(new_mlechecks_constraints.len());
1057
1058				let oracle_ids = composite.inner().clone();
1059
1060				let exp = <_ as CompositionPoly<FExt<Tower>>>::expression(composite.c());
1061				let fast_exp = exp.convert_field::<FFastExt<Tower>>();
1062
1063				if let Some((_, constraint_builder)) = new_mlechecks_constraints.get_mut(position) {
1064					constraint_builder.add_sumcheck(oracle_ids, fast_exp, eval);
1065				} else {
1066					let mut new_builder = ConstraintSetBuilder::new();
1067					new_builder.add_sumcheck(oracle_ids, fast_exp, eval);
1068					new_mlechecks_constraints.push((eval_point.clone(), new_builder));
1069				}
1070			}
1071			_ => unreachable!(),
1072		}
1073	}
1074
1075	let new_mlechecks = new_mlechecks_constraints
1076		.into_iter()
1077		.map(|(ep, builder)| {
1078			builder
1079				.build_one(oracles)
1080				.map(|constraint| ConstraintSetEqIndPoint {
1081					eq_ind_challenges: ep.clone(),
1082					constraint_set: constraint,
1083				})
1084				.map_err(Error::from)
1085		})
1086		.collect::<Result<Vec<_>, Error>>()?;
1087
1088	let mut memoized_data = MemoizedData::new();
1089
1090	let mut fast_new_evalcheck_claims = Vec::new();
1091
1092	for ConstraintSetEqIndPoint {
1093		eq_ind_challenges,
1094		constraint_set,
1095	} in new_mlechecks
1096	{
1097		let evalcheck_claims = prove_mlecheck_with_switchover::<_, _, FFastExt<Tower>, _, _>(
1098			&witness_index,
1099			constraint_set,
1100			eq_ind_challenges,
1101			&mut memoized_data,
1102			transcript,
1103			immediate_switchover_heuristic,
1104			domain_factory.clone(),
1105			backend,
1106		)?;
1107		fast_new_evalcheck_claims.extend(evalcheck_claims);
1108	}
1109
1110	Ok(chain!(
1111		fast_new_evalcheck_claims
1112			.into_iter()
1113			.map(|claim| claim.isomorphic::<FExt<Tower>>()),
1114		linear_claims.into_iter()
1115	)
1116	.collect::<Vec<_>>())
1117}