binius_core/protocols/sumcheck/prove/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, mem, sync::Arc};
4
5use binius_field::{
6	ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, RepackedExtension,
7	TowerField,
8	packed::{copy_packed_from_scalars_slice, get_packed_slice, set_packed_slice},
9	util::powers,
10};
11use binius_hal::{ComputationBackend, ComputationBackendExt};
12use binius_math::{
13	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
14	MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{Either, izip};
20use tracing::instrument;
21
22use crate::{
23	polynomial::MultilinearComposite,
24	protocols::sumcheck::{
25		Error,
26		common::{CompositeSumClaim, equal_n_vars_check},
27		prove::{
28			SumcheckProver, ZerocheckProver,
29			common::fold_partial_eq_ind,
30			eq_ind::EqIndSumcheckProverBuilder,
31			univariate::{
32				ZerocheckUnivariateEvalsOutput, ZerocheckUnivariateFoldResult,
33				zerocheck_univariate_evals,
34			},
35		},
36		zerocheck::{ZerocheckRoundEvals, domain_size},
37	},
38};
39
40pub fn validate_witness<'a, F, P, M, Composition>(
41	multilinears: &[M],
42	zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
43) -> Result<(), Error>
44where
45	F: Field,
46	P: PackedField<Scalar = F>,
47	M: MultilinearPoly<P> + Send + Sync,
48	Composition: CompositionPoly<P> + 'a,
49{
50	let n_vars = multilinears
51		.first()
52		.map(|multilinear| multilinear.n_vars())
53		.unwrap_or_default();
54	for multilinear in multilinears {
55		if multilinear.n_vars() != n_vars {
56			bail!(Error::NumberOfVariablesMismatch);
57		}
58	}
59
60	let multilinears = multilinears.iter().collect::<Vec<_>>();
61
62	for (name, composition) in zero_claims {
63		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
64		(0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
65			if witness.evaluate_on_hypercube(j)? != F::ZERO {
66				return Err(Error::ZerocheckNaiveValidationFailure {
67					composition_name: name.to_string(),
68					vertex_index: j,
69				});
70			}
71			Ok(())
72		})?;
73	}
74	Ok(())
75}
76
77// Pad a small field multilinear to at least `min_n_vars` variables. Padding is on the high indexed
78// variables via concatenating `2^(min_n_vars - n_vars)` copies of evals.
79pub fn high_pad_small_multilinear<PBase, P, M>(
80	min_n_vars: usize,
81	multilinear: M,
82) -> Either<M, MLEEmbeddingAdapter<PBase, P>>
83where
84	PBase: PackedField,
85	P: PackedField + RepackedExtension<PBase>,
86	M: MultilinearPoly<P>,
87{
88	let n_vars = multilinear.n_vars();
89	if n_vars >= min_n_vars {
90		return Either::Left(multilinear);
91	}
92
93	let mut padded_evals_base =
94		zeroed_vec::<PBase>(1 << min_n_vars.saturating_sub(PBase::LOG_WIDTH));
95
96	let log_embedding_degree = <P::Scalar as ExtensionField<PBase::Scalar>>::LOG_DEGREE;
97	let padded_evals = P::cast_exts_mut(&mut padded_evals_base);
98
99	multilinear
100		.subcube_evals(
101			n_vars,
102			0,
103			log_embedding_degree,
104			&mut padded_evals[..1 << n_vars.saturating_sub(PBase::LOG_WIDTH)],
105		)
106		.expect("copy evals verbatim into correctly sized array");
107
108	for repeat_idx in 0..1 << (min_n_vars - n_vars) {
109		for scalar_idx in 0..1 << n_vars {
110			let eval = get_packed_slice(&padded_evals_base, scalar_idx);
111			set_packed_slice(&mut padded_evals_base, scalar_idx | repeat_idx << n_vars, eval);
112		}
113	}
114
115	let padded_multilinear = MultilinearExtension::new(min_n_vars, padded_evals_base)
116		.expect("padded evals have correct size");
117
118	Either::Right(MLEEmbeddingAdapter::from(padded_multilinear))
119}
120
121/// Small-field aware zerocheck prover.
122///
123/// This is a state machine satisfying the contract of [ZerocheckProver](`super::ZerocheckProver`)
124/// trait. Object safety of the latter allows batching several zerocheck provers over different base
125/// fields together.
126///
127/// Full zerocheck reduction is laid out in the
128/// [batch_verify_zerocheck](super::super::batch_verify_zerocheck). This struct implements the
129/// univariate round, witness folding and prover construction for multilinear rounds, and witness
130/// projection for the univariatizing reduction.
131#[derive(Debug)]
132#[allow(clippy::type_complexity)]
133pub struct ZerocheckProverImpl<
134	'a,
135	FDomain,
136	FBase,
137	P,
138	CompositionBase,
139	Composition,
140	M,
141	DomainFactory,
142	Backend,
143> where
144	FDomain: Field,
145	FBase: Field,
146	P: PackedExtension<FBase>,
147	Backend: ComputationBackend,
148{
149	n_vars: usize,
150	zerocheck_challenges: Vec<P::Scalar>,
151	state: ZerocheckProverState<
152		Vec<M>,
153		Vec<Either<M, MLEEmbeddingAdapter<P::PackedSubfield, P>>>,
154		Vec<(String, CompositionBase, Composition)>,
155		ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>,
156		DomainFactory,
157	>,
158	backend: &'a Backend,
159	_p_base_marker: PhantomData<FBase>,
160	_fdomain_marker: PhantomData<FDomain>,
161}
162
163#[derive(Debug)]
164enum ZerocheckProverState<
165	Multilinears,
166	PaddedMultilinears,
167	Compositions,
168	EvalsOutput,
169	DomainFactory,
170> {
171	IllegalState,
172	RoundEval {
173		multilinears: Multilinears,
174		compositions: Compositions,
175		domain_factory: DomainFactory,
176	},
177	Folding {
178		skip_rounds: usize,
179		padded_multilinears: PaddedMultilinears,
180		compositions: Compositions,
181		domain_factory: DomainFactory,
182		univariate_evals_output: EvalsOutput,
183	},
184	Projection {
185		skip_rounds: usize,
186		padded_multilinears: PaddedMultilinears,
187	},
188}
189
190#[allow(clippy::derivable_impls)]
191impl<Multilinears, PaddedMultilinears, Compositions, EvalsOutput, DomainFactory> Default
192	for ZerocheckProverState<
193		Multilinears,
194		PaddedMultilinears,
195		Compositions,
196		EvalsOutput,
197		DomainFactory,
198	>
199{
200	fn default() -> Self {
201		// Default impl is used to allow mem::take on prover state
202		Self::IllegalState
203	}
204}
205
206impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
207	ZerocheckProverImpl<'a, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
208where
209	F: TowerField,
210	FDomain: Field,
211	FBase: ExtensionField<FDomain>,
212	P: PackedField<Scalar = F>
213		+ PackedExtension<F, PackedSubfield = P>
214		+ PackedExtension<FBase>
215		+ PackedExtension<FDomain>,
216	CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
217	Composition: CompositionPoly<P> + 'a,
218	M: MultilinearPoly<P> + Send + Sync + 'a,
219	DomainFactory: EvaluationDomainFactory<FDomain>,
220	Backend: ComputationBackend,
221{
222	pub fn new(
223		multilinears: Vec<M>,
224		zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
225		zerocheck_challenges: &[F],
226		domain_factory: DomainFactory,
227		backend: &'a Backend,
228	) -> Result<Self, Error> {
229		let n_vars = equal_n_vars_check(&multilinears)?;
230		let n_multilinears = multilinears.len();
231
232		let compositions = zero_claims.into_iter().collect::<Vec<_>>();
233		for (_, composition_base, composition) in &compositions {
234			if composition_base.n_vars() != n_multilinears
235				|| composition.n_vars() != n_multilinears
236				|| composition_base.degree() != composition.degree()
237			{
238				bail!(Error::InvalidComposition {
239					actual: composition.n_vars(),
240					expected: n_multilinears,
241				});
242			}
243		}
244		#[cfg(feature = "debug_validate_sumcheck")]
245		{
246			let compositions = compositions
247				.iter()
248				.map(|(name, _, a)| (name.clone(), a))
249				.collect::<Vec<_>>();
250			validate_witness(&multilinears, &compositions)?;
251		}
252
253		let zerocheck_challenges = zerocheck_challenges.to_vec();
254		let state = ZerocheckProverState::RoundEval {
255			multilinears,
256			compositions,
257			domain_factory,
258		};
259
260		Ok(Self {
261			n_vars,
262			zerocheck_challenges,
263			state,
264			backend,
265			_p_base_marker: PhantomData,
266			_fdomain_marker: PhantomData,
267		})
268	}
269}
270
271impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
272	ZerocheckProver<'a, P>
273	for ZerocheckProverImpl<
274		'a,
275		FDomain,
276		FBase,
277		P,
278		CompositionBase,
279		Composition,
280		M,
281		DomainFactory,
282		Backend,
283	>
284where
285	F: TowerField,
286	FDomain: TowerField,
287	FBase: ExtensionField<FDomain>,
288	P: PackedField<Scalar = F>
289		+ PackedExtension<F, PackedSubfield = P>
290		+ PackedExtension<FBase>
291		+ PackedExtension<FDomain>,
292	CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
293	Composition: CompositionPoly<P> + 'static,
294	M: MultilinearPoly<P> + Send + Sync + 'a,
295	DomainFactory: EvaluationDomainFactory<FDomain>,
296	Backend: ComputationBackend,
297{
298	fn n_vars(&self) -> usize {
299		self.n_vars
300	}
301
302	fn domain_size(&self, skip_rounds: usize) -> Option<usize> {
303		let ZerocheckProverState::RoundEval { compositions, .. } = &self.state else {
304			return None;
305		};
306
307		Some(
308			compositions
309				.iter()
310				.map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
311				.max()
312				.unwrap_or(0),
313		)
314	}
315
316	fn execute_univariate_round(
317		&mut self,
318		skip_rounds: usize,
319		max_domain_size: usize,
320		batch_coeff: F,
321	) -> Result<ZerocheckRoundEvals<F>, Error> {
322		let ZerocheckProverState::RoundEval {
323			multilinears,
324			compositions,
325			domain_factory,
326		} = mem::take(&mut self.state)
327		else {
328			bail!(Error::ExpectedExecution);
329		};
330
331		// High pad "short" multilinears to at least `skip_rounds` variables.
332		let padded_multilinears = multilinears
333			.into_iter()
334			.map(|multilinear| high_pad_small_multilinear(skip_rounds, multilinear))
335			.collect::<Vec<_>>();
336
337		// Only use base compositions in the univariate round (it's the whole point)
338		let compositions_base = compositions
339			.iter()
340			.map(|(_, composition_base, _)| composition_base)
341			.collect::<Vec<_>>();
342
343		// Output contains values that are needed for computations that happen after
344		// the round challenge has been sampled
345		let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
346			&padded_multilinears,
347			&compositions_base,
348			&self.zerocheck_challenges,
349			skip_rounds,
350			max_domain_size,
351			self.backend,
352		)?;
353
354		// Batch together Lagrange round evals using powers of batch_coeff
355		let batched_round_evals = univariate_evals_output
356			.round_evals
357			.iter()
358			.zip(powers(batch_coeff))
359			.map(|(evals, scalar)| {
360				ZerocheckRoundEvals {
361					evals: evals.clone(),
362				} * scalar
363			})
364			.try_fold(
365				ZerocheckRoundEvals::zeros(max_domain_size - (1 << skip_rounds)),
366				|mut accum, evals| -> Result<_, Error> {
367					accum.add_assign_lagrange(&evals)?;
368					Ok(accum)
369				},
370			)?;
371
372		self.state = ZerocheckProverState::Folding {
373			skip_rounds,
374			padded_multilinears,
375			compositions,
376			domain_factory,
377			univariate_evals_output,
378		};
379
380		Ok(batched_round_evals)
381	}
382
383	#[instrument(skip_all, level = "debug")]
384	fn fold_univariate_round(
385		&mut self,
386		challenge: F,
387	) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
388		let ZerocheckProverState::Folding {
389			skip_rounds,
390			padded_multilinears,
391			compositions,
392			domain_factory,
393			univariate_evals_output,
394		} = mem::take(&mut self.state)
395		else {
396			bail!(Error::ExpectedFold);
397		};
398
399		// Once the challenge is known, values required for the instantiation of the
400		// multilinear prover for the remaining rounds become known.
401		let ZerocheckUnivariateFoldResult {
402			remaining_rounds,
403			subcube_lagrange_coeffs,
404			claimed_sums,
405			mut partial_eq_ind_evals,
406		} = univariate_evals_output.fold::<FDomain>(challenge)?;
407
408		// For each subcube of size 2**skip_rounds, we need to compute its
409		// inner product with Lagrange coefficients at challenge point in order
410		// to obtain the witness for the remaining multilinear rounds.
411		// REVIEW: Currently MultilinearPoly lacks a method to do that, so we
412		//         hack the needed functionality by overwriting the inner content
413		//         of a MultilinearQuery and performing an evaluate_partial_low,
414		//         which accidentally does what's needed. There should obviously
415		//         be a dedicated method for this someday.
416		let mut packed_subcube_lagrange_coeffs =
417			zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
418		copy_packed_from_scalars_slice(
419			&subcube_lagrange_coeffs[..1 << skip_rounds],
420			&mut packed_subcube_lagrange_coeffs,
421		);
422		let lagrange_coeffs_query =
423			MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
424
425		let folded_multilinears = padded_multilinears
426			.par_iter()
427			.map(|multilinear| -> Result<_, Error> {
428				let folded_multilinear = multilinear
429					.evaluate_partial_low(lagrange_coeffs_query.to_ref())?
430					.into_evals();
431
432				Ok(folded_multilinear)
433			})
434			.collect::<Result<Vec<_>, _>>()?;
435
436		let composite_claims = izip!(compositions, claimed_sums)
437			.map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum })
438			.collect::<Vec<_>>();
439
440		// Zerocheck tensor expansion for the reduced zerocheck should be one variable less
441		fold_partial_eq_ind::<P, Backend>(
442			EvaluationOrder::HighToLow,
443			remaining_rounds,
444			&mut partial_eq_ind_evals,
445		);
446
447		// The remaining non-univariate zerocheck rounds are an instance of EqIndSumcheck,
448		// due to the number of zerocheck challenges being equal to the number of remaining rounds.
449		// Note: while univariate round happens over lowest `skip_rounds` variables, the reduced
450		// EqIndSumcheck is high-to-low.
451		let regular_prover = EqIndSumcheckProverBuilder::without_switchover(
452			remaining_rounds,
453			folded_multilinears,
454			self.backend,
455		)
456		.with_eq_ind_partial_evals(partial_eq_ind_evals)
457		.build(
458			EvaluationOrder::HighToLow,
459			&self.zerocheck_challenges,
460			composite_claims,
461			domain_factory,
462		)?;
463
464		self.state = ZerocheckProverState::Projection {
465			skip_rounds,
466			padded_multilinears,
467		};
468
469		Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
470	}
471
472	fn project_to_skipped_variables(
473		self: Box<Self>,
474		challenges: &[F],
475	) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error> {
476		let ZerocheckProverState::Projection {
477			skip_rounds,
478			padded_multilinears,
479		} = self.state
480		else {
481			bail!(Error::ExpectedProjection);
482		};
483
484		let projection_n_vars = self.n_vars.saturating_sub(skip_rounds);
485		if challenges.len() < projection_n_vars {
486			bail!(Error::IncorrectNumberOfChallenges);
487		}
488
489		let packed_skipped_projections = if self.n_vars < skip_rounds {
490			padded_multilinears
491				.into_iter()
492				.map(|multilinear| {
493					multilinear
494						.expect_right("all multilinears are high-padded")
495						.upcast_arc_dyn()
496				})
497				.collect::<Vec<_>>()
498		} else {
499			let query = self
500				.backend
501				.multilinear_query(&challenges[challenges.len() - projection_n_vars..])?;
502			padded_multilinears
503				.par_iter()
504				.map(|multilinear| {
505					let projected_mle = self
506						.backend
507						.evaluate_partial_high(multilinear, query.to_ref())
508						.expect("sumcheck_challenges.len() >= n_vars - skip_rounds");
509
510					MLEDirectAdapter::from(projected_mle).upcast_arc_dyn()
511				})
512				.collect::<Vec<_>>()
513		};
514
515		Ok(packed_skipped_projections)
516	}
517}