1use std::{marker::PhantomData, mem, sync::Arc};
4
5use binius_field::{
6 ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, RepackedExtension,
7 TowerField,
8 packed::{copy_packed_from_scalars_slice, get_packed_slice, set_packed_slice},
9 util::powers,
10};
11use binius_hal::{ComputationBackend, ComputationBackendExt};
12use binius_math::{
13 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
14 MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{Either, izip};
20use tracing::instrument;
21
22use crate::{
23 polynomial::MultilinearComposite,
24 protocols::sumcheck::{
25 Error,
26 common::{CompositeSumClaim, equal_n_vars_check},
27 prove::{
28 SumcheckProver, ZerocheckProver,
29 common::fold_partial_eq_ind,
30 eq_ind::EqIndSumcheckProverBuilder,
31 univariate::{
32 ZerocheckUnivariateEvalsOutput, ZerocheckUnivariateFoldResult,
33 zerocheck_univariate_evals,
34 },
35 },
36 zerocheck::{ZerocheckRoundEvals, domain_size},
37 },
38};
39
40pub fn validate_witness<'a, F, P, M, Composition>(
41 multilinears: &[M],
42 zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
43) -> Result<(), Error>
44where
45 F: Field,
46 P: PackedField<Scalar = F>,
47 M: MultilinearPoly<P> + Send + Sync,
48 Composition: CompositionPoly<P> + 'a,
49{
50 let n_vars = multilinears
51 .first()
52 .map(|multilinear| multilinear.n_vars())
53 .unwrap_or_default();
54 for multilinear in multilinears {
55 if multilinear.n_vars() != n_vars {
56 bail!(Error::NumberOfVariablesMismatch);
57 }
58 }
59
60 let multilinears = multilinears.iter().collect::<Vec<_>>();
61
62 for (name, composition) in zero_claims {
63 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
64 (0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
65 if witness.evaluate_on_hypercube(j)? != F::ZERO {
66 return Err(Error::ZerocheckNaiveValidationFailure {
67 composition_name: name.to_string(),
68 vertex_index: j,
69 });
70 }
71 Ok(())
72 })?;
73 }
74 Ok(())
75}
76
77pub fn high_pad_small_multilinear<PBase, P, M>(
80 min_n_vars: usize,
81 multilinear: M,
82) -> Either<M, MLEEmbeddingAdapter<PBase, P>>
83where
84 PBase: PackedField,
85 P: PackedField + RepackedExtension<PBase>,
86 M: MultilinearPoly<P>,
87{
88 let n_vars = multilinear.n_vars();
89 if n_vars >= min_n_vars {
90 return Either::Left(multilinear);
91 }
92
93 let mut padded_evals_base =
94 zeroed_vec::<PBase>(1 << min_n_vars.saturating_sub(PBase::LOG_WIDTH));
95
96 let log_embedding_degree = <P::Scalar as ExtensionField<PBase::Scalar>>::LOG_DEGREE;
97 let padded_evals = P::cast_exts_mut(&mut padded_evals_base);
98
99 multilinear
100 .subcube_evals(
101 n_vars,
102 0,
103 log_embedding_degree,
104 &mut padded_evals[..1 << n_vars.saturating_sub(PBase::LOG_WIDTH)],
105 )
106 .expect("copy evals verbatim into correctly sized array");
107
108 for repeat_idx in 0..1 << (min_n_vars - n_vars) {
109 for scalar_idx in 0..1 << n_vars {
110 let eval = get_packed_slice(&padded_evals_base, scalar_idx);
111 set_packed_slice(&mut padded_evals_base, scalar_idx | repeat_idx << n_vars, eval);
112 }
113 }
114
115 let padded_multilinear = MultilinearExtension::new(min_n_vars, padded_evals_base)
116 .expect("padded evals have correct size");
117
118 Either::Right(MLEEmbeddingAdapter::from(padded_multilinear))
119}
120
121#[derive(Debug)]
132#[allow(clippy::type_complexity)]
133pub struct ZerocheckProverImpl<
134 'a,
135 FDomain,
136 FBase,
137 P,
138 CompositionBase,
139 Composition,
140 M,
141 DomainFactory,
142 Backend,
143> where
144 FDomain: Field,
145 FBase: Field,
146 P: PackedExtension<FBase>,
147 Backend: ComputationBackend,
148{
149 n_vars: usize,
150 zerocheck_challenges: Vec<P::Scalar>,
151 state: ZerocheckProverState<
152 Vec<M>,
153 Vec<Either<M, MLEEmbeddingAdapter<P::PackedSubfield, P>>>,
154 Vec<(String, CompositionBase, Composition)>,
155 ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>,
156 DomainFactory,
157 >,
158 backend: &'a Backend,
159 _p_base_marker: PhantomData<FBase>,
160 _fdomain_marker: PhantomData<FDomain>,
161}
162
163#[derive(Debug)]
164enum ZerocheckProverState<
165 Multilinears,
166 PaddedMultilinears,
167 Compositions,
168 EvalsOutput,
169 DomainFactory,
170> {
171 IllegalState,
172 RoundEval {
173 multilinears: Multilinears,
174 compositions: Compositions,
175 domain_factory: DomainFactory,
176 },
177 Folding {
178 skip_rounds: usize,
179 padded_multilinears: PaddedMultilinears,
180 compositions: Compositions,
181 domain_factory: DomainFactory,
182 univariate_evals_output: EvalsOutput,
183 },
184 Projection {
185 skip_rounds: usize,
186 padded_multilinears: PaddedMultilinears,
187 },
188}
189
190#[allow(clippy::derivable_impls)]
191impl<Multilinears, PaddedMultilinears, Compositions, EvalsOutput, DomainFactory> Default
192 for ZerocheckProverState<
193 Multilinears,
194 PaddedMultilinears,
195 Compositions,
196 EvalsOutput,
197 DomainFactory,
198 >
199{
200 fn default() -> Self {
201 Self::IllegalState
203 }
204}
205
206impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
207 ZerocheckProverImpl<'a, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
208where
209 F: TowerField,
210 FDomain: Field,
211 FBase: ExtensionField<FDomain>,
212 P: PackedField<Scalar = F>
213 + PackedExtension<F, PackedSubfield = P>
214 + PackedExtension<FBase>
215 + PackedExtension<FDomain>,
216 CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
217 Composition: CompositionPoly<P> + 'a,
218 M: MultilinearPoly<P> + Send + Sync + 'a,
219 DomainFactory: EvaluationDomainFactory<FDomain>,
220 Backend: ComputationBackend,
221{
222 pub fn new(
223 multilinears: Vec<M>,
224 zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
225 zerocheck_challenges: &[F],
226 domain_factory: DomainFactory,
227 backend: &'a Backend,
228 ) -> Result<Self, Error> {
229 let n_vars = equal_n_vars_check(&multilinears)?;
230 let n_multilinears = multilinears.len();
231
232 let compositions = zero_claims.into_iter().collect::<Vec<_>>();
233 for (_, composition_base, composition) in &compositions {
234 if composition_base.n_vars() != n_multilinears
235 || composition.n_vars() != n_multilinears
236 || composition_base.degree() != composition.degree()
237 {
238 bail!(Error::InvalidComposition {
239 actual: composition.n_vars(),
240 expected: n_multilinears,
241 });
242 }
243 }
244 #[cfg(feature = "debug_validate_sumcheck")]
245 {
246 let compositions = compositions
247 .iter()
248 .map(|(name, _, a)| (name.clone(), a))
249 .collect::<Vec<_>>();
250 validate_witness(&multilinears, &compositions)?;
251 }
252
253 let zerocheck_challenges = zerocheck_challenges.to_vec();
254 let state = ZerocheckProverState::RoundEval {
255 multilinears,
256 compositions,
257 domain_factory,
258 };
259
260 Ok(Self {
261 n_vars,
262 zerocheck_challenges,
263 state,
264 backend,
265 _p_base_marker: PhantomData,
266 _fdomain_marker: PhantomData,
267 })
268 }
269}
270
271impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
272 ZerocheckProver<'a, P>
273 for ZerocheckProverImpl<
274 'a,
275 FDomain,
276 FBase,
277 P,
278 CompositionBase,
279 Composition,
280 M,
281 DomainFactory,
282 Backend,
283 >
284where
285 F: TowerField,
286 FDomain: TowerField,
287 FBase: ExtensionField<FDomain>,
288 P: PackedField<Scalar = F>
289 + PackedExtension<F, PackedSubfield = P>
290 + PackedExtension<FBase>
291 + PackedExtension<FDomain>,
292 CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
293 Composition: CompositionPoly<P> + 'static,
294 M: MultilinearPoly<P> + Send + Sync + 'a,
295 DomainFactory: EvaluationDomainFactory<FDomain>,
296 Backend: ComputationBackend,
297{
298 fn n_vars(&self) -> usize {
299 self.n_vars
300 }
301
302 fn domain_size(&self, skip_rounds: usize) -> Option<usize> {
303 let ZerocheckProverState::RoundEval { compositions, .. } = &self.state else {
304 return None;
305 };
306
307 Some(
308 compositions
309 .iter()
310 .map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
311 .max()
312 .unwrap_or(0),
313 )
314 }
315
316 fn execute_univariate_round(
317 &mut self,
318 skip_rounds: usize,
319 max_domain_size: usize,
320 batch_coeff: F,
321 ) -> Result<ZerocheckRoundEvals<F>, Error> {
322 let ZerocheckProverState::RoundEval {
323 multilinears,
324 compositions,
325 domain_factory,
326 } = mem::take(&mut self.state)
327 else {
328 bail!(Error::ExpectedExecution);
329 };
330
331 let padded_multilinears = multilinears
333 .into_iter()
334 .map(|multilinear| high_pad_small_multilinear(skip_rounds, multilinear))
335 .collect::<Vec<_>>();
336
337 let compositions_base = compositions
339 .iter()
340 .map(|(_, composition_base, _)| composition_base)
341 .collect::<Vec<_>>();
342
343 let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
346 &padded_multilinears,
347 &compositions_base,
348 &self.zerocheck_challenges,
349 skip_rounds,
350 max_domain_size,
351 self.backend,
352 )?;
353
354 let batched_round_evals = univariate_evals_output
356 .round_evals
357 .iter()
358 .zip(powers(batch_coeff))
359 .map(|(evals, scalar)| {
360 ZerocheckRoundEvals {
361 evals: evals.clone(),
362 } * scalar
363 })
364 .try_fold(
365 ZerocheckRoundEvals::zeros(max_domain_size - (1 << skip_rounds)),
366 |mut accum, evals| -> Result<_, Error> {
367 accum.add_assign_lagrange(&evals)?;
368 Ok(accum)
369 },
370 )?;
371
372 self.state = ZerocheckProverState::Folding {
373 skip_rounds,
374 padded_multilinears,
375 compositions,
376 domain_factory,
377 univariate_evals_output,
378 };
379
380 Ok(batched_round_evals)
381 }
382
383 #[instrument(skip_all, level = "debug")]
384 fn fold_univariate_round(
385 &mut self,
386 challenge: F,
387 ) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
388 let ZerocheckProverState::Folding {
389 skip_rounds,
390 padded_multilinears,
391 compositions,
392 domain_factory,
393 univariate_evals_output,
394 } = mem::take(&mut self.state)
395 else {
396 bail!(Error::ExpectedFold);
397 };
398
399 let ZerocheckUnivariateFoldResult {
402 remaining_rounds,
403 subcube_lagrange_coeffs,
404 claimed_sums,
405 mut partial_eq_ind_evals,
406 } = univariate_evals_output.fold::<FDomain>(challenge)?;
407
408 let mut packed_subcube_lagrange_coeffs =
417 zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
418 copy_packed_from_scalars_slice(
419 &subcube_lagrange_coeffs[..1 << skip_rounds],
420 &mut packed_subcube_lagrange_coeffs,
421 );
422 let lagrange_coeffs_query =
423 MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
424
425 let folded_multilinears = padded_multilinears
426 .par_iter()
427 .map(|multilinear| -> Result<_, Error> {
428 let folded_multilinear = multilinear
429 .evaluate_partial_low(lagrange_coeffs_query.to_ref())?
430 .into_evals();
431
432 Ok(folded_multilinear)
433 })
434 .collect::<Result<Vec<_>, _>>()?;
435
436 let composite_claims = izip!(compositions, claimed_sums)
437 .map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum })
438 .collect::<Vec<_>>();
439
440 fold_partial_eq_ind::<P, Backend>(
442 EvaluationOrder::HighToLow,
443 remaining_rounds,
444 &mut partial_eq_ind_evals,
445 );
446
447 let regular_prover = EqIndSumcheckProverBuilder::without_switchover(
452 remaining_rounds,
453 folded_multilinears,
454 self.backend,
455 )
456 .with_eq_ind_partial_evals(partial_eq_ind_evals)
457 .build(
458 EvaluationOrder::HighToLow,
459 &self.zerocheck_challenges,
460 composite_claims,
461 domain_factory,
462 )?;
463
464 self.state = ZerocheckProverState::Projection {
465 skip_rounds,
466 padded_multilinears,
467 };
468
469 Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
470 }
471
472 fn project_to_skipped_variables(
473 self: Box<Self>,
474 challenges: &[F],
475 ) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error> {
476 let ZerocheckProverState::Projection {
477 skip_rounds,
478 padded_multilinears,
479 } = self.state
480 else {
481 bail!(Error::ExpectedProjection);
482 };
483
484 let projection_n_vars = self.n_vars.saturating_sub(skip_rounds);
485 if challenges.len() < projection_n_vars {
486 bail!(Error::IncorrectNumberOfChallenges);
487 }
488
489 let packed_skipped_projections = if self.n_vars < skip_rounds {
490 padded_multilinears
491 .into_iter()
492 .map(|multilinear| {
493 multilinear
494 .expect_right("all multilinears are high-padded")
495 .upcast_arc_dyn()
496 })
497 .collect::<Vec<_>>()
498 } else {
499 let query = self
500 .backend
501 .multilinear_query(&challenges[challenges.len() - projection_n_vars..])?;
502 padded_multilinears
503 .par_iter()
504 .map(|multilinear| {
505 let projected_mle = self
506 .backend
507 .evaluate_partial_high(multilinear, query.to_ref())
508 .expect("sumcheck_challenges.len() >= n_vars - skip_rounds");
509
510 MLEDirectAdapter::from(projected_mle).upcast_arc_dyn()
511 })
512 .collect::<Vec<_>>()
513 };
514
515 Ok(packed_skipped_projections)
516 }
517}