binius_core/protocols/greedy_evalcheck/
prove.rs1use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::EvaluationDomainFactory;
6
7use super::{error::Error, logging::RegularSumcheckDimensionsData};
8use crate::{
9 fiat_shamir::Challenger,
10 oracle::MultilinearOracleSet,
11 protocols::evalcheck::{
12 subclaims::{prove_bivariate_sumchecks_with_switchover, MemoizedData},
13 EvalcheckMultilinearClaim, EvalcheckProver,
14 },
15 transcript::ProverTranscript,
16 witness::MultilinearExtensionIndex,
17};
18
19pub struct GreedyEvalcheckProveOutput<'a, F: Field, P: PackedField, Backend: ComputationBackend> {
20 pub eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
21 pub memoized_data: MemoizedData<'a, P, Backend>,
22}
23
24#[allow(clippy::too_many_arguments)]
25pub fn prove<'a, F, P, DomainField, Challenger_, Backend>(
26 oracles: &mut MultilinearOracleSet<F>,
27 witness_index: &'a mut MultilinearExtensionIndex<P>,
28 claims: impl IntoIterator<Item = EvalcheckMultilinearClaim<F>>,
29 switchover_fn: impl Fn(usize) -> usize + Clone + 'static,
30 transcript: &mut ProverTranscript<Challenger_>,
31 domain_factory: impl EvaluationDomainFactory<DomainField>,
32 backend: &Backend,
33) -> Result<GreedyEvalcheckProveOutput<'a, F, P, Backend>, Error>
34where
35 F: TowerField + ExtensionField<DomainField>,
36 P: PackedField<Scalar = F>
37 + PackedExtension<F, PackedSubfield = P>
38 + PackedExtension<DomainField>,
39 DomainField: TowerField,
40 Challenger_: Challenger,
41 Backend: ComputationBackend,
42{
43 let mut evalcheck_prover =
44 EvalcheckProver::<F, P, Backend>::new(oracles, witness_index, backend);
45
46 let claims: Vec<_> = claims.into_iter().collect();
47
48 let initial_evalcheck_round_span = tracing::debug_span!(
50 "[phase] Initial Evalcheck Round",
51 phase = "evalcheck",
52 perfetto_category = "task.main"
53 )
54 .entered();
55 evalcheck_prover.prove(claims, transcript)?;
56 drop(initial_evalcheck_round_span);
57
58 loop {
59 let _span = tracing::debug_span!(
60 "[step] Evalcheck Round",
61 phase = "evalcheck",
62 perfetto_category = "phase.sub"
63 )
64 .entered();
65 let new_sumchecks = evalcheck_prover.take_new_sumchecks_constraints().unwrap();
66 if new_sumchecks.is_empty() {
67 break;
68 }
69
70 let dimensions_data = RegularSumcheckDimensionsData::new(new_sumchecks.iter());
72 let evalcheck_round_mle_fold_high_span = tracing::debug_span!(
73 "[task] (Evalcheck) Regular Sumcheck (Small)",
74 phase = "evalcheck",
75 perfetto_category = "task.main",
76 dimensions_data = ?dimensions_data,
77 )
78 .entered();
79 let new_evalcheck_claims =
80 prove_bivariate_sumchecks_with_switchover::<_, _, DomainField, _, _>(
81 evalcheck_prover.witness_index,
82 new_sumchecks,
83 transcript,
84 switchover_fn.clone(),
85 domain_factory.clone(),
86 backend,
87 )?;
88 drop(evalcheck_round_mle_fold_high_span);
89
90 evalcheck_prover.prove(new_evalcheck_claims, transcript)?;
91 }
92
93 let committed_claims = evalcheck_prover
94 .committed_eval_claims_mut()
95 .drain(..)
96 .collect::<Vec<_>>();
97
98 Ok(GreedyEvalcheckProveOutput {
99 eval_claims: committed_claims,
100 memoized_data: evalcheck_prover.memoized_data,
101 })
102}