binius_core/constraint_system/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp::Reverse, env, marker::PhantomData, slice::from_mut};
4
5use binius_field::{
6	as_packed_field::{PackScalar, PackedType},
7	linear_transformation::{PackedTransformationFactory, Transformation},
8	underlier::WithUnderlier,
9	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
10	RepackedExtension, TowerField,
11};
12use binius_hal::ComputationBackend;
13use binius_hash::PseudoCompressionFunction;
14use binius_math::{
15	DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
16	IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
17};
18use binius_maybe_rayon::prelude::*;
19use binius_utils::bail;
20use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset, Output};
21use either::Either;
22use itertools::{chain, izip};
23use tracing::instrument;
24
25use super::{
26	channel::Boundary,
27	error::Error,
28	verify::{
29		get_post_flush_sumcheck_eval_claims_without_eq, make_flush_oracles,
30		max_n_vars_and_skip_rounds, reorder_for_flushing_by_n_vars,
31	},
32	ConstraintSystem, Proof,
33};
34use crate::{
35	constraint_system::{
36		common::{FDomain, FEncode, FExt, FFastExt},
37		exp,
38		verify::{get_flush_dedup_sumcheck_metas, FlushSumcheckMeta},
39	},
40	fiat_shamir::{CanSample, Challenger},
41	merkle_tree::BinaryMerkleTreeProver,
42	oracle::{Constraint, MultilinearOracleSet, MultilinearPolyVariant, OracleId},
43	piop,
44	protocols::{
45		fri::CommitOutput,
46		gkr_exp,
47		gkr_gpa::{self, GrandProductBatchProveOutput, GrandProductWitness, LayerClaim},
48		greedy_evalcheck::{self, GreedyEvalcheckProveOutput},
49		sumcheck::{
50			self, constraint_set_zerocheck_claim, immediate_switchover_heuristic,
51			prove::{
52				eq_ind::EqIndSumcheckProverBuilder, SumcheckProver, UnivariateZerocheckProver,
53			},
54			standard_switchover_heuristic, zerocheck,
55		},
56	},
57	ring_switch,
58	tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
59	transcript::ProverTranscript,
60	witness::{MultilinearExtensionIndex, MultilinearWitness},
61};
62
63/// Generates a proof that a witness satisfies a constraint system with the standard FRI PCS.
64#[instrument("constraint_system::prove", skip_all, level = "debug")]
65pub fn prove<U, Tower, Hash, Compress, Challenger_, Backend>(
66	constraint_system: &ConstraintSystem<FExt<Tower>>,
67	log_inv_rate: usize,
68	security_bits: usize,
69	boundaries: &[Boundary<FExt<Tower>>],
70	mut witness: MultilinearExtensionIndex<PackedType<U, FExt<Tower>>>,
71	backend: &Backend,
72) -> Result<Proof, Error>
73where
74	U: ProverTowerUnderlier<Tower>,
75	Tower: ProverTowerFamily,
76	Tower::B128: PackedTop<Tower>,
77	Hash: Digest + BlockSizeUser + FixedOutputReset + Send + Sync + Clone,
78	Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
79	Challenger_: Challenger + Default,
80	Backend: ComputationBackend,
81	// REVIEW: Consider changing TowerFamily and associated traits to shorten/remove these bounds
82	PackedType<U, Tower::B128>: PackedTop<Tower>
83		+ PackedFieldIndexable
84		+ RepackedExtension<PackedType<U, Tower::B8>>
85		+ RepackedExtension<PackedType<U, Tower::B16>>
86		+ RepackedExtension<PackedType<U, Tower::B32>>
87		+ RepackedExtension<PackedType<U, Tower::B64>>
88		+ RepackedExtension<PackedType<U, Tower::B128>>
89		+ PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
90	PackedType<U, Tower::FastB128>:
91		PackedFieldIndexable + PackedTransformationFactory<PackedType<U, Tower::B128>>,
92	PackedType<U, Tower::B8>: PackedFieldIndexable,
93	PackedType<U, Tower::B16>: PackedFieldIndexable,
94	PackedType<U, Tower::B32>: PackedFieldIndexable,
95	PackedType<U, Tower::B64>: PackedFieldIndexable,
96{
97	tracing::debug!(
98		arch = env::consts::ARCH,
99		rayon_threads = binius_maybe_rayon::current_num_threads(),
100		"using computation backend: {backend:?}"
101	);
102
103	let domain_factory = DefaultEvaluationDomainFactory::<FDomain<Tower>>::default();
104	let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
105
106	let mut transcript = ProverTranscript::<Challenger_>::new();
107	transcript.observe().write_slice(boundaries);
108
109	let ConstraintSystem {
110		mut oracles,
111		mut table_constraints,
112		mut flushes,
113		mut exponents,
114		non_zero_oracle_ids,
115		max_channel_id,
116	} = constraint_system.clone();
117
118	exponents.sort_by_key(|b| std::cmp::Reverse(b.n_vars(&oracles)));
119
120	// We must generate multiplication witnesses before committing, as this function
121	// adds the committed witnesses for exponentiation results to the witness index.
122	let exp_witnesses = exp::make_exp_witnesses::<U, Tower>(&mut witness, &oracles, &exponents)?;
123
124	// Stable sort constraint sets in descending order by number of variables.
125	table_constraints.sort_by_key(|constraint_set| Reverse(constraint_set.n_vars));
126
127	// Commit polynomials
128	let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
129	let merkle_scheme = merkle_prover.scheme();
130
131	let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
132	let committed_multilins = piop::collect_committed_witnesses::<U, _>(
133		&commit_meta,
134		&oracle_to_commit_index,
135		&oracles,
136		&witness,
137	)?;
138
139	let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
140		&commit_meta,
141		merkle_scheme,
142		security_bits,
143		log_inv_rate,
144	)?;
145	let CommitOutput {
146		commitment,
147		committed,
148		codeword,
149	} = piop::commit(&fri_params, &merkle_prover, &committed_multilins)?;
150
151	// Observe polynomial commitment
152	let mut writer = transcript.message();
153	writer.write(&commitment);
154
155	// GKR exp
156	let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
157
158	let exp_evals = gkr_exp::get_evals_in_point_from_witnesses(&exp_witnesses, &exp_challenge)?
159		.into_iter()
160		.map(|x| x.into())
161		.collect::<Vec<_>>();
162
163	let mut writer = transcript.message();
164	writer.write_scalar_slice(&exp_evals);
165
166	let exp_challenge = exp_challenge
167		.into_iter()
168		.map(|x| x.into())
169		.collect::<Vec<_>>();
170
171	let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
172		.into_iter()
173		.map(|claim| claim.isomorphic())
174		.collect::<Vec<_>>();
175
176	let base_exp_output = gkr_exp::batch_prove::<_, _, FFastExt<Tower>, _, _>(
177		EvaluationOrder::HighToLow,
178		exp_witnesses,
179		&exp_claims,
180		fast_domain_factory.clone(),
181		&mut transcript,
182		backend,
183	)?
184	.isomorphic();
185
186	let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
187
188	// Grand product arguments
189	// Grand products for non-zero checking
190	let non_zero_fast_witnesses =
191		make_fast_masked_flush_witnesses::<U, _>(&oracles, &witness, &non_zero_oracle_ids, None)?;
192	let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
193		.into_par_iter()
194		.map(GrandProductWitness::new)
195		.collect::<Result<Vec<_>, _>>()?;
196
197	let non_zero_products =
198		gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
199	if non_zero_products
200		.iter()
201		.any(|count| *count == Tower::B128::zero())
202	{
203		bail!(Error::Zeros);
204	}
205
206	let mut writer = transcript.message();
207
208	writer.write_scalar_slice(&non_zero_products);
209
210	let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
211		&non_zero_oracle_ids,
212		&oracles,
213		&non_zero_products,
214	)?;
215
216	// Grand products for flushing
217	let mixing_challenge = transcript.sample();
218	let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
219
220	flushes.sort_by_key(|flush| flush.channel_id);
221	let flush_oracle_ids =
222		make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
223	let flush_selectors = flushes
224		.iter()
225		.map(|flush| flush.selector)
226		.collect::<Vec<_>>();
227
228	make_unmasked_flush_witnesses::<U, _>(&oracles, &mut witness, &flush_oracle_ids)?;
229	// there are no oracle ids associated with these flush_witnesses
230	let flush_witnesses = make_fast_masked_flush_witnesses::<U, _>(
231		&oracles,
232		&witness,
233		&flush_oracle_ids,
234		Some(&flush_selectors),
235	)?;
236
237	// This is important to do in parallel.
238	let flush_prodcheck_witnesses = flush_witnesses
239		.into_par_iter()
240		.map(GrandProductWitness::new)
241		.collect::<Result<Vec<_>, _>>()?;
242	let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
243
244	transcript.message().write_scalar_slice(&flush_products);
245
246	let flush_prodcheck_claims =
247		gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
248
249	// Prove grand products
250	let all_gpa_witnesses = [flush_prodcheck_witnesses, non_zero_prodcheck_witnesses].concat();
251	let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
252		.map(|claim| claim.isomorphic())
253		.collect::<Vec<_>>();
254
255	let GrandProductBatchProveOutput { final_layer_claims } =
256		gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
257			EvaluationOrder::LowToHigh,
258			all_gpa_witnesses,
259			&all_gpa_claims,
260			&fast_domain_factory,
261			&mut transcript,
262			backend,
263		)?;
264
265	// Apply isomorphism to the layer claims
266	let mut final_layer_claims = final_layer_claims
267		.into_iter()
268		.map(|layer_claim| layer_claim.isomorphic())
269		.collect::<Vec<_>>();
270
271	let non_zero_final_layer_claims = final_layer_claims.split_off(flush_oracle_ids.len());
272	let flush_final_layer_claims = final_layer_claims;
273
274	// Reduce non_zero_final_layer_claims to evalcheck claims
275	let non_zero_prodcheck_eval_claims =
276		gkr_gpa::make_eval_claims(non_zero_oracle_ids, non_zero_final_layer_claims)?;
277
278	// Reduce flush_final_layer_claims to sumcheck claims then evalcheck claims
279	let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) =
280		reorder_for_flushing_by_n_vars(
281			&oracles,
282			&flush_oracle_ids,
283			flush_selectors,
284			flush_final_layer_claims,
285		);
286
287	let FlushSumcheckProvers {
288		provers,
289		flush_selectors_unique_by_claim,
290		flush_oracle_ids_by_claim,
291	} = get_flush_sumcheck_provers::<U, _, FDomain<Tower>, _, _>(
292		&mut oracles,
293		&flush_oracle_ids,
294		&flush_selectors,
295		&flush_final_layer_claims,
296		&mut witness,
297		&domain_factory,
298		backend,
299	)?;
300
301	let flush_sumcheck_output = sumcheck::prove::batch_prove(provers, &mut transcript)?;
302
303	let flush_eval_claims = get_post_flush_sumcheck_eval_claims_without_eq(
304		&oracles,
305		&flush_selectors_unique_by_claim,
306		&flush_oracle_ids_by_claim,
307		&flush_sumcheck_output,
308	)?;
309
310	// Zerocheck
311	let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
312		.iter()
313		.cloned()
314		.map(constraint_set_zerocheck_claim)
315		.collect::<Result<Vec<_>, _>>()?
316		.into_iter()
317		.unzip::<_, _, Vec<_>, Vec<_>>();
318
319	let eq_ind_sumcheck_claims = zerocheck::reduce_to_eq_ind_sumchecks(&zerocheck_claims)?;
320
321	let (max_n_vars, skip_rounds) =
322		max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
323
324	let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
325
326	let switchover_fn = standard_switchover_heuristic(-2);
327
328	let mut univariate_provers = Vec::new();
329	let mut tail_regular_zerocheck_provers = Vec::new();
330	let mut univariatized_multilinears = Vec::new();
331
332	for constraint_set in table_constraints {
333		let skip_challenges = (max_n_vars - constraint_set.n_vars).saturating_sub(skip_rounds);
334		let univariate_decider = |n_vars| n_vars > max_n_vars - skip_rounds;
335
336		let (constraints, multilinears) =
337			sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
338
339		let base_tower_level = chain!(
340			multilinears
341				.iter()
342				.map(|multilinear| 7 - multilinear.log_extension_degree()),
343			constraints
344				.iter()
345				.map(|constraint| constraint.composition.binary_tower_level())
346		)
347		.max()
348		.unwrap_or(0);
349
350		univariatized_multilinears.push(multilinears.clone());
351
352		let constructor =
353			ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _, _> {
354				constraints,
355				multilinears,
356				domain_factory: domain_factory.clone(),
357				switchover_fn,
358				zerocheck_challenges: &zerocheck_challenges[skip_challenges..],
359				backend,
360				_fdomain_marker: PhantomData,
361			};
362
363		let either_prover = match base_tower_level {
364			0..=3 => constructor.create::<Tower::B8>(univariate_decider)?,
365			4 => constructor.create::<Tower::B16>(univariate_decider)?,
366			5 => constructor.create::<Tower::B32>(univariate_decider)?,
367			6 => constructor.create::<Tower::B64>(univariate_decider)?,
368			7 => constructor.create::<Tower::B128>(univariate_decider)?,
369			_ => unreachable!(),
370		};
371
372		match either_prover {
373			Either::Left(univariate_prover) => univariate_provers.push(univariate_prover),
374			Either::Right(zerocheck_prover) => {
375				tail_regular_zerocheck_provers.push(zerocheck_prover)
376			}
377		}
378	}
379
380	let univariate_cnt = univariate_provers.len();
381
382	let univariate_output = sumcheck::prove::batch_prove_zerocheck_univariate_round(
383		univariate_provers,
384		skip_rounds,
385		&mut transcript,
386	)?;
387
388	let univariate_challenge = univariate_output.univariate_challenge;
389
390	let sumcheck_output = sumcheck::prove::batch_prove_with_start(
391		univariate_output.batch_prove_start,
392		tail_regular_zerocheck_provers,
393		&mut transcript,
394	)?;
395
396	let zerocheck_output = sumcheck::eq_ind::verify_sumcheck_outputs(
397		&eq_ind_sumcheck_claims,
398		&zerocheck_challenges,
399		sumcheck_output,
400	)?;
401
402	let mut reduction_claims = Vec::with_capacity(univariate_cnt);
403	let mut reduction_provers = Vec::with_capacity(univariate_cnt);
404
405	for (univariatized_multilinear_evals, multilinears) in
406		izip!(&zerocheck_output.multilinear_evals, univariatized_multilinears)
407	{
408		let claim_n_vars = multilinears
409			.first()
410			.map_or(0, |multilinear| multilinear.n_vars());
411
412		let skip_challenges = (max_n_vars - claim_n_vars).saturating_sub(skip_rounds);
413		let challenges = &zerocheck_output.challenges[skip_challenges..];
414		let reduced_multilinears =
415			sumcheck::prove::reduce_to_skipped_projection(multilinears, challenges, backend)?;
416
417		let claim_skip_rounds = claim_n_vars - challenges.len();
418		let reduction_claim = sumcheck::univariate::univariatizing_reduction_claim(
419			claim_skip_rounds,
420			univariatized_multilinear_evals,
421		)?;
422
423		let reduction_prover =
424			sumcheck::prove::univariatizing_reduction_prover::<_, FDomain<Tower>, _, _>(
425				reduced_multilinears,
426				univariatized_multilinear_evals,
427				univariate_challenge,
428				backend,
429			)?;
430
431		reduction_claims.push(reduction_claim);
432		reduction_provers.push(reduction_prover);
433	}
434
435	let univariatizing_output = sumcheck::prove::batch_prove(reduction_provers, &mut transcript)?;
436
437	let multilinear_zerocheck_output = sumcheck::univariate::verify_sumcheck_outputs(
438		&reduction_claims,
439		univariate_challenge,
440		&zerocheck_output.challenges,
441		univariatizing_output,
442	)?;
443
444	let zerocheck_eval_claims = sumcheck::make_eval_claims(
445		EvaluationOrder::LowToHigh,
446		zerocheck_oracle_metas,
447		multilinear_zerocheck_output,
448	)?;
449
450	// Prove evaluation claims
451	let GreedyEvalcheckProveOutput {
452		eval_claims,
453		memoized_data,
454	} = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
455		&mut oracles,
456		&mut witness,
457		[non_zero_prodcheck_eval_claims, flush_eval_claims]
458			.concat()
459			.into_iter()
460			.chain(zerocheck_eval_claims)
461			.chain(exp_eval_claims),
462		switchover_fn,
463		&mut transcript,
464		&domain_factory,
465		backend,
466	)?;
467
468	// Reduce committed evaluation claims to PIOP sumcheck claims
469	let system = ring_switch::EvalClaimSystem::new(
470		&oracles,
471		&commit_meta,
472		&oracle_to_commit_index,
473		&eval_claims,
474	)?;
475
476	let ring_switch::ReducedWitness {
477		transparents: transparent_multilins,
478		sumcheck_claims: piop_sumcheck_claims,
479	} = ring_switch::prove::<_, _, _, Tower, _, _>(
480		&system,
481		&committed_multilins,
482		&mut transcript,
483		memoized_data,
484		backend,
485	)?;
486
487	// Prove evaluation claims using PIOP compiler
488	piop::prove::<_, FDomain<Tower>, _, _, _, _, _, _, _, _>(
489		&fri_params,
490		&merkle_prover,
491		domain_factory,
492		&commit_meta,
493		committed,
494		&codeword,
495		&committed_multilins,
496		&transparent_multilins,
497		&piop_sumcheck_claims,
498		&mut transcript,
499		&backend,
500	)?;
501
502	Ok(Proof {
503		transcript: transcript.finalize(),
504	})
505}
506
507type TypeErasedUnivariateZerocheck<'a, F> = Box<dyn UnivariateZerocheckProver<'a, F> + 'a>;
508type TypeErasedSumcheck<'a, F> = Box<dyn SumcheckProver<F> + 'a>;
509type TypeErasedProver<'a, F> =
510	Either<TypeErasedUnivariateZerocheck<'a, F>, TypeErasedSumcheck<'a, F>>;
511
512struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
513where
514	P: PackedField,
515{
516	constraints: Vec<Constraint<P::Scalar>>,
517	multilinears: Vec<MultilinearWitness<'a, P>>,
518	domain_factory: DomainFactory,
519	switchover_fn: SwitchoverFn,
520	zerocheck_challenges: &'a [P::Scalar],
521	backend: &'a Backend,
522	_fdomain_marker: PhantomData<FDomain>,
523}
524
525impl<'a, P, F, FDomain, DomainFactory, SwitchoverFn, Backend>
526	ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, SwitchoverFn, Backend>
527where
528	F: Field,
529	P: PackedFieldIndexable<Scalar = F>,
530	FDomain: TowerField,
531	DomainFactory: EvaluationDomainFactory<FDomain> + 'a,
532	SwitchoverFn: Fn(usize) -> usize + Clone + 'a,
533	Backend: ComputationBackend,
534{
535	fn create<FBase>(
536		self,
537		is_univariate: impl FnOnce(usize) -> bool,
538	) -> Result<TypeErasedProver<'a, F>, Error>
539	where
540		FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
541		P: PackedExtension<F, PackedSubfield = P>
542			+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>
543			+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>,
544		F: TowerField,
545	{
546		let univariate_prover =
547			sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _, _, _>(
548				self.constraints,
549				self.multilinears,
550				self.domain_factory,
551				self.switchover_fn,
552				self.zerocheck_challenges,
553				self.backend,
554			)?;
555
556		let type_erased_prover = if is_univariate(univariate_prover.n_vars()) {
557			let type_erased_univariate_prover =
558				Box::new(univariate_prover) as TypeErasedUnivariateZerocheck<'a, P::Scalar>;
559
560			Either::Left(type_erased_univariate_prover)
561		} else {
562			let zerocheck_prover = univariate_prover.into_regular_zerocheck()?;
563			let type_erased_zerocheck_prover =
564				Box::new(zerocheck_prover) as TypeErasedSumcheck<'a, P::Scalar>;
565
566			Either::Right(type_erased_zerocheck_prover)
567		};
568
569		Ok(type_erased_prover)
570	}
571}
572
573#[instrument(skip_all, level = "debug")]
574fn make_unmasked_flush_witnesses<'a, U, Tower>(
575	oracles: &MultilinearOracleSet<FExt<Tower>>,
576	witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
577	flush_oracle_ids: &[OracleId],
578) -> Result<(), Error>
579where
580	U: ProverTowerUnderlier<Tower>,
581	Tower: ProverTowerFamily,
582{
583	// The function is on the critical path, parallelize.
584	let flush_witnesses: Result<Vec<MultilinearWitness<'a, _>>, Error> = flush_oracle_ids
585		.par_iter()
586		.map(|&oracle_id| {
587			let MultilinearPolyVariant::LinearCombination(lincom) =
588				oracles.oracle(oracle_id).variant
589			else {
590				unreachable!("make_flush_oracles adds linear combination oracles");
591			};
592			let polys = lincom
593				.polys()
594				.map(|id| witness.get_multilin_poly(id))
595				.collect::<Result<Vec<_>, _>>()?;
596
597			let packed_len = 1
598				<< lincom
599					.n_vars()
600					.saturating_sub(<PackedType<U, FExt<Tower>>>::LOG_WIDTH);
601			let data = (0..packed_len)
602				.into_par_iter()
603				.map(|i| {
604					<PackedType<U, FExt<Tower>>>::from_fn(|j| {
605						let index = i << <PackedType<U, FExt<Tower>>>::LOG_WIDTH | j;
606						polys.iter().zip(lincom.coefficients()).fold(
607							lincom.offset(),
608							|sum, (poly, coeff)| {
609								sum + poly
610									.evaluate_on_hypercube_and_scale(index, coeff)
611									.unwrap_or(<FExt<Tower>>::ZERO)
612							},
613						)
614					})
615				})
616				.collect::<Vec<_>>();
617			let lincom_poly = MultilinearExtension::new(lincom.n_vars(), data)
618				.expect("data is constructed with the correct length with respect to n_vars");
619
620			Ok(MLEDirectAdapter::from(lincom_poly).upcast_arc_dyn())
621		})
622		.collect();
623
624	witness.update_multilin_poly(izip!(flush_oracle_ids.iter().copied(), flush_witnesses?))?;
625	Ok(())
626}
627
628#[allow(clippy::type_complexity)]
629#[instrument(skip_all, level = "debug")]
630fn make_fast_masked_flush_witnesses<'a, U, Tower>(
631	oracles: &MultilinearOracleSet<FExt<Tower>>,
632	witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
633	flush_oracles: &[OracleId],
634	flush_selectors: Option<&[OracleId]>,
635) -> Result<Vec<MultilinearWitness<'a, PackedType<U, FFastExt<Tower>>>>, Error>
636where
637	U: ProverTowerUnderlier<Tower>,
638	Tower: ProverTowerFamily,
639	PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
640{
641	let to_fast = Tower::packed_transformation_to_fast();
642
643	// The function is on the critical path, parallelize.
644	flush_oracles
645		.par_iter()
646		.enumerate()
647		.map(|(i, &flush_oracle_id)| {
648			let n_vars = oracles.n_vars(flush_oracle_id);
649
650			let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
651			let width = 1 << log_width;
652
653			let packed_len = 1 << n_vars.saturating_sub(log_width);
654			let mut fast_ext_result = vec![PackedType::<U, FFastExt<Tower>>::one(); packed_len];
655
656			let poly = witness.get_multilin_poly(flush_oracle_id)?;
657			let selector = flush_selectors
658				.map(|flush_selectors| witness.get_multilin_poly(flush_selectors[i]))
659				.transpose()?;
660
661			const MAX_SUBCUBE_VARS: usize = 8;
662			let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
663			let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
664
665			fast_ext_result
666				.par_chunks_mut(subcube_packed_size)
667				.enumerate()
668				.for_each(|(subcube_index, fast_subcube)| {
669					let underliers =
670						PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
671
672					let subcube_evals =
673						PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
674					poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
675						.expect("witness data populated by make_unmasked_flush_witnesses()");
676
677					for underlier in underliers.iter_mut() {
678						let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
679						let dest = to_fast.transform(&src);
680						*underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
681					}
682
683					if let Some(selector) = &selector {
684						let fast_subcube =
685							PackedType::<U, FFastExt<Tower>>::from_underliers_ref_mut(underliers);
686
687						let mut ones_mask = PackedType::<U, FExt<Tower>>::default();
688						for (i, packed) in fast_subcube.iter_mut().enumerate() {
689							selector
690								.subcube_evals(
691									log_width,
692									(subcube_index << subcube_vars.saturating_sub(log_width)) | i,
693									0,
694									from_mut(&mut ones_mask),
695								)
696								.expect("selector n_vars equals flushed n_vars");
697
698							if ones_mask == PackedField::zero() {
699								*packed = PackedField::one();
700							} else if ones_mask != PackedField::one() {
701								for j in 0..width {
702									if ones_mask.get(j) == FExt::<Tower>::ZERO {
703										packed.set(j, FFastExt::<Tower>::ONE);
704									}
705								}
706							}
707						}
708					}
709				});
710
711			let masked_poly = MultilinearExtension::new(n_vars, fast_ext_result)
712				.expect("data is constructed with the correct length with respect to n_vars");
713			Ok(MLEDirectAdapter::from(masked_poly).upcast_arc_dyn())
714		})
715		.collect()
716}
717
718pub struct FlushSumcheckProvers<Prover> {
719	provers: Vec<Prover>,
720	flush_oracle_ids_by_claim: Vec<Vec<OracleId>>,
721	flush_selectors_unique_by_claim: Vec<Vec<OracleId>>,
722}
723
724#[instrument(skip_all, level = "debug")]
725fn get_flush_sumcheck_provers<'a, 'b, U, Tower, FDomain, DomainFactory, Backend>(
726	oracles: &mut MultilinearOracleSet<Tower::B128>,
727	flush_oracle_ids: &[OracleId],
728	flush_selectors: &[OracleId],
729	final_layer_claims: &[LayerClaim<Tower::B128>],
730	witness: &mut MultilinearExtensionIndex<'a, PackedType<U, Tower::B128>>,
731	domain_factory: DomainFactory,
732	backend: &'b Backend,
733) -> Result<FlushSumcheckProvers<impl SumcheckProver<Tower::B128> + 'b>, Error>
734where
735	U: ProverTowerUnderlier<Tower> + PackScalar<FDomain>,
736	Tower: ProverTowerFamily,
737	Tower::B128: ExtensionField<FDomain>,
738	FDomain: Field,
739	DomainFactory: EvaluationDomainFactory<FDomain>,
740	Backend: ComputationBackend,
741	PackedType<U, Tower::B128>: PackedFieldIndexable,
742	'a: 'b,
743{
744	let flush_sumcheck_metas = get_flush_dedup_sumcheck_metas(
745		oracles,
746		flush_oracle_ids,
747		flush_selectors,
748		final_layer_claims,
749	)?;
750
751	let n_claims = flush_sumcheck_metas.len();
752	let mut provers = Vec::with_capacity(n_claims);
753	let mut flush_oracle_ids_by_claim = Vec::with_capacity(n_claims);
754	let mut flush_selectors_unique_by_claim = Vec::with_capacity(n_claims);
755	for flush_sumcheck_meta in flush_sumcheck_metas {
756		let FlushSumcheckMeta {
757			composite_sum_claims,
758			flush_selectors_unique,
759			flush_oracle_ids,
760			eval_point,
761		} = flush_sumcheck_meta;
762
763		let mut multilinears =
764			Vec::with_capacity(flush_selectors_unique.len() + flush_oracle_ids.len());
765
766		let mut nonzero_scalars_prefixes = Vec::with_capacity(multilinears.len());
767
768		for &oracle_id in chain!(&flush_selectors_unique, &flush_oracle_ids) {
769			let entry = witness.get_index_entry(oracle_id)?;
770			multilinears.push(entry.multilin_poly);
771			nonzero_scalars_prefixes.push(entry.nonzero_scalars_prefix);
772		}
773
774		let prover = EqIndSumcheckProverBuilder::new(backend)
775			.with_nonzero_scalars_prefixes(&nonzero_scalars_prefixes)
776			.build(
777				EvaluationOrder::LowToHigh,
778				multilinears,
779				&eval_point,
780				composite_sum_claims,
781				domain_factory.clone(),
782				immediate_switchover_heuristic,
783			)?;
784
785		provers.push(prover);
786		flush_oracle_ids_by_claim.push(flush_oracle_ids);
787		flush_selectors_unique_by_claim.push(flush_selectors_unique);
788	}
789
790	Ok(FlushSumcheckProvers {
791		provers,
792		flush_selectors_unique_by_claim,
793		flush_oracle_ids_by_claim,
794	})
795}