1use std::{marker::PhantomData, mem, sync::Arc};
4
5use binius_field::{
6 packed::{copy_packed_from_scalars_slice, get_packed_slice, set_packed_slice},
7 util::powers,
8 ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, RepackedExtension,
9 TowerField,
10};
11use binius_hal::{ComputationBackend, ComputationBackendExt};
12use binius_math::{
13 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
14 MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{izip, Either};
20use tracing::instrument;
21
22use crate::{
23 polynomial::MultilinearComposite,
24 protocols::sumcheck::{
25 common::{equal_n_vars_check, CompositeSumClaim},
26 prove::{
27 common::fold_partial_eq_ind,
28 eq_ind::EqIndSumcheckProverBuilder,
29 univariate::{
30 zerocheck_univariate_evals, ZerocheckUnivariateEvalsOutput,
31 ZerocheckUnivariateFoldResult,
32 },
33 SumcheckProver, ZerocheckProver,
34 },
35 zerocheck::{domain_size, ZerocheckRoundEvals},
36 Error,
37 },
38};
39
40pub fn validate_witness<'a, F, P, M, Composition>(
41 multilinears: &[M],
42 zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
43) -> Result<(), Error>
44where
45 F: Field,
46 P: PackedField<Scalar = F>,
47 M: MultilinearPoly<P> + Send + Sync,
48 Composition: CompositionPoly<P> + 'a,
49{
50 let n_vars = multilinears
51 .first()
52 .map(|multilinear| multilinear.n_vars())
53 .unwrap_or_default();
54 for multilinear in multilinears {
55 if multilinear.n_vars() != n_vars {
56 bail!(Error::NumberOfVariablesMismatch);
57 }
58 }
59
60 let multilinears = multilinears.iter().collect::<Vec<_>>();
61
62 for (name, composition) in zero_claims {
63 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
64 (0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
65 if witness.evaluate_on_hypercube(j)? != F::ZERO {
66 return Err(Error::ZerocheckNaiveValidationFailure {
67 composition_name: name.to_string(),
68 vertex_index: j,
69 });
70 }
71 Ok(())
72 })?;
73 }
74 Ok(())
75}
76
77pub fn high_pad_small_multilinear<PBase, P, M>(
80 min_n_vars: usize,
81 multilinear: M,
82) -> Either<M, MLEEmbeddingAdapter<PBase, P>>
83where
84 PBase: PackedField,
85 P: PackedField + RepackedExtension<PBase>,
86 M: MultilinearPoly<P>,
87{
88 let n_vars = multilinear.n_vars();
89 if n_vars >= min_n_vars {
90 return Either::Left(multilinear);
91 }
92
93 let mut padded_evals_base =
94 zeroed_vec::<PBase>(1 << min_n_vars.saturating_sub(PBase::LOG_WIDTH));
95
96 let log_embedding_degree = <P::Scalar as ExtensionField<PBase::Scalar>>::LOG_DEGREE;
97 let padded_evals = P::cast_exts_mut(&mut padded_evals_base);
98
99 multilinear
100 .subcube_evals(
101 n_vars,
102 0,
103 log_embedding_degree,
104 &mut padded_evals[..1 << n_vars.saturating_sub(PBase::LOG_WIDTH)],
105 )
106 .expect("copy evals verbatim into correctly sized array");
107
108 for repeat_idx in 0..1 << (min_n_vars - n_vars) {
109 for scalar_idx in 0..1 << n_vars {
110 let eval = get_packed_slice(&padded_evals_base, scalar_idx);
111 set_packed_slice(&mut padded_evals_base, scalar_idx | repeat_idx << n_vars, eval);
112 }
113 }
114
115 let padded_multilinear = MultilinearExtension::new(min_n_vars, padded_evals_base)
116 .expect("padded evals have correct size");
117
118 Either::Right(MLEEmbeddingAdapter::from(padded_multilinear))
119}
120
121#[derive(Debug)]
130#[allow(clippy::type_complexity)]
131pub struct ZerocheckProverImpl<
132 'a,
133 FDomain,
134 FBase,
135 P,
136 CompositionBase,
137 Composition,
138 M,
139 DomainFactory,
140 Backend,
141> where
142 FDomain: Field,
143 FBase: Field,
144 P: PackedExtension<FBase>,
145 Backend: ComputationBackend,
146{
147 n_vars: usize,
148 zerocheck_challenges: Vec<P::Scalar>,
149 state: ZerocheckProverState<
150 Vec<M>,
151 Vec<Either<M, MLEEmbeddingAdapter<P::PackedSubfield, P>>>,
152 Vec<(String, CompositionBase, Composition)>,
153 ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>,
154 DomainFactory,
155 >,
156 backend: &'a Backend,
157 _p_base_marker: PhantomData<FBase>,
158 _fdomain_marker: PhantomData<FDomain>,
159}
160
161#[derive(Debug)]
162enum ZerocheckProverState<
163 Multilinears,
164 PaddedMultilinears,
165 Compositions,
166 EvalsOutput,
167 DomainFactory,
168> {
169 IllegalState,
170 RoundEval {
171 multilinears: Multilinears,
172 compositions: Compositions,
173 domain_factory: DomainFactory,
174 },
175 Folding {
176 skip_rounds: usize,
177 padded_multilinears: PaddedMultilinears,
178 compositions: Compositions,
179 domain_factory: DomainFactory,
180 univariate_evals_output: EvalsOutput,
181 },
182 Projection {
183 skip_rounds: usize,
184 padded_multilinears: PaddedMultilinears,
185 },
186}
187
188#[allow(clippy::derivable_impls)]
189impl<Multilinears, PaddedMultilinears, Compositions, EvalsOutput, DomainFactory> Default
190 for ZerocheckProverState<
191 Multilinears,
192 PaddedMultilinears,
193 Compositions,
194 EvalsOutput,
195 DomainFactory,
196 >
197{
198 fn default() -> Self {
199 Self::IllegalState
201 }
202}
203
204impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
205 ZerocheckProverImpl<'a, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
206where
207 F: TowerField,
208 FDomain: Field,
209 FBase: ExtensionField<FDomain>,
210 P: PackedField<Scalar = F>
211 + PackedExtension<F, PackedSubfield = P>
212 + PackedExtension<FBase>
213 + PackedExtension<FDomain>,
214 CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
215 Composition: CompositionPoly<P> + 'a,
216 M: MultilinearPoly<P> + Send + Sync + 'a,
217 DomainFactory: EvaluationDomainFactory<FDomain>,
218 Backend: ComputationBackend,
219{
220 pub fn new(
221 multilinears: Vec<M>,
222 zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
223 zerocheck_challenges: &[F],
224 domain_factory: DomainFactory,
225 backend: &'a Backend,
226 ) -> Result<Self, Error> {
227 let n_vars = equal_n_vars_check(&multilinears)?;
228 let n_multilinears = multilinears.len();
229
230 let compositions = zero_claims.into_iter().collect::<Vec<_>>();
231 for (_, composition_base, composition) in &compositions {
232 if composition_base.n_vars() != n_multilinears
233 || composition.n_vars() != n_multilinears
234 || composition_base.degree() != composition.degree()
235 {
236 bail!(Error::InvalidComposition {
237 actual: composition.n_vars(),
238 expected: n_multilinears,
239 });
240 }
241 }
242 #[cfg(feature = "debug_validate_sumcheck")]
243 {
244 let compositions = compositions
245 .iter()
246 .map(|(name, _, a)| (name.clone(), a))
247 .collect::<Vec<_>>();
248 validate_witness(&multilinears, &compositions)?;
249 }
250
251 let zerocheck_challenges = zerocheck_challenges.to_vec();
252 let state = ZerocheckProverState::RoundEval {
253 multilinears,
254 compositions,
255 domain_factory,
256 };
257
258 Ok(Self {
259 n_vars,
260 zerocheck_challenges,
261 state,
262 backend,
263 _p_base_marker: PhantomData,
264 _fdomain_marker: PhantomData,
265 })
266 }
267}
268
269impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
270 ZerocheckProver<'a, P>
271 for ZerocheckProverImpl<
272 'a,
273 FDomain,
274 FBase,
275 P,
276 CompositionBase,
277 Composition,
278 M,
279 DomainFactory,
280 Backend,
281 >
282where
283 F: TowerField,
284 FDomain: TowerField,
285 FBase: ExtensionField<FDomain>,
286 P: PackedField<Scalar = F>
287 + PackedExtension<F, PackedSubfield = P>
288 + PackedExtension<FBase>
289 + PackedExtension<FDomain>,
290 CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
291 Composition: CompositionPoly<P> + 'static,
292 M: MultilinearPoly<P> + Send + Sync + 'a,
293 DomainFactory: EvaluationDomainFactory<FDomain>,
294 Backend: ComputationBackend,
295{
296 fn n_vars(&self) -> usize {
297 self.n_vars
298 }
299
300 fn domain_size(&self, skip_rounds: usize) -> Option<usize> {
301 let ZerocheckProverState::RoundEval { compositions, .. } = &self.state else {
302 return None;
303 };
304
305 Some(
306 compositions
307 .iter()
308 .map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
309 .max()
310 .unwrap_or(0),
311 )
312 }
313
314 fn execute_univariate_round(
315 &mut self,
316 skip_rounds: usize,
317 max_domain_size: usize,
318 batch_coeff: F,
319 ) -> Result<ZerocheckRoundEvals<F>, Error> {
320 let ZerocheckProverState::RoundEval {
321 multilinears,
322 compositions,
323 domain_factory,
324 } = mem::take(&mut self.state)
325 else {
326 bail!(Error::ExpectedExecution);
327 };
328
329 let padded_multilinears = multilinears
331 .into_iter()
332 .map(|multilinear| high_pad_small_multilinear(skip_rounds, multilinear))
333 .collect::<Vec<_>>();
334
335 let compositions_base = compositions
337 .iter()
338 .map(|(_, composition_base, _)| composition_base)
339 .collect::<Vec<_>>();
340
341 let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
344 &padded_multilinears,
345 &compositions_base,
346 &self.zerocheck_challenges,
347 skip_rounds,
348 max_domain_size,
349 self.backend,
350 )?;
351
352 let batched_round_evals = univariate_evals_output
354 .round_evals
355 .iter()
356 .zip(powers(batch_coeff))
357 .map(|(evals, scalar)| {
358 ZerocheckRoundEvals {
359 evals: evals.clone(),
360 } * scalar
361 })
362 .try_fold(
363 ZerocheckRoundEvals::zeros(max_domain_size - (1 << skip_rounds)),
364 |mut accum, evals| -> Result<_, Error> {
365 accum.add_assign_lagrange(&evals)?;
366 Ok(accum)
367 },
368 )?;
369
370 self.state = ZerocheckProverState::Folding {
371 skip_rounds,
372 padded_multilinears,
373 compositions,
374 domain_factory,
375 univariate_evals_output,
376 };
377
378 Ok(batched_round_evals)
379 }
380
381 #[instrument(skip_all, level = "debug")]
382 fn fold_univariate_round(
383 &mut self,
384 challenge: F,
385 ) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
386 let ZerocheckProverState::Folding {
387 skip_rounds,
388 padded_multilinears,
389 compositions,
390 domain_factory,
391 univariate_evals_output,
392 } = mem::take(&mut self.state)
393 else {
394 bail!(Error::ExpectedFold);
395 };
396
397 let ZerocheckUnivariateFoldResult {
400 remaining_rounds,
401 subcube_lagrange_coeffs,
402 claimed_sums,
403 mut partial_eq_ind_evals,
404 } = univariate_evals_output.fold::<FDomain>(challenge)?;
405
406 let mut packed_subcube_lagrange_coeffs =
415 zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
416 copy_packed_from_scalars_slice(
417 &subcube_lagrange_coeffs[..1 << skip_rounds],
418 &mut packed_subcube_lagrange_coeffs,
419 );
420 let lagrange_coeffs_query =
421 MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
422
423 let folded_multilinears = padded_multilinears
424 .par_iter()
425 .map(|multilinear| -> Result<_, Error> {
426 let folded_multilinear = multilinear
427 .evaluate_partial_low(lagrange_coeffs_query.to_ref())?
428 .into_evals();
429
430 Ok(folded_multilinear)
431 })
432 .collect::<Result<Vec<_>, _>>()?;
433
434 let composite_claims = izip!(compositions, claimed_sums)
435 .map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum })
436 .collect::<Vec<_>>();
437
438 fold_partial_eq_ind::<P, Backend>(
440 EvaluationOrder::HighToLow,
441 remaining_rounds,
442 &mut partial_eq_ind_evals,
443 );
444
445 let regular_prover = EqIndSumcheckProverBuilder::without_switchover(
450 remaining_rounds,
451 folded_multilinears,
452 self.backend,
453 )
454 .with_eq_ind_partial_evals(partial_eq_ind_evals)
455 .build(
456 EvaluationOrder::HighToLow,
457 &self.zerocheck_challenges,
458 composite_claims,
459 domain_factory,
460 )?;
461
462 self.state = ZerocheckProverState::Projection {
463 skip_rounds,
464 padded_multilinears,
465 };
466
467 Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
468 }
469
470 fn project_to_skipped_variables(
471 self: Box<Self>,
472 challenges: &[F],
473 ) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error> {
474 let ZerocheckProverState::Projection {
475 skip_rounds,
476 padded_multilinears,
477 } = self.state
478 else {
479 bail!(Error::ExpectedProjection);
480 };
481
482 let projection_n_vars = self.n_vars.saturating_sub(skip_rounds);
483 if challenges.len() < projection_n_vars {
484 bail!(Error::IncorrectNumberOfChallenges);
485 }
486
487 let packed_skipped_projections = if self.n_vars < skip_rounds {
488 padded_multilinears
489 .into_iter()
490 .map(|multilinear| {
491 multilinear
492 .expect_right("all multilinears are high-padded")
493 .upcast_arc_dyn()
494 })
495 .collect::<Vec<_>>()
496 } else {
497 let query = self
498 .backend
499 .multilinear_query(&challenges[challenges.len() - projection_n_vars..])?;
500 padded_multilinears
501 .par_iter()
502 .map(|multilinear| {
503 let projected_mle = self
504 .backend
505 .evaluate_partial_high(multilinear, query.to_ref())
506 .expect("sumcheck_challenges.len() >= n_vars - skip_rounds");
507
508 MLEDirectAdapter::from(projected_mle).upcast_arc_dyn()
509 })
510 .collect::<Vec<_>>()
511 };
512
513 Ok(packed_skipped_projections)
514 }
515}