binius_core/protocols/greedy_evalcheck/
prove.rs1use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::EvaluationDomainFactory;
6
7use super::{error::Error, logging::RegularSumcheckDimensionsData};
8use crate::{
9 fiat_shamir::Challenger,
10 oracle::MultilinearOracleSet,
11 protocols::evalcheck::{
12 ConstraintSetEqIndPoint, EvalcheckMultilinearClaim, EvalcheckProver,
13 subclaims::{
14 MemoizedData, prove_bivariate_sumchecks_with_switchover, prove_mlecheck_with_switchover,
15 },
16 },
17 transcript::ProverTranscript,
18 witness::MultilinearExtensionIndex,
19};
20
21pub struct GreedyEvalcheckProveOutput<'a, F: Field, P: PackedField> {
22 pub eval_claims: Vec<EvalcheckMultilinearClaim<F>>,
23 pub memoized_data: MemoizedData<'a, P>,
24}
25
26#[allow(clippy::too_many_arguments)]
27pub fn prove<'a, F, P, DomainField, Challenger_, Backend>(
28 oracles: &mut MultilinearOracleSet<F>,
29 witness_index: &'a mut MultilinearExtensionIndex<P>,
30 claims: impl IntoIterator<Item = EvalcheckMultilinearClaim<F>>,
31 switchover_fn: impl Fn(usize) -> usize + Clone + 'static,
32 transcript: &mut ProverTranscript<Challenger_>,
33 domain_factory: impl EvaluationDomainFactory<DomainField>,
34 backend: &Backend,
35) -> Result<GreedyEvalcheckProveOutput<'a, F, P>, Error>
36where
37 F: TowerField + ExtensionField<DomainField>,
38 P: PackedField<Scalar = F>
39 + PackedExtension<F, PackedSubfield = P>
40 + PackedExtension<DomainField>,
41 DomainField: TowerField,
42 Challenger_: Challenger,
43 Backend: ComputationBackend,
44{
45 let mut evalcheck_prover = EvalcheckProver::<F, P>::new(oracles, witness_index);
46
47 let claims: Vec<_> = claims.into_iter().collect();
48
49 let initial_evalcheck_round_span = tracing::debug_span!(
51 "[step] Initial Evalcheck Round",
52 phase = "evalcheck",
53 perfetto_category = "phase.sub"
54 )
55 .entered();
56 evalcheck_prover.prove(claims, transcript)?;
57 drop(initial_evalcheck_round_span);
58
59 loop {
60 let _span = tracing::debug_span!(
61 "[step] Evalcheck Round",
62 phase = "evalcheck",
63 perfetto_category = "phase.sub"
64 )
65 .entered();
66
67 let new_bivariate_sumchecks =
68 evalcheck_prover.take_new_bivariate_sumchecks_constraints()?;
69
70 let new_mlechecks = evalcheck_prover.take_new_mlechecks_constraints()?;
71
72 let mut new_evalcheck_claims =
73 Vec::with_capacity(new_bivariate_sumchecks.len() + new_mlechecks.len());
74
75 if !new_bivariate_sumchecks.is_empty() {
76 let dimensions_data =
79 RegularSumcheckDimensionsData::new(new_bivariate_sumchecks.iter());
80 let evalcheck_round_mle_fold_high_span = tracing::debug_span!(
81 "[task] (Evalcheck) Regular Sumcheck (Small)",
82 phase = "evalcheck",
83 perfetto_category = "task.main",
84 dimensions_data = ?dimensions_data,
85 )
86 .entered();
87 let evalcheck_claims =
88 prove_bivariate_sumchecks_with_switchover::<_, _, DomainField, _, _>(
89 evalcheck_prover.witness_index,
90 new_bivariate_sumchecks,
91 transcript,
92 switchover_fn.clone(),
93 domain_factory.clone(),
94 backend,
95 )?;
96
97 new_evalcheck_claims.extend(evalcheck_claims);
98 drop(evalcheck_round_mle_fold_high_span);
99 }
100
101 if !new_mlechecks.is_empty() {
102 let dimensions_data = RegularSumcheckDimensionsData::new(
104 new_mlechecks
105 .iter()
106 .map(|new_mlecheck| &new_mlecheck.constraint_set),
107 );
108 let evalcheck_round_mle_fold_high_span = tracing::debug_span!(
109 "[task] (Evalcheck) MLE check",
110 phase = "evalcheck",
111 perfetto_category = "task.main",
112 dimensions_data = ?dimensions_data,
113 )
114 .entered();
115
116 for ConstraintSetEqIndPoint {
117 eq_ind_challenges,
118 constraint_set,
119 } in new_mlechecks
120 {
121 let evalcheck_claims = prove_mlecheck_with_switchover::<_, _, DomainField, _, _>(
122 evalcheck_prover.witness_index,
123 constraint_set,
124 eq_ind_challenges,
125 &mut evalcheck_prover.memoized_data,
126 transcript,
127 switchover_fn.clone(),
128 domain_factory.clone(),
129 backend,
130 )?;
131 new_evalcheck_claims.extend(evalcheck_claims);
132 }
133
134 drop(evalcheck_round_mle_fold_high_span);
135 }
136
137 if new_evalcheck_claims.is_empty() {
138 break;
139 }
140
141 evalcheck_prover.prove(new_evalcheck_claims, transcript)?;
142 }
143
144 let committed_claims = evalcheck_prover
145 .committed_eval_claims_mut()
146 .drain(..)
147 .collect::<Vec<_>>();
148
149 Ok(GreedyEvalcheckProveOutput {
150 eval_claims: committed_claims,
151 memoized_data: evalcheck_prover.memoized_data,
152 })
153}