1use std::{env, marker::PhantomData};
4
5use binius_field::{
6 as_packed_field::PackedType,
7 linear_transformation::{PackedTransformationFactory, Transformation},
8 tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier},
9 underlier::WithUnderlier,
10 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
11 RepackedExtension, TowerField,
12};
13use binius_hal::ComputationBackend;
14use binius_hash::PseudoCompressionFunction;
15use binius_math::{
16 CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
17 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
18};
19use binius_maybe_rayon::prelude::*;
20use binius_ntt::SingleThreadedNTT;
21use binius_utils::bail;
22use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset, Output};
23use itertools::chain;
24use tracing::instrument;
25
26use super::{
27 channel::Boundary,
28 error::Error,
29 verify::{make_flush_oracles, max_n_vars_and_skip_rounds},
30 ConstraintSystem, Proof,
31};
32use crate::{
33 constraint_system::{
34 common::{FDomain, FEncode, FExt, FFastExt},
35 exp,
36 },
37 fiat_shamir::{CanSample, Challenger},
38 merkle_tree::BinaryMerkleTreeProver,
39 oracle::{Constraint, MultilinearOracleSet, MultilinearPolyVariant, OracleId},
40 piop,
41 protocols::{
42 fri::CommitOutput,
43 gkr_exp,
44 gkr_gpa::{self, GrandProductBatchProveOutput, GrandProductWitness},
45 greedy_evalcheck::{self, GreedyEvalcheckProveOutput},
46 sumcheck::{
47 self, constraint_set_zerocheck_claim, prove::ZerocheckProver,
48 standard_switchover_heuristic,
49 },
50 },
51 ring_switch,
52 transcript::ProverTranscript,
53 witness::{MultilinearExtensionIndex, MultilinearWitness},
54};
55
56#[instrument("constraint_system::prove", skip_all, level = "debug")]
58pub fn prove<U, Tower, Hash, Compress, Challenger_, Backend>(
59 constraint_system: &ConstraintSystem<FExt<Tower>>,
60 log_inv_rate: usize,
61 security_bits: usize,
62 boundaries: &[Boundary<FExt<Tower>>],
63 mut witness: MultilinearExtensionIndex<PackedType<U, FExt<Tower>>>,
64 backend: &Backend,
65) -> Result<Proof, Error>
66where
67 U: ProverTowerUnderlier<Tower>,
68 Tower: ProverTowerFamily,
69 Tower::B128: PackedTop<Tower>,
70 Hash: Digest + BlockSizeUser + FixedOutputReset + Send + Sync + Clone,
71 Compress: PseudoCompressionFunction<Output<Hash>, 2> + Default + Sync,
72 Challenger_: Challenger + Default,
73 Backend: ComputationBackend,
74 PackedType<U, Tower::B128>: PackedTop<Tower>
76 + PackedFieldIndexable + RepackedExtension<PackedType<U, Tower::B8>>
78 + RepackedExtension<PackedType<U, Tower::B16>>
79 + RepackedExtension<PackedType<U, Tower::B32>>
80 + RepackedExtension<PackedType<U, Tower::B64>>
81 + RepackedExtension<PackedType<U, Tower::B128>>
82 + PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
83 PackedType<U, Tower::FastB128>: PackedTransformationFactory<PackedType<U, Tower::B128>>,
84{
85 tracing::debug!(
86 arch = env::consts::ARCH,
87 rayon_threads = binius_maybe_rayon::current_num_threads(),
88 "using computation backend: {backend:?}"
89 );
90
91 let domain_factory = DefaultEvaluationDomainFactory::<FDomain<Tower>>::default();
92 let fast_domain_factory = IsomorphicEvaluationDomainFactory::<FFastExt<Tower>>::default();
93
94 let mut transcript = ProverTranscript::<Challenger_>::new();
95 transcript.observe().write_slice(boundaries);
96
97 let ConstraintSystem {
98 mut oracles,
99 mut table_constraints,
100 mut flushes,
101 mut exponents,
102 non_zero_oracle_ids,
103 max_channel_id,
104 } = constraint_system.clone();
105
106 exponents.sort_by_key(|b| std::cmp::Reverse(b.n_vars(&oracles)));
107
108 let exp_witnesses = exp::make_exp_witnesses::<U, Tower>(&mut witness, &oracles, &exponents)?;
111
112 table_constraints.sort_by_key(|constraint_set| constraint_set.n_vars);
114
115 let merkle_prover = BinaryMerkleTreeProver::<_, Hash, _>::new(Compress::default());
117 let merkle_scheme = merkle_prover.scheme();
118
119 let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?;
120 let committed_multilins = piop::collect_committed_witnesses::<U, _>(
121 &commit_meta,
122 &oracle_to_commit_index,
123 &oracles,
124 &witness,
125 )?;
126
127 let fri_params = piop::make_commit_params_with_optimal_arity::<_, FEncode<Tower>, _>(
128 &commit_meta,
129 merkle_scheme,
130 security_bits,
131 log_inv_rate,
132 )?;
133 let ntt = SingleThreadedNTT::new(fri_params.rs_code().log_len())?
134 .precompute_twiddles()
135 .multithreaded();
136
137 let commit_span =
138 tracing::info_span!("[phase] Commit", phase = "commit", perfetto_category = "phase.main")
139 .entered();
140 let CommitOutput {
141 commitment,
142 committed,
143 codeword,
144 } = piop::commit(&fri_params, &ntt, &merkle_prover, &committed_multilins)?;
145 drop(commit_span);
146
147 let mut writer = transcript.message();
149 writer.write(&commitment);
150
151 let exp_challenge = transcript.sample_vec(exp::max_n_vars(&exponents, &oracles));
153
154 let exp_evals = gkr_exp::get_evals_in_point_from_witnesses(&exp_witnesses, &exp_challenge)?
155 .into_iter()
156 .map(|x| x.into())
157 .collect::<Vec<_>>();
158
159 let mut writer = transcript.message();
160 writer.write_scalar_slice(&exp_evals);
161
162 let exp_challenge = exp_challenge
163 .into_iter()
164 .map(|x| x.into())
165 .collect::<Vec<_>>();
166
167 let exp_claims = exp::make_claims(&exponents, &oracles, &exp_challenge, &exp_evals)?
168 .into_iter()
169 .map(|claim| claim.isomorphic())
170 .collect::<Vec<_>>();
171
172 let base_exp_output = gkr_exp::batch_prove::<_, _, FFastExt<Tower>, _, _>(
173 EvaluationOrder::HighToLow,
174 exp_witnesses,
175 &exp_claims,
176 fast_domain_factory.clone(),
177 &mut transcript,
178 backend,
179 )?
180 .isomorphic();
181
182 let exp_eval_claims = exp::make_eval_claims(&exponents, base_exp_output)?;
183
184 let non_zero_fast_witnesses =
187 make_fast_unmasked_flush_witnesses::<U, _>(&oracles, &witness, &non_zero_oracle_ids)?;
188 let non_zero_prodcheck_witnesses = non_zero_fast_witnesses
189 .into_par_iter()
190 .map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
191 .collect::<Result<Vec<_>, _>>()?;
192
193 let non_zero_products =
194 gkr_gpa::get_grand_products_from_witnesses(&non_zero_prodcheck_witnesses);
195 if non_zero_products
196 .iter()
197 .any(|count| *count == Tower::B128::zero())
198 {
199 bail!(Error::Zeros);
200 }
201
202 let mut writer = transcript.message();
203
204 writer.write_scalar_slice(&non_zero_products);
205
206 let non_zero_prodcheck_claims = gkr_gpa::construct_grand_product_claims(
207 &non_zero_oracle_ids,
208 &oracles,
209 &non_zero_products,
210 )?;
211
212 let mixing_challenge = transcript.sample();
214 let permutation_challenges = transcript.sample_vec(max_channel_id + 1);
215
216 flushes.sort_by_key(|flush| flush.channel_id);
217 let flush_oracle_ids =
218 make_flush_oracles(&mut oracles, &flushes, mixing_challenge, &permutation_challenges)?;
219
220 make_masked_flush_witnesses::<U, _>(&oracles, &mut witness, &flush_oracle_ids)?;
221
222 let flush_witnesses =
224 make_fast_unmasked_flush_witnesses::<U, _>(&oracles, &witness, &flush_oracle_ids)?;
225
226 let flush_prodcheck_witnesses = flush_witnesses
228 .into_par_iter()
229 .map(|(n_vars, evals)| GrandProductWitness::new(n_vars, evals))
230 .collect::<Result<Vec<_>, _>>()?;
231 let flush_products = gkr_gpa::get_grand_products_from_witnesses(&flush_prodcheck_witnesses);
232
233 transcript.message().write_scalar_slice(&flush_products);
234
235 let flush_prodcheck_claims =
236 gkr_gpa::construct_grand_product_claims(&flush_oracle_ids, &oracles, &flush_products)?;
237
238 let all_gpa_witnesses = [flush_prodcheck_witnesses, non_zero_prodcheck_witnesses].concat();
240 let all_gpa_claims = chain!(flush_prodcheck_claims, non_zero_prodcheck_claims)
241 .map(|claim| claim.isomorphic())
242 .collect::<Vec<_>>();
243
244 let GrandProductBatchProveOutput { final_layer_claims } =
245 gkr_gpa::batch_prove::<FFastExt<Tower>, _, FFastExt<Tower>, _, _>(
246 EvaluationOrder::LowToHigh,
247 all_gpa_witnesses,
248 &all_gpa_claims,
249 &fast_domain_factory,
250 &mut transcript,
251 backend,
252 )?;
253
254 let final_layer_claims = final_layer_claims
256 .into_iter()
257 .map(|layer_claim| layer_claim.isomorphic())
258 .collect::<Vec<_>>();
259
260 let prodcheck_eval_claims = gkr_gpa::make_eval_claims(
262 chain!(flush_oracle_ids, non_zero_oracle_ids),
263 final_layer_claims,
264 )?;
265
266 let zerocheck_span = tracing::info_span!(
268 "[phase] Zerocheck",
269 phase = "zerocheck",
270 perfetto_category = "phase.main",
271 )
272 .entered();
273
274 let (zerocheck_claims, zerocheck_oracle_metas) = table_constraints
275 .iter()
276 .cloned()
277 .map(constraint_set_zerocheck_claim)
278 .collect::<Result<Vec<_>, _>>()?
279 .into_iter()
280 .unzip::<_, _, Vec<_>, Vec<_>>();
281
282 let (max_n_vars, skip_rounds) =
283 max_n_vars_and_skip_rounds(&zerocheck_claims, FDomain::<Tower>::N_BITS);
284
285 let zerocheck_challenges = transcript.sample_vec(max_n_vars - skip_rounds);
286
287 let mut zerocheck_provers = Vec::with_capacity(table_constraints.len());
288
289 for constraint_set in table_constraints {
290 let n_vars = constraint_set.n_vars;
291 let (constraints, multilinears) =
292 sumcheck::prove::split_constraint_set(constraint_set, &witness)?;
293
294 let base_tower_level = chain!(
295 multilinears
296 .iter()
297 .map(|multilinear| 7 - multilinear.log_extension_degree()),
298 constraints
299 .iter()
300 .map(|constraint| constraint.composition.binary_tower_level())
301 )
302 .max()
303 .unwrap_or(0);
304
305 let zerocheck_challenges = &zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..];
307 let domain_factory = domain_factory.clone();
308
309 let constructor =
310 ZerocheckProverConstructor::<PackedType<U, FExt<Tower>>, FDomain<Tower>, _, _> {
311 constraints,
312 multilinears,
313 zerocheck_challenges,
314 domain_factory,
315 backend,
316 _fdomain_marker: PhantomData,
317 };
318
319 let zerocheck_prover = match base_tower_level {
320 0..=3 => constructor.create::<Tower::B8>()?,
321 4 => constructor.create::<Tower::B16>()?,
322 5 => constructor.create::<Tower::B32>()?,
323 6 => constructor.create::<Tower::B64>()?,
324 7 => constructor.create::<Tower::B128>()?,
325 _ => unreachable!(),
326 };
327
328 zerocheck_provers.push(zerocheck_prover);
329 }
330
331 let zerocheck_output = sumcheck::prove::batch_prove_zerocheck::<
332 FExt<Tower>,
333 FDomain<Tower>,
334 PackedType<U, FExt<Tower>>,
335 _,
336 _,
337 >(zerocheck_provers, skip_rounds, &mut transcript)?;
338
339 let zerocheck_eval_claims =
340 sumcheck::make_zerocheck_eval_claims(zerocheck_oracle_metas, zerocheck_output)?;
341
342 drop(zerocheck_span);
343
344 let evalcheck_span = tracing::info_span!(
345 "[phase] Evalcheck",
346 phase = "evalcheck",
347 perfetto_category = "phase.main"
348 )
349 .entered();
350
351 let GreedyEvalcheckProveOutput {
353 eval_claims,
354 memoized_data,
355 } = greedy_evalcheck::prove::<_, _, FDomain<Tower>, _, _>(
356 &mut oracles,
357 &mut witness,
358 chain!(prodcheck_eval_claims, zerocheck_eval_claims, exp_eval_claims,),
359 standard_switchover_heuristic(-2),
360 &mut transcript,
361 &domain_factory,
362 backend,
363 )?;
364
365 let system = ring_switch::EvalClaimSystem::new(
367 &oracles,
368 &commit_meta,
369 &oracle_to_commit_index,
370 &eval_claims,
371 )?;
372
373 drop(evalcheck_span);
374
375 let ring_switch_span = tracing::info_span!(
376 "[phase] Ring Switch",
377 phase = "ring_switch",
378 perfetto_category = "phase.main"
379 )
380 .entered();
381 let ring_switch::ReducedWitness {
382 transparents: transparent_multilins,
383 sumcheck_claims: piop_sumcheck_claims,
384 } = ring_switch::prove::<_, _, _, Tower, _, _>(
385 &system,
386 &committed_multilins,
387 &mut transcript,
388 memoized_data,
389 backend,
390 )?;
391 drop(ring_switch_span);
392
393 let piop_compiler_span = tracing::info_span!(
395 "[phase] PIOP Compiler",
396 phase = "piop_compiler",
397 perfetto_category = "phase.main"
398 )
399 .entered();
400 piop::prove::<_, FDomain<Tower>, _, _, _, _, _, _, _, _, _>(
401 &fri_params,
402 &ntt,
403 &merkle_prover,
404 domain_factory,
405 &commit_meta,
406 committed,
407 &codeword,
408 &committed_multilins,
409 &transparent_multilins,
410 &piop_sumcheck_claims,
411 &mut transcript,
412 &backend,
413 )?;
414 drop(piop_compiler_span);
415
416 let proof = Proof {
417 transcript: transcript.finalize(),
418 };
419
420 tracing::event!(
421 name: "proof_size",
422 tracing::Level::INFO,
423 counter = true,
424 value = proof.get_proof_size() as u64,
425 unit = "bytes",
426 );
427
428 Ok(proof)
429}
430
431type TypeErasedZerocheck<'a, P> = Box<dyn ZerocheckProver<'a, P> + 'a>;
432
433struct ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
434where
435 P: PackedField,
436{
437 constraints: Vec<Constraint<P::Scalar>>,
438 multilinears: Vec<MultilinearWitness<'a, P>>,
439 domain_factory: DomainFactory,
440 zerocheck_challenges: &'a [P::Scalar],
441 backend: &'a Backend,
442 _fdomain_marker: PhantomData<FDomain>,
443}
444
445impl<'a, P, F, FDomain, DomainFactory, Backend>
446 ZerocheckProverConstructor<'a, P, FDomain, DomainFactory, Backend>
447where
448 F: Field,
449 P: PackedField<Scalar = F>,
450 FDomain: TowerField,
451 DomainFactory: EvaluationDomainFactory<FDomain> + 'a,
452 Backend: ComputationBackend,
453{
454 fn create<FBase>(self) -> Result<TypeErasedZerocheck<'a, P>, Error>
455 where
456 FBase: TowerField + ExtensionField<FDomain> + TryFrom<F>,
457 P: PackedExtension<F, PackedSubfield = P>
458 + PackedExtension<FDomain>
459 + PackedExtension<FBase>,
460 F: TowerField,
461 {
462 let zerocheck_prover =
463 sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _, _>(
464 self.constraints,
465 self.multilinears,
466 self.domain_factory,
467 self.zerocheck_challenges,
468 self.backend,
469 )?;
470
471 let type_erased_zerocheck_prover = Box::new(zerocheck_prover) as TypeErasedZerocheck<'a, P>;
472
473 Ok(type_erased_zerocheck_prover)
474 }
475}
476
477#[instrument(skip_all, level = "debug")]
478fn make_masked_flush_witnesses<'a, U, Tower>(
479 oracles: &MultilinearOracleSet<FExt<Tower>>,
480 witness: &mut MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
481 flush_oracle_ids: &[OracleId],
482) -> Result<(), Error>
483where
484 U: ProverTowerUnderlier<Tower>,
485 Tower: ProverTowerFamily,
486{
487 let indices_to_update: Vec<(OracleId, MultilinearWitness<'a, _>)> = flush_oracle_ids
489 .par_iter()
490 .map(|&flush_oracle| match oracles.oracle(flush_oracle).variant {
491 MultilinearPolyVariant::Composite(composite) => {
492 let inner_polys = composite.inner();
493
494 let polys = inner_polys
495 .iter()
496 .map(|id| witness.get_multilin_poly(*id))
497 .collect::<Result<Vec<_>, _>>()?;
498
499 let n_vars = composite.n_vars();
500 let log_width = <PackedType<U, FExt<Tower>>>::LOG_WIDTH;
501
502 let packed_len = 1 << n_vars.saturating_sub(log_width);
503
504 let inner_c = composite.c();
505
506 let composite_data = (0..packed_len)
507 .into_par_iter()
508 .map(|i| {
509 <PackedType<U, FExt<Tower>>>::from_fn(|j| {
510 let index = i << <PackedType<U, FExt<Tower>>>::LOG_WIDTH | j;
511 let evals = polys
512 .iter()
513 .map(|poly| poly.evaluate_on_hypercube(index).unwrap_or_default())
514 .collect::<Vec<_>>();
515
516 inner_c
517 .evaluate(&evals)
518 .expect("query length is the same as poly length")
519 })
520 })
521 .collect::<Vec<_>>();
522
523 let composite_poly = MultilinearExtension::new(n_vars, composite_data)
524 .expect("data is constructed with the correct length with respect to n_vars");
525
526 Ok((flush_oracle, MLEDirectAdapter::from(composite_poly).upcast_arc_dyn()))
527 }
528 MultilinearPolyVariant::LinearCombination(lincom) => {
529 let polys = lincom
530 .polys()
531 .map(|id| witness.get_multilin_poly(id))
532 .collect::<Result<Vec<_>, _>>()?;
533
534 let packed_len = 1
535 << lincom
536 .n_vars()
537 .saturating_sub(<PackedType<U, FExt<Tower>>>::LOG_WIDTH);
538 let lin_comb_data = (0..packed_len)
539 .into_par_iter()
540 .map(|i| {
541 <PackedType<U, FExt<Tower>>>::from_fn(|j| {
542 let index = i << <PackedType<U, FExt<Tower>>>::LOG_WIDTH | j;
543 polys.iter().zip(lincom.coefficients()).fold(
544 lincom.offset(),
545 |sum, (poly, coeff)| {
546 sum + poly
547 .evaluate_on_hypercube_and_scale(index, coeff)
548 .unwrap_or(<FExt<Tower>>::ZERO)
549 },
550 )
551 })
552 })
553 .collect::<Vec<_>>();
554
555 let lincom_poly = MultilinearExtension::new(lincom.n_vars(), lin_comb_data)
556 .expect("data is constructed with the correct length with respect to n_vars");
557 Ok((flush_oracle, MLEDirectAdapter::from(lincom_poly).upcast_arc_dyn()))
558 }
559 _ => unreachable!("flush_oracles must either be composite or linear combinations"),
560 })
561 .collect::<Result<Vec<_>, Error>>()?;
562
563 witness.update_multilin_poly(indices_to_update.into_iter())?;
564 Ok(())
565}
566
567#[allow(clippy::type_complexity)]
568#[instrument(skip_all, level = "debug")]
569fn make_fast_unmasked_flush_witnesses<'a, U, Tower>(
570 oracles: &MultilinearOracleSet<FExt<Tower>>,
571 witness: &MultilinearExtensionIndex<'a, PackedType<U, FExt<Tower>>>,
572 flush_oracles: &[OracleId],
573) -> Result<Vec<(usize, Vec<PackedType<U, FFastExt<Tower>>>)>, Error>
574where
575 U: ProverTowerUnderlier<Tower>,
576 Tower: ProverTowerFamily,
577 PackedType<U, Tower::B128>: PackedTransformationFactory<PackedType<U, Tower::FastB128>>,
578{
579 let to_fast = Tower::packed_transformation_to_fast();
580
581 flush_oracles
583 .into_par_iter()
584 .map(|&flush_oracle_id| {
585 let n_vars = oracles.n_vars(flush_oracle_id);
586
587 let log_width = <PackedType<U, FFastExt<Tower>>>::LOG_WIDTH;
588
589 let poly = witness.get_multilin_poly(flush_oracle_id)?;
590
591 const MAX_SUBCUBE_VARS: usize = 8;
592 let subcube_vars = MAX_SUBCUBE_VARS.min(n_vars);
593 let subcube_packed_size = 1 << subcube_vars.saturating_sub(log_width);
594 let non_const_scalars = 1usize << n_vars;
595 let non_const_subcubes = non_const_scalars.div_ceil(1 << subcube_vars);
596
597 let mut fast_ext_result = vec![
598 PackedType::<U, FFastExt<Tower>>::one();
599 non_const_subcubes * subcube_packed_size
600 ];
601
602 fast_ext_result
603 .par_chunks_exact_mut(subcube_packed_size)
604 .enumerate()
605 .for_each(|(subcube_index, fast_subcube)| {
606 let underliers =
607 PackedType::<U, FFastExt<Tower>>::to_underliers_ref_mut(fast_subcube);
608
609 let subcube_evals =
610 PackedType::<U, FExt<Tower>>::from_underliers_ref_mut(underliers);
611 poly.subcube_evals(subcube_vars, subcube_index, 0, subcube_evals)
612 .expect("witness data populated by make_unmasked_flush_witnesses()");
613
614 for underlier in underliers.iter_mut() {
615 let src = PackedType::<U, FExt<Tower>>::from_underlier(*underlier);
616 let dest = to_fast.transform(&src);
617 *underlier = PackedType::<U, FFastExt<Tower>>::to_underlier(dest);
618 }
619 });
620
621 fast_ext_result.truncate(non_const_scalars);
622 Ok((n_vars, fast_ext_result))
623 })
624 .collect()
625}