binius_core/constraint_system/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp::Reverse, env, marker::PhantomData, slice::from_mut};
4
5use binius_field::{
6	as_packed_field::{PackScalar, PackedType},
7	linear_transformation::{PackedTransformationFactory, Transformation},
8	underlier::WithUnderlier,
9	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
10	RepackedExtension, TowerField,
11};
12use binius_hal::ComputationBackend;
13use binius_hash::PseudoCompressionFunction;
14use binius_math::{
15	EvaluationDomainFactory, EvaluationOrder, IsomorphicEvaluationDomainFactory, MLEDirectAdapter,
16	MultilinearExtension, MultilinearPoly,
17};
18use binius_maybe_rayon::prelude::*;
19use binius_utils::bail;
20use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset, Output};
21use either::Either;
22use itertools::{chain, izip};
23use tracing::instrument;
24
25use super::{
26	channel::Boundary,
27	error::Error,
28	verify::{
29		get_post_flush_sumcheck_eval_claims_without_eq, make_flush_oracles,
30		max_n_vars_and_skip_rounds, reorder_for_flushing_by_n_vars,
31	},
32	ConstraintSystem, Proof,
33};
34use crate::{
35	constraint_system::{
36		common::{FDomain, FEncode, FExt, FFastExt},
37		verify::{get_flush_dedup_sumcheck_metas, FlushSumcheckMeta},
38	},
39	fiat_shamir::{CanSample, Challenger},
40	merkle_tree::BinaryMerkleTreeProver,
41	oracle::{Constraint, MultilinearOracleSet, MultilinearPolyVariant, OracleId},
42	piop,
43	protocols::{
44		fri::CommitOutput,
45		gkr_gpa::{
46			self, gpa_sumcheck::prove::GPAProver, GrandProductBatchProveOutput,
47			GrandProductWitness, LayerClaim,
48		},
49		greedy_evalcheck,
50		sumcheck::{
51			self, constraint_set_zerocheck_claim,
52			prove::{SumcheckProver, UnivariateZerocheckProver},
53			standard_switchover_heuristic, zerocheck,
54		},
55	},
56	ring_switch,
57	tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
58	transcript::ProverTranscript,
59	witness::{MultilinearExtensionIndex, MultilinearWitness},
60};
61
62/// Generates a proof that a witness satisfies a constraint system with the standard FRI PCS.
63#[instrument("constraint_system::prove", skip_all, level = "debug")]
64pub fn prove<U, Tower, DomainFactory, Hash, Compress, Challenger_, Backend>(
65	constraint_system: &ConstraintSystem<FExt<Tower>>,
66	log_inv_rate: usize,
67	security_bits: usize,
68	boundaries: &[Boundary<FExt<Tower>>],
69	mut witness: MultilinearExtensionIndex<U, FExt<Tower>>,
70	domain_factory: DomainFactory,
71	backend: &Backend,
72) -> Result<Proof, Error>
73where
74	U: ProverTowerUnderlier<Tower>,
75	Tower: ProverTowerFamily,
76	Tower::B128: PackedTop<Tower>,
77	DomainFactory: EvaluationDomainFactory<FDomain<Tower>>,
78	Hash: Digest + BlockSizeUser + FixedOutputReset,
79	Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
80	Challenger_: Challenger + Default,
81	Backend: ComputationBackend,
82	// REVIEW: Consider changing TowerFamily and associated traits to shorten/remove these bounds
83	PackedType<U, Tower::B128>: PackedTop<Tower>
84		+ PackedFieldIndexable
85		+ RepackedExtension<PackedType<U, Tower::B8>>
86		+ RepackedExtension<PackedType<U, Tower::B16>>
87		+ RepackedExtension<PackedType<U, Tower::B32>>
88		+ RepackedExtension<PackedType<U, Tower::B64>>
89		+ RepackedExtension<PackedType<U, Tower::B128>>
90		+ PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
91	PackedType<U, Tower::FastB128>:
92		PackedFieldIndexable + PackedTransformationFactory<PackedType<U, Tower::B128>>,
93	PackedType<U, Tower::B8>: PackedFieldIndexable,
94	PackedType<U, Tower::B16>: PackedFieldIndexable,
95	PackedType<U, Tower::B32>: PackedFieldIndexable,
96	PackedType<U, Tower::B64>: PackedFieldIndexable,
97{
98	tracing::debug!(
99		arch = env::consts::ARCH,
100		rayon_threads = binius_maybe_rayon::current_num_threads(),
101		"using computation backend: {backend:?}"
102	);
103
104	let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
105
106	let mut transcript = ProverTranscript::<Challenger_>::new();
107	transcript.observe().write_slice(boundaries);
108
109	let ConstraintSystem {
110		mut oracles,
111		mut table_constraints,
112		mut flushes,
113		non_zero_oracle_ids,
114		max_channel_id,
115	} = constraint_system.clone();
116
117	// Stable sort constraint sets in descending order by number of variables.
118	table_constraints.sort_by_key(|constraint_set| Reverse(constraint_set.n_vars));
119
120	// Commit polynomials
121	let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
122	let merkle_scheme = merkle_prover.scheme();
123
124	let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
125	let committed_multilins = piop::collect_committed_witnesses(
126		&commit_meta,
127		&oracle_to_commit_index,
128		&oracles,
129		&witness,
130	)?;
131
132	let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
133		&commit_meta,
134		merkle_scheme,
135		security_bits,
136		log_inv_rate,
137	)?;
138	let CommitOutput {
139		commitment,
140		committed,
141		codeword,
142	} = piop::commit(&fri_params, &merkle_prover, &committed_multilins)?;
143
144	// Observe polynomial commitment
145	let mut writer = transcript.message();
146	writer.write(&commitment);
147
148	// Grand product arguments
149	// Grand products for non-zero checking
150	let non_zero_fast_witnesses =
151		make_fast_masked_flush_witnesses(&oracles, &witness, &non_zero_oracle_ids, None)?;
152	let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
153		.into_par_iter()
154		.map(GrandProductWitness::new)
155		.collect::<Result<Vec<_>, _>>()?;
156
157	let non_zero_products =
158		gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
159	if non_zero_products
160		.iter()
161		.any(|count| *count == Tower::B128::zero())
162	{
163		bail!(Error::Zeros);
164	}
165
166	writer.write_scalar_slice(&non_zero_products);
167
168	let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
169		&non_zero_oracle_ids,
170		&oracles,
171		&non_zero_products,
172	)?;
173
174	// Grand products for flushing
175	let mixing_challenge = transcript.sample();
176	let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
177
178	flushes.sort_by_key(|flush| flush.channel_id);
179	let flush_oracle_ids =
180		make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
181	let flush_selectors = flushes
182		.iter()
183		.map(|flush| flush.selector)
184		.collect::<Vec<_>>();
185
186	make_unmasked_flush_witnesses(&oracles, &mut witness, &flush_oracle_ids)?;
187	// there are no oracle ids associated with these flush_witnesses
188	let flush_witnesses = make_fast_masked_flush_witnesses(
189		&oracles,
190		&witness,
191		&flush_oracle_ids,
192		Some(&flush_selectors),
193	)?;
194
195	// This is important to do in parallel.
196	let flush_prodcheck_witnesses = flush_witnesses
197		.into_par_iter()
198		.map(GrandProductWitness::new)
199		.collect::<Result<Vec<_>, _>>()?;
200	let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
201
202	transcript.message().write_scalar_slice(&flush_products);
203
204	let flush_prodcheck_claims =
205		gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
206
207	// Prove grand products
208	let all_gpa_witnesses = [flush_prodcheck_witnesses, non_zero_prodcheck_witnesses].concat();
209	let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
210		.map(|claim| claim.isomorphic())
211		.collect::<Vec<_>>();
212
213	let GrandProductBatchProveOutput { final_layer_claims } =
214		gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
215			EvaluationOrder::LowToHigh,
216			all_gpa_witnesses,
217			&all_gpa_claims,
218			&fast_domain_factory,
219			&mut transcript,
220			backend,
221		)?;
222
223	// Apply isomorphism to the layer claims
224	let mut final_layer_claims = final_layer_claims
225		.into_iter()
226		.map(|layer_claim| layer_claim.isomorphic())
227		.collect::<Vec<_>>();
228
229	let non_zero_final_layer_claims = final_layer_claims.split_off(flush_oracle_ids.len());
230	let flush_final_layer_claims = final_layer_claims;
231
232	// Reduce non_zero_final_layer_claims to evalcheck claims
233	let non_zero_prodcheck_eval_claims =
234		gkr_gpa::make_eval_claims(non_zero_oracle_ids, non_zero_final_layer_claims)?;
235
236	// Reduce flush_final_layer_claims to sumcheck claims then evalcheck claims
237	let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) =
238		reorder_for_flushing_by_n_vars(
239			&oracles,
240			&flush_oracle_ids,
241			flush_selectors,
242			flush_final_layer_claims,
243		);
244
245	let FlushSumcheckProvers {
246		provers,
247		flush_selectors_unique_by_claim,
248		flush_oracle_ids_by_claim,
249	} = get_flush_sumcheck_provers::<_, _, FDomain<Tower>, _, _>(
250		&mut oracles,
251		&flush_oracle_ids,
252		&flush_selectors,
253		&flush_final_layer_claims,
254		&mut witness,
255		&domain_factory,
256		backend,
257	)?;
258
259	let flush_sumcheck_output = sumcheck::prove::batch_prove(provers, &mut transcript)?;
260
261	let flush_eval_claims = get_post_flush_sumcheck_eval_claims_without_eq(
262		&oracles,
263		&flush_selectors_unique_by_claim,
264		&flush_oracle_ids_by_claim,
265		&flush_sumcheck_output,
266	)?;
267
268	// Zerocheck
269	let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
270		.iter()
271		.cloned()
272		.map(constraint_set_zerocheck_claim)
273		.collect::<Result<Vec<_>, _>>()?
274		.into_iter()
275		.unzip::<_, _, Vec<_>, Vec<_>>();
276
277	let (max_n_vars, skip_rounds) =
278		max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
279
280	let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
281
282	let switchover_fn = standard_switchover_heuristic(-2);
283
284	let mut univariate_provers = Vec::new();
285	let mut tail_regular_zerocheck_provers = Vec::new();
286	let mut univariatized_multilinears = Vec::new();
287
288	for constraint_set in table_constraints {
289		let skip_challenges = (max_n_vars - constraint_set.n_vars).saturating_sub(skip_rounds);
290		let univariate_decider = |n_vars| n_vars > max_n_vars - skip_rounds;
291
292		let (constraints, multilinears) =
293			sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
294
295		let base_tower_level = chain!(
296			multilinears
297				.iter()
298				.map(|multilinear| 7 - multilinear.log_extension_degree()),
299			constraints
300				.iter()
301				.map(|constraint| constraint.composition.binary_tower_level())
302		)
303		.max()
304		.unwrap_or(0);
305
306		univariatized_multilinears.push(multilinears.clone());
307
308		let constructor =
309			ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _, _> {
310				constraints,
311				multilinears,
312				domain_factory: &domain_factory,
313				switchover_fn,
314				zerocheck_challenges: &zerocheck_challenges[skip_challenges..],
315				backend,
316				_fdomain_marker: PhantomData,
317			};
318
319		let either_prover = match base_tower_level {
320			0..=3 => constructor.create::<Tower::B8>(univariate_decider)?,
321			4 => constructor.create::<Tower::B16>(univariate_decider)?,
322			5 => constructor.create::<Tower::B32>(univariate_decider)?,
323			6 => constructor.create::<Tower::B64>(univariate_decider)?,
324			7 => constructor.create::<Tower::B128>(univariate_decider)?,
325			_ => unreachable!(),
326		};
327
328		match either_prover {
329			Either::Left(univariate_prover) => univariate_provers.push(univariate_prover),
330			Either::Right(zerocheck_prover) => {
331				tail_regular_zerocheck_provers.push(zerocheck_prover)
332			}
333		}
334	}
335
336	let univariate_cnt = univariate_provers.len();
337
338	let univariate_output = sumcheck::prove::batch_prove_zerocheck_univariate_round(
339		univariate_provers,
340		skip_rounds,
341		&mut transcript,
342	)?;
343
344	let univariate_challenge = univariate_output.univariate_challenge;
345
346	let sumcheck_output = sumcheck::prove::batch_prove_with_start(
347		univariate_output.batch_prove_start,
348		tail_regular_zerocheck_provers,
349		&mut transcript,
350	)?;
351
352	let zerocheck_output = zerocheck::verify_sumcheck_outputs(
353		&zerocheck_claims,
354		&zerocheck_challenges,
355		sumcheck_output,
356	)?;
357
358	let mut reduction_claims = Vec::with_capacity(univariate_cnt);
359	let mut reduction_provers = Vec::with_capacity(univariate_cnt);
360
361	for (univariatized_multilinear_evals, multilinears) in
362		izip!(&zerocheck_output.multilinear_evals, univariatized_multilinears)
363	{
364		let claim_n_vars = multilinears
365			.first()
366			.map_or(0, |multilinear| multilinear.n_vars());
367
368		let skip_challenges = (max_n_vars - claim_n_vars).saturating_sub(skip_rounds);
369		let challenges = &zerocheck_output.challenges[skip_challenges..];
370		let reduced_multilinears =
371			sumcheck::prove::reduce_to_skipped_projection(multilinears, challenges, backend)?;
372
373		let claim_skip_rounds = claim_n_vars - challenges.len();
374		let reduction_claim = sumcheck::univariate::univariatizing_reduction_claim(
375			claim_skip_rounds,
376			univariatized_multilinear_evals,
377		)?;
378
379		let reduction_prover =
380			sumcheck::prove::univariatizing_reduction_prover::<_, FDomain<Tower>, _, _>(
381				reduced_multilinears,
382				univariatized_multilinear_evals,
383				univariate_challenge,
384				&domain_factory,
385				backend,
386			)?;
387
388		reduction_claims.push(reduction_claim);
389		reduction_provers.push(reduction_prover);
390	}
391
392	let univariatizing_output = sumcheck::prove::batch_prove(reduction_provers, &mut transcript)?;
393
394	let multilinear_zerocheck_output = sumcheck::univariate::verify_sumcheck_outputs(
395		&reduction_claims,
396		univariate_challenge,
397		&zerocheck_output.challenges,
398		univariatizing_output,
399	)?;
400
401	let zerocheck_eval_claims =
402		sumcheck::make_eval_claims(zerocheck_oracle_metas, multilinear_zerocheck_output)?;
403
404	// Prove evaluation claims
405	let eval_claims = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
406		&mut oracles,
407		&mut witness,
408		[non_zero_prodcheck_eval_claims, flush_eval_claims]
409			.concat()
410			.into_iter()
411			.chain(zerocheck_eval_claims),
412		switchover_fn,
413		&mut transcript,
414		&domain_factory,
415		backend,
416	)?;
417
418	// Reduce committed evaluation claims to PIOP sumcheck claims
419	let system = ring_switch::EvalClaimSystem::new(
420		&oracles,
421		&commit_meta,
422		&oracle_to_commit_index,
423		&eval_claims,
424	)?;
425
426	let ring_switch::ReducedWitness {
427		transparents: transparent_multilins,
428		sumcheck_claims: piop_sumcheck_claims,
429	} = ring_switch::prove::<_, _, _, Tower, _, _>(
430		&system,
431		&committed_multilins,
432		&mut transcript,
433		backend,
434	)?;
435
436	// Prove evaluation claims using PIOP compiler
437	piop::prove::<_, FDomain<Tower>, _, _, _, _, _, _, _, _>(
438		&fri_params,
439		&merkle_prover,
440		domain_factory,
441		&commit_meta,
442		committed,
443		&codeword,
444		&committed_multilins,
445		&transparent_multilins,
446		&piop_sumcheck_claims,
447		&mut transcript,
448		&backend,
449	)?;
450
451	Ok(Proof {
452		transcript: transcript.finalize(),
453	})
454}
455
456type TypeErasedUnivariateZerocheck<'a, F> = Box<dyn UnivariateZerocheckProver<'a, F> + 'a>;
457type TypeErasedSumcheck<'a, F> = Box<dyn SumcheckProver<F> + 'a>;
458type TypeErasedProver<'a, F> =
459	Either<TypeErasedUnivariateZerocheck<'a, F>, TypeErasedSumcheck<'a, F>>;
460
461struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
462where
463	P: PackedField,
464{
465	constraints: Vec<Constraint<P::Scalar>>,
466	multilinears: Vec<MultilinearWitness<'a, P>>,
467	domain_factory: DomainFactory,
468	switchover_fn: SwitchoverFn,
469	zerocheck_challenges: &'a [P::Scalar],
470	backend: &'a Backend,
471	_fdomain_marker: PhantomData<FDomain>,
472}
473
474impl<'a, P, F, FDomain, DomainFactory, SwitchoverFn, Backend>
475	ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
476where
477	F: Field,
478	P: PackedFieldIndexable<Scalar = F>,
479	FDomain: TowerField,
480	DomainFactory: EvaluationDomainFactory<FDomain>,
481	SwitchoverFn: Fn(usize) -> usize + Clone,
482	Backend: ComputationBackend,
483{
484	fn create<FBase>(
485		self,
486		is_univariate: impl FnOnce(usize) -> bool,
487	) -> Result<TypeErasedProver<'a, F>, Error>
488	where
489		FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
490		P: PackedExtension<F, PackedSubfield = P>
491			+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
492			+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
493		F: TowerField,
494	{
495		let univariate_prover =
496			sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _>(
497				self.constraints,
498				self.multilinears,
499				self.domain_factory,
500				self.switchover_fn,
501				self.zerocheck_challenges,
502				self.backend,
503			)?;
504
505		let type_erased_prover = if is_univariate(univariate_prover.n_vars()) {
506			let type_erased_univariate_prover =
507				Box::new(univariate_prover) as TypeErasedUnivariateZerocheck<'a, P::Scalar>;
508
509			Either::Left(type_erased_univariate_prover)
510		} else {
511			let zerocheck_prover = univariate_prover.into_regular_zerocheck()?;
512			let type_erased_zerocheck_prover =
513				Box::new(zerocheck_prover) as TypeErasedSumcheck<'a, P::Scalar>;
514
515			Either::Right(type_erased_zerocheck_prover)
516		};
517
518		Ok(type_erased_prover)
519	}
520}
521
522#[instrument(skip_all, level = "debug")]
523fn make_unmasked_flush_witnesses<'a, U, Tower>(
524	oracles: &MultilinearOracleSet<FExt<Tower>>,
525	witness: &mut MultilinearExtensionIndex<'a, U, FExt<Tower>>,
526	flush_oracle_ids: &[OracleId],
527) -> Result<(), Error>
528where
529	U: ProverTowerUnderlier<Tower>,
530	Tower: ProverTowerFamily,
531{
532	// The function is on the critical path, parallelize.
533	let flush_witnesses: Result<Vec<MultilinearWitness<'a, _>>, Error> = flush_oracle_ids
534		.par_iter()
535		.map(|&oracle_id| {
536			let MultilinearPolyVariant::LinearCombination(lincom) =
537				oracles.oracle(oracle_id).variant
538			else {
539				unreachable!("make_flush_oracles adds linear combination oracles");
540			};
541			let polys = lincom
542				.polys()
543				.map(|id| witness.get_multilin_poly(id))
544				.collect::<Result<Vec<_>, _>>()?;
545
546			let packed_len = 1
547				<< lincom
548					.n_vars()
549					.saturating_sub(<PackedType<U, FExt<Tower>>>::LOG_WIDTH);
550			let data = (0..packed_len)
551				.into_par_iter()
552				.map(|i| {
553					<PackedType<U, FExt<Tower>>>::from_fn(|j| {
554						let index = i << <PackedType<U, FExt<Tower>>>::LOG_WIDTH | j;
555						polys.iter().zip(lincom.coefficients()).fold(
556							lincom.offset(),
557							|sum, (poly, coeff)| {
558								sum + poly
559									.evaluate_on_hypercube_and_scale(index, coeff)
560									.unwrap_or(<FExt<Tower>>::ZERO)
561							},
562						)
563					})
564				})
565				.collect::<Vec<_>>();
566			let lincom_poly = MultilinearExtension::new(lincom.n_vars(), data)
567				.expect("data is constructed with the correct length with respect to n_vars");
568
569			Ok(MLEDirectAdapter::from(lincom_poly).upcast_arc_dyn())
570		})
571		.collect();
572
573	witness.update_multilin_poly(izip!(flush_oracle_ids.iter().copied(), flush_witnesses?))?;
574	Ok(())
575}
576
577#[allow(clippy::type_complexity)]
578#[instrument(skip_all, level = "debug")]
579fn make_fast_masked_flush_witnesses<'a, U, Tower>(
580	oracles: &MultilinearOracleSet<FExt<Tower>>,
581	witness: &MultilinearExtensionIndex<'a, U, FExt<Tower>>,
582	flush_oracles: &[OracleId],
583	flush_selectors: Option<&[OracleId]>,
584) -> Result<Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>, Error>
585where
586	U: ProverTowerUnderlier<Tower>,
587	Tower: ProverTowerFamily,
588	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
589{
590	let to_fast = Tower::packed_transformation_to_fast();
591
592	// The function is on the critical path, parallelize.
593	flush_oracles
594		.par_iter()
595		.enumerate()
596		.map(|(i, &flush_oracle_id)| {
597			let n_vars = oracles.n_vars(flush_oracle_id);
598
599			let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
600			let width = 1 << log_width;
601
602			let packed_len = 1 << n_vars.saturating_sub(log_width);
603			let mut fast_ext_result = vec![PackedType::<U, FFastExt<Tower>>::one(); packed_len];
604
605			let poly = witness.get_multilin_poly(flush_oracle_id)?;
606			let selector = flush_selectors
607				.map(|flush_selectors| witness.get_multilin_poly(flush_selectors[i]))
608				.transpose()?;
609
610			const MAX_SUBCUBE_VARS: usize = 8;
611			let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
612			let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
613
614			fast_ext_result
615				.par_chunks_mut(subcube_packed_size)
616				.enumerate()
617				.for_each(|(subcube_index, fast_subcube)| {
618					let underliers =
619						PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
620
621					let subcube_evals =
622						PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
623					poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
624						.expect("witness data populated by make_unmasked_flush_witnesses()");
625
626					for underlier in underliers.iter_mut() {
627						let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
628						let dest = to_fast.transform(&src);
629						*underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
630					}
631
632					if let Some(selector) = &selector {
633						let fast_subcube =
634							PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
635
636						let mut ones_mask = PackedType::<U, FExt<Tower>>::default();
637						for (i, packed) in fast_subcube.iter_mut().enumerate() {
638							selector
639								.subcube_evals(
640									log_width,
641									(subcube_index << subcube_vars.saturating_sub(log_width)) | i,
642									0,
643									from_mut(&mut ones_mask),
644								)
645								.expect("selector n_vars equals flushed n_vars");
646
647							if ones_mask == PackedField::zero() {
648								*packed = PackedField::one();
649							} else if ones_mask != PackedField::one() {
650								for j in 0..width {
651									if ones_mask.get(j) == FExt::<Tower>::ZERO {
652										packed.set(j, FFastExt::<Tower>::ONE);
653									}
654								}
655							}
656						}
657					}
658				});
659
660			let masked_poly = MultilinearExtension::new(n_vars, fast_ext_result)
661				.expect("data is constructed with the correct length with respect to n_vars");
662			Ok(MLEDirectAdapter::from(masked_poly).upcast_arc_dyn())
663		})
664		.collect()
665}
666
667pub struct FlushSumcheckProvers<Prover> {
668	provers: Vec<Prover>,
669	flush_oracle_ids_by_claim: Vec<Vec<OracleId>>,
670	flush_selectors_unique_by_claim: Vec<Vec<OracleId>>,
671}
672
673#[instrument(skip_all, level = "debug")]
674fn get_flush_sumcheck_provers<'a, 'b, U, Tower, FDomain, DomainFactory, Backend>(
675	oracles: &mut MultilinearOracleSet<Tower::B128>,
676	flush_oracle_ids: &[OracleId],
677	flush_selectors: &[OracleId],
678	final_layer_claims: &[LayerClaim<Tower::B128>],
679	witness: &mut MultilinearExtensionIndex<'a, U, Tower::B128>,
680	domain_factory: DomainFactory,
681	backend: &'b Backend,
682) -> Result<FlushSumcheckProvers<impl SumcheckProver<Tower::B128> + 'b>, Error>
683where
684	U: ProverTowerUnderlier<Tower> + PackScalar<FDomain>,
685	Tower: ProverTowerFamily,
686	Tower::B128: ExtensionField<FDomain>,
687	FDomain: Field,
688	DomainFactory: EvaluationDomainFactory<FDomain>,
689	Backend: ComputationBackend,
690	PackedType<U, Tower::B128>: PackedFieldIndexable,
691	'a: 'b,
692{
693	let flush_sumcheck_metas = get_flush_dedup_sumcheck_metas(
694		oracles,
695		flush_oracle_ids,
696		flush_selectors,
697		final_layer_claims,
698	)?;
699
700	let n_claims = flush_sumcheck_metas.len();
701	let mut provers = Vec::with_capacity(n_claims);
702	let mut flush_oracle_ids_by_claim = Vec::with_capacity(n_claims);
703	let mut flush_selectors_unique_by_claim = Vec::with_capacity(n_claims);
704	for flush_sumcheck_meta in flush_sumcheck_metas {
705		let FlushSumcheckMeta {
706			composite_sum_claims,
707			flush_selectors_unique,
708			flush_oracle_ids,
709			eval_point,
710		} = flush_sumcheck_meta;
711
712		let mut multilinears =
713			Vec::with_capacity(flush_selectors_unique.len() + flush_oracle_ids.len());
714
715		for &flush_selector in &flush_selectors_unique {
716			multilinears.push(witness.get_multilin_poly(flush_selector)?);
717		}
718
719		for &oracle_id in &flush_oracle_ids {
720			multilinears.push(witness.get_multilin_poly(oracle_id)?);
721		}
722
723		let prover = GPAProver::new(
724			EvaluationOrder::LowToHigh,
725			multilinears,
726			None,
727			composite_sum_claims,
728			domain_factory.clone(),
729			&eval_point,
730			backend,
731		)?;
732
733		provers.push(prover);
734		flush_oracle_ids_by_claim.push(flush_oracle_ids);
735		flush_selectors_unique_by_claim.push(flush_selectors_unique);
736	}
737
738	Ok(FlushSumcheckProvers {
739		provers,
740		flush_selectors_unique_by_claim,
741		flush_oracle_ids_by_claim,
742	})
743}