binius_core/protocols/sumcheck/prove/
batch_zerocheck.rs1use std::sync::Arc;
4
5use binius_field::{ExtensionField, PackedExtension, PackedField, TowerField};
6use binius_hal::{make_portable_backend, CpuBackend};
7use binius_math::{
8 BinarySubspace, EvaluationDomain, EvaluationOrder, IsomorphicEvaluationDomainFactory,
9 MLEDirectAdapter, MultilinearPoly,
10};
11use binius_utils::{bail, sorting::is_sorted_ascending};
12
13use crate::{
14 fiat_shamir::{CanSample, Challenger},
15 protocols::sumcheck::{
16 immediate_switchover_heuristic,
17 prove::{
18 front_loaded, logging::FoldLowDimensionsData, RegularSumcheckProver, SumcheckProver,
19 },
20 zerocheck::{
21 lagrange_evals_multilinear_extension, univariatizing_reduction_claim,
22 BatchZerocheckOutput, ZerocheckRoundEvals,
23 },
24 BatchSumcheckOutput, Error,
25 },
26 transcript::ProverTranscript,
27};
28
29pub trait ZerocheckProver<'a, P: PackedField> {
44 fn n_vars(&self) -> usize;
46
47 fn domain_size(&self, skip_rounds: usize) -> Option<usize>;
51
52 fn execute_univariate_round(
60 &mut self,
61 skip_rounds: usize,
62 max_domain_size: usize,
63 batch_coeff: P::Scalar,
64 ) -> Result<ZerocheckRoundEvals<P::Scalar>, Error>;
65
66 fn fold_univariate_round(
68 &mut self,
69 challenge: P::Scalar,
70 ) -> Result<Box<dyn SumcheckProver<P::Scalar> + 'a>, Error>;
71
72 fn project_to_skipped_variables(
74 self: Box<Self>,
75 challenges: &[P::Scalar],
76 ) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error>;
77}
78
79impl<'a, P: PackedField, Prover: ZerocheckProver<'a, P> + ?Sized> ZerocheckProver<'a, P>
81 for Box<Prover>
82{
83 fn n_vars(&self) -> usize {
84 (**self).n_vars()
85 }
86
87 fn domain_size(&self, skip_rounds: usize) -> Option<usize> {
88 (**self).domain_size(skip_rounds)
89 }
90
91 fn execute_univariate_round(
92 &mut self,
93 skip_rounds: usize,
94 max_domain_size: usize,
95 batch_coeff: P::Scalar,
96 ) -> Result<ZerocheckRoundEvals<P::Scalar>, Error> {
97 (**self).execute_univariate_round(skip_rounds, max_domain_size, batch_coeff)
98 }
99
100 fn fold_univariate_round(
101 &mut self,
102 challenge: P::Scalar,
103 ) -> Result<Box<dyn SumcheckProver<P::Scalar> + 'a>, Error> {
104 (**self).fold_univariate_round(challenge)
105 }
106
107 fn project_to_skipped_variables(
108 self: Box<Self>,
109 challenges: &[P::Scalar],
110 ) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error> {
111 (*self).project_to_skipped_variables(challenges)
112 }
113}
114
115fn univariatizing_reduction_prover<F, FDomain, P>(
116 mut projected_multilinears: Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>,
117 skip_rounds: usize,
118 univariatized_multilinear_evals: Vec<Vec<F>>,
119 univariate_challenge: F,
120 backend: &'_ CpuBackend,
121) -> Result<impl SumcheckProver<F> + '_, Error>
122where
123 F: TowerField + ExtensionField<FDomain>,
124 FDomain: TowerField,
125 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
126{
127 let sumcheck_claim =
128 univariatizing_reduction_claim(skip_rounds, &univariatized_multilinear_evals)?;
129
130 let subspace =
131 BinarySubspace::<FDomain::Canonical>::with_dim(skip_rounds)?.isomorphic::<FDomain>();
132 let ntt_domain = EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
133
134 projected_multilinears.push(
135 MLEDirectAdapter::from(lagrange_evals_multilinear_extension(
136 &ntt_domain,
137 univariate_challenge,
138 )?)
139 .upcast_arc_dyn(),
140 );
141
142 let prover = RegularSumcheckProver::<FDomain, P, _, _, _>::new(
145 EvaluationOrder::HighToLow,
146 projected_multilinears,
147 sumcheck_claim.composite_sums().iter().cloned(),
148 IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default(),
149 immediate_switchover_heuristic,
150 backend,
151 )?;
152
153 Ok(prover)
154}
155
156#[allow(clippy::type_complexity)]
166pub fn batch_prove<'a, F, FDomain, P, Prover, Challenger_>(
167 mut provers: Vec<Prover>,
168 skip_rounds: usize,
169 transcript: &mut ProverTranscript<Challenger_>,
170) -> Result<BatchZerocheckOutput<P::Scalar>, Error>
171where
172 F: TowerField + ExtensionField<FDomain>,
173 FDomain: TowerField,
174 P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
175 Prover: ZerocheckProver<'a, P>,
176 Challenger_: Challenger,
177{
178 if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) {
180 bail!(Error::ClaimsOutOfOrder);
181 }
182
183 let max_domain_size = provers
184 .iter()
185 .map(|prover| {
186 prover
187 .domain_size(skip_rounds)
188 .expect("domain size must be known")
189 })
190 .max()
191 .unwrap_or(0);
192
193 let mut batch_coeffs = Vec::with_capacity(provers.len());
196 let mut round_evals =
197 ZerocheckRoundEvals::zeros(max_domain_size.saturating_sub(1 << skip_rounds));
198 for prover in &mut provers {
199 let next_batch_coeff = transcript.sample();
200 batch_coeffs.push(next_batch_coeff);
201
202 let prover_round_evals =
203 prover.execute_univariate_round(skip_rounds, max_domain_size, next_batch_coeff)?;
204
205 round_evals.add_assign_lagrange(&(prover_round_evals * next_batch_coeff))?;
206 }
207
208 transcript.message().write_scalar_slice(&round_evals.evals);
210 let univariate_challenge = transcript.sample();
211
212 let mut sumcheck_provers = Vec::with_capacity(provers.len());
214 for prover in &mut provers {
215 let sumcheck_prover = prover.fold_univariate_round(univariate_challenge)?;
216 sumcheck_provers.push(sumcheck_prover);
217 }
218
219 let regular_sumcheck_prover =
220 front_loaded::BatchProver::new_prebatched(batch_coeffs, sumcheck_provers)?;
221
222 let BatchSumcheckOutput {
223 challenges: mut unskipped_challenges,
224 multilinear_evals: mut univariatized_multilinear_evals,
225 } = regular_sumcheck_prover.run(transcript)?;
226
227 unskipped_challenges.reverse();
229
230 for evals in &mut univariatized_multilinear_evals {
232 evals
233 .pop()
234 .expect("equality indicator evaluation at last position");
235 }
236
237 let mut projected_multilinears = Vec::new();
239 let dimensions_data = FoldLowDimensionsData::new(skip_rounds, &provers);
240 let mle_fold_low_span = tracing::debug_span!(
241 "[task] Initial MLE Fold Low",
242 phase = "zerocheck",
243 perfetto_category = "task.main",
244 ?dimensions_data,
245 )
246 .entered();
247 for prover in provers {
248 let claim_projected_multilinears =
249 Box::new(prover).project_to_skipped_variables(&unskipped_challenges)?;
250
251 projected_multilinears.extend(claim_projected_multilinears);
252 }
253 drop(mle_fold_low_span);
254
255 let backend = make_portable_backend();
258 let reduction_prover = univariatizing_reduction_prover::<_, FDomain, _>(
259 projected_multilinears,
260 skip_rounds,
261 univariatized_multilinear_evals,
262 univariate_challenge,
263 &backend,
264 )?;
265
266 let batch_reduction_prover =
267 front_loaded::BatchProver::new(vec![reduction_prover], transcript)?;
268
269 let BatchSumcheckOutput {
270 challenges: mut skipped_challenges,
271 multilinear_evals: mut concat_multilinear_evals,
272 } = batch_reduction_prover.run(transcript)?;
273
274 skipped_challenges.reverse();
276
277 let mut concat_multilinear_evals = concat_multilinear_evals
278 .pop()
279 .expect("multilinear_evals.len() == 1");
280
281 concat_multilinear_evals
282 .pop()
283 .expect("Lagrange coefficients MLE eval at last position");
284
285 let output = BatchZerocheckOutput {
287 skipped_challenges,
288 unskipped_challenges,
289 concat_multilinear_evals,
290 };
291
292 Ok(output)
293}