binius_core/protocols/greedy_evalcheck/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::EvaluationDomainFactory;
6
7use super::{error::Error, logging::RegularSumcheckDimensionsData};
8use crate::{
9	fiat_shamir::Challenger,
10	oracle::MultilinearOracleSet,
11	protocols::evalcheck::{
12		subclaims::{prove_bivariate_sumchecks_with_switchover, MemoizedData},
13		EvalcheckMultilinearClaim, EvalcheckProver,
14	},
15	transcript::ProverTranscript,
16	witness::MultilinearExtensionIndex,
17};
18
19pub struct GreedyEvalcheckProveOutput<'a, F: Field, P: PackedField, Backend: ComputationBackend> {
20	pub eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
21	pub memoized_data: MemoizedData<'a, P, Backend>,
22}
23
24#[allow(clippy::too_many_arguments)]
25pub fn prove<'a, F, P, DomainField, Challenger_, Backend>(
26	oracles: &mut MultilinearOracleSet<F>,
27	witness_index: &'a mut MultilinearExtensionIndex<P>,
28	claims: impl IntoIterator<Item = EvalcheckMultilinearClaim<F>>,
29	switchover_fn: impl Fn(usize) -> usize + Clone + 'static,
30	transcript: &mut ProverTranscript<Challenger_>,
31	domain_factory: impl EvaluationDomainFactory<DomainField>,
32	backend: &Backend,
33) -> Result<GreedyEvalcheckProveOutput<'a, F, P, Backend>, Error>
34where
35	F: TowerField + ExtensionField<DomainField>,
36	P: PackedField<Scalar = F>
37		+ PackedExtension<F, PackedSubfield = P>
38		+ PackedExtension<DomainField>,
39	DomainField: TowerField,
40	Challenger_: Challenger,
41	Backend: ComputationBackend,
42{
43	let mut evalcheck_prover =
44		EvalcheckProver::<F, P, Backend>::new(oracles, witness_index, backend);
45
46	let claims: Vec<_> = claims.into_iter().collect();
47
48	// Prove the initial evalcheck claims
49	let initial_evalcheck_round_span = tracing::debug_span!(
50		"[phase] Initial Evalcheck Round",
51		phase = "evalcheck",
52		perfetto_category = "task.main"
53	)
54	.entered();
55	evalcheck_prover.prove(claims, transcript)?;
56	drop(initial_evalcheck_round_span);
57
58	loop {
59		let _span = tracing::debug_span!(
60			"[step] Evalcheck Round",
61			phase = "evalcheck",
62			perfetto_category = "phase.sub"
63		)
64		.entered();
65		let new_sumchecks = evalcheck_prover.take_new_sumchecks_constraints().unwrap();
66		if new_sumchecks.is_empty() {
67			break;
68		}
69
70		// Reduce the new sumcheck claims for virtual polynomial openings to new evalcheck claims.
71		let dimensions_data = RegularSumcheckDimensionsData::new(new_sumchecks.iter());
72		let evalcheck_round_mle_fold_high_span = tracing::debug_span!(
73			"[task] (Evalcheck) Regular Sumcheck (Small)",
74			phase = "evalcheck",
75			perfetto_category = "task.main",
76			dimensions_data = ?dimensions_data,
77		)
78		.entered();
79		let new_evalcheck_claims =
80			prove_bivariate_sumchecks_with_switchover::<_, _, DomainField, _, _>(
81				evalcheck_prover.witness_index,
82				new_sumchecks,
83				transcript,
84				switchover_fn.clone(),
85				domain_factory.clone(),
86				backend,
87			)?;
88		drop(evalcheck_round_mle_fold_high_span);
89
90		evalcheck_prover.prove(new_evalcheck_claims, transcript)?;
91	}
92
93	let committed_claims = evalcheck_prover
94		.committed_eval_claims_mut()
95		.drain(..)
96		.collect::<Vec<_>>();
97
98	Ok(GreedyEvalcheckProveOutput {
99		eval_claims: committed_claims,
100		memoized_data: evalcheck_prover.memoized_data,
101	})
102}