binius_core/protocols/sumcheck/prove/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, mem, sync::Arc};
4
5use binius_field::{
6	packed::{copy_packed_from_scalars_slice, get_packed_slice, set_packed_slice},
7	util::powers,
8	ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, RepackedExtension,
9	TowerField,
10};
11use binius_hal::{ComputationBackend, ComputationBackendExt};
12use binius_math::{
13	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter,
14	MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_utils::bail;
18use bytemuck::zeroed_vec;
19use itertools::{izip, Either};
20use tracing::instrument;
21
22use crate::{
23	polynomial::MultilinearComposite,
24	protocols::sumcheck::{
25		common::{equal_n_vars_check, CompositeSumClaim},
26		prove::{
27			common::fold_partial_eq_ind,
28			eq_ind::EqIndSumcheckProverBuilder,
29			univariate::{
30				zerocheck_univariate_evals, ZerocheckUnivariateEvalsOutput,
31				ZerocheckUnivariateFoldResult,
32			},
33			SumcheckProver, ZerocheckProver,
34		},
35		zerocheck::{domain_size, ZerocheckRoundEvals},
36		Error,
37	},
38};
39
40pub fn validate_witness<'a, F, P, M, Composition>(
41	multilinears: &[M],
42	zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
43) -> Result<(), Error>
44where
45	F: Field,
46	P: PackedField<Scalar = F>,
47	M: MultilinearPoly<P> + Send + Sync,
48	Composition: CompositionPoly<P> + 'a,
49{
50	let n_vars = multilinears
51		.first()
52		.map(|multilinear| multilinear.n_vars())
53		.unwrap_or_default();
54	for multilinear in multilinears {
55		if multilinear.n_vars() != n_vars {
56			bail!(Error::NumberOfVariablesMismatch);
57		}
58	}
59
60	let multilinears = multilinears.iter().collect::<Vec<_>>();
61
62	for (name, composition) in zero_claims {
63		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
64		(0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
65			if witness.evaluate_on_hypercube(j)? != F::ZERO {
66				return Err(Error::ZerocheckNaiveValidationFailure {
67					composition_name: name.to_string(),
68					vertex_index: j,
69				});
70			}
71			Ok(())
72		})?;
73	}
74	Ok(())
75}
76
77// Pad a small field multilinear to at least `min_n_vars` variables. Padding is on the high indexed variables
78// via concatenating `2^(min_n_vars - n_vars)` copies of evals.
79pub fn high_pad_small_multilinear<PBase, P, M>(
80	min_n_vars: usize,
81	multilinear: M,
82) -> Either<M, MLEEmbeddingAdapter<PBase, P>>
83where
84	PBase: PackedField,
85	P: PackedField + RepackedExtension<PBase>,
86	M: MultilinearPoly<P>,
87{
88	let n_vars = multilinear.n_vars();
89	if n_vars >= min_n_vars {
90		return Either::Left(multilinear);
91	}
92
93	let mut padded_evals_base =
94		zeroed_vec::<PBase>(1 << min_n_vars.saturating_sub(PBase::LOG_WIDTH));
95
96	let log_embedding_degree = <P::Scalar as ExtensionField<PBase::Scalar>>::LOG_DEGREE;
97	let padded_evals = P::cast_exts_mut(&mut padded_evals_base);
98
99	multilinear
100		.subcube_evals(
101			n_vars,
102			0,
103			log_embedding_degree,
104			&mut padded_evals[..1 << n_vars.saturating_sub(PBase::LOG_WIDTH)],
105		)
106		.expect("copy evals verbatim into correctly sized array");
107
108	for repeat_idx in 0..1 << (min_n_vars - n_vars) {
109		for scalar_idx in 0..1 << n_vars {
110			let eval = get_packed_slice(&padded_evals_base, scalar_idx);
111			set_packed_slice(&mut padded_evals_base, scalar_idx | repeat_idx << n_vars, eval);
112		}
113	}
114
115	let padded_multilinear = MultilinearExtension::new(min_n_vars, padded_evals_base)
116		.expect("padded evals have correct size");
117
118	Either::Right(MLEEmbeddingAdapter::from(padded_multilinear))
119}
120
121/// Small-field aware zerocheck prover.
122///
123/// This is a state machine satisfying the contract of [ZerocheckProver](`super::ZerocheckProver`) trait.
124/// Object safety of the latter allows batching several zerocheck provers over different base fields together.
125///
126/// Full zerocheck reduction is laid out in the [batch_verify_zerocheck](super::super::batch_verify_zerocheck).
127/// This struct implements the univariate round, witness folding and prover construction for multilinear rounds,
128/// and witness projection for the univariatizing reduction.
129#[derive(Debug)]
130#[allow(clippy::type_complexity)]
131pub struct ZerocheckProverImpl<
132	'a,
133	FDomain,
134	FBase,
135	P,
136	CompositionBase,
137	Composition,
138	M,
139	DomainFactory,
140	Backend,
141> where
142	FDomain: Field,
143	FBase: Field,
144	P: PackedExtension<FBase>,
145	Backend: ComputationBackend,
146{
147	n_vars: usize,
148	zerocheck_challenges: Vec<P::Scalar>,
149	state: ZerocheckProverState<
150		Vec<M>,
151		Vec<Either<M, MLEEmbeddingAdapter<P::PackedSubfield, P>>>,
152		Vec<(String, CompositionBase, Composition)>,
153		ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>,
154		DomainFactory,
155	>,
156	backend: &'a Backend,
157	_p_base_marker: PhantomData<FBase>,
158	_fdomain_marker: PhantomData<FDomain>,
159}
160
161#[derive(Debug)]
162enum ZerocheckProverState<
163	Multilinears,
164	PaddedMultilinears,
165	Compositions,
166	EvalsOutput,
167	DomainFactory,
168> {
169	IllegalState,
170	RoundEval {
171		multilinears: Multilinears,
172		compositions: Compositions,
173		domain_factory: DomainFactory,
174	},
175	Folding {
176		skip_rounds: usize,
177		padded_multilinears: PaddedMultilinears,
178		compositions: Compositions,
179		domain_factory: DomainFactory,
180		univariate_evals_output: EvalsOutput,
181	},
182	Projection {
183		skip_rounds: usize,
184		padded_multilinears: PaddedMultilinears,
185	},
186}
187
188#[allow(clippy::derivable_impls)]
189impl<Multilinears, PaddedMultilinears, Compositions, EvalsOutput, DomainFactory> Default
190	for ZerocheckProverState<
191		Multilinears,
192		PaddedMultilinears,
193		Compositions,
194		EvalsOutput,
195		DomainFactory,
196	>
197{
198	fn default() -> Self {
199		// Default impl is used to allow mem::take on prover state
200		Self::IllegalState
201	}
202}
203
204impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
205	ZerocheckProverImpl<'a, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
206where
207	F: TowerField,
208	FDomain: Field,
209	FBase: ExtensionField<FDomain>,
210	P: PackedField<Scalar = F>
211		+ PackedExtension<F, PackedSubfield = P>
212		+ PackedExtension<FBase>
213		+ PackedExtension<FDomain>,
214	CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
215	Composition: CompositionPoly<P> + 'a,
216	M: MultilinearPoly<P> + Send + Sync + 'a,
217	DomainFactory: EvaluationDomainFactory<FDomain>,
218	Backend: ComputationBackend,
219{
220	pub fn new(
221		multilinears: Vec<M>,
222		zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
223		zerocheck_challenges: &[F],
224		domain_factory: DomainFactory,
225		backend: &'a Backend,
226	) -> Result<Self, Error> {
227		let n_vars = equal_n_vars_check(&multilinears)?;
228		let n_multilinears = multilinears.len();
229
230		let compositions = zero_claims.into_iter().collect::<Vec<_>>();
231		for (_, composition_base, composition) in &compositions {
232			if composition_base.n_vars() != n_multilinears
233				|| composition.n_vars() != n_multilinears
234				|| composition_base.degree() != composition.degree()
235			{
236				bail!(Error::InvalidComposition {
237					actual: composition.n_vars(),
238					expected: n_multilinears,
239				});
240			}
241		}
242		#[cfg(feature = "debug_validate_sumcheck")]
243		{
244			let compositions = compositions
245				.iter()
246				.map(|(name, _, a)| (name.clone(), a))
247				.collect::<Vec<_>>();
248			validate_witness(&multilinears, &compositions)?;
249		}
250
251		let zerocheck_challenges = zerocheck_challenges.to_vec();
252		let state = ZerocheckProverState::RoundEval {
253			multilinears,
254			compositions,
255			domain_factory,
256		};
257
258		Ok(Self {
259			n_vars,
260			zerocheck_challenges,
261			state,
262			backend,
263			_p_base_marker: PhantomData,
264			_fdomain_marker: PhantomData,
265		})
266	}
267}
268
269impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, DomainFactory, Backend>
270	ZerocheckProver<'a, P>
271	for ZerocheckProverImpl<
272		'a,
273		FDomain,
274		FBase,
275		P,
276		CompositionBase,
277		Composition,
278		M,
279		DomainFactory,
280		Backend,
281	>
282where
283	F: TowerField,
284	FDomain: TowerField,
285	FBase: ExtensionField<FDomain>,
286	P: PackedField<Scalar = F>
287		+ PackedExtension<F, PackedSubfield = P>
288		+ PackedExtension<FBase>
289		+ PackedExtension<FDomain>,
290	CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
291	Composition: CompositionPoly<P> + 'static,
292	M: MultilinearPoly<P> + Send + Sync + 'a,
293	DomainFactory: EvaluationDomainFactory<FDomain>,
294	Backend: ComputationBackend,
295{
296	fn n_vars(&self) -> usize {
297		self.n_vars
298	}
299
300	fn domain_size(&self, skip_rounds: usize) -> Option<usize> {
301		let ZerocheckProverState::RoundEval { compositions, .. } = &self.state else {
302			return None;
303		};
304
305		Some(
306			compositions
307				.iter()
308				.map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
309				.max()
310				.unwrap_or(0),
311		)
312	}
313
314	fn execute_univariate_round(
315		&mut self,
316		skip_rounds: usize,
317		max_domain_size: usize,
318		batch_coeff: F,
319	) -> Result<ZerocheckRoundEvals<F>, Error> {
320		let ZerocheckProverState::RoundEval {
321			multilinears,
322			compositions,
323			domain_factory,
324		} = mem::take(&mut self.state)
325		else {
326			bail!(Error::ExpectedExecution);
327		};
328
329		// High pad "short" multilinears to at least `skip_rounds` variables.
330		let padded_multilinears = multilinears
331			.into_iter()
332			.map(|multilinear| high_pad_small_multilinear(skip_rounds, multilinear))
333			.collect::<Vec<_>>();
334
335		// Only use base compositions in the univariate round (it's the whole point)
336		let compositions_base = compositions
337			.iter()
338			.map(|(_, composition_base, _)| composition_base)
339			.collect::<Vec<_>>();
340
341		// Output contains values that are needed for computations that happen after
342		// the round challenge has been sampled
343		let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
344			&padded_multilinears,
345			&compositions_base,
346			&self.zerocheck_challenges,
347			skip_rounds,
348			max_domain_size,
349			self.backend,
350		)?;
351
352		// Batch together Lagrange round evals using powers of batch_coeff
353		let batched_round_evals = univariate_evals_output
354			.round_evals
355			.iter()
356			.zip(powers(batch_coeff))
357			.map(|(evals, scalar)| {
358				ZerocheckRoundEvals {
359					evals: evals.clone(),
360				} * scalar
361			})
362			.try_fold(
363				ZerocheckRoundEvals::zeros(max_domain_size - (1 << skip_rounds)),
364				|mut accum, evals| -> Result<_, Error> {
365					accum.add_assign_lagrange(&evals)?;
366					Ok(accum)
367				},
368			)?;
369
370		self.state = ZerocheckProverState::Folding {
371			skip_rounds,
372			padded_multilinears,
373			compositions,
374			domain_factory,
375			univariate_evals_output,
376		};
377
378		Ok(batched_round_evals)
379	}
380
381	#[instrument(skip_all, level = "debug")]
382	fn fold_univariate_round(
383		&mut self,
384		challenge: F,
385	) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
386		let ZerocheckProverState::Folding {
387			skip_rounds,
388			padded_multilinears,
389			compositions,
390			domain_factory,
391			univariate_evals_output,
392		} = mem::take(&mut self.state)
393		else {
394			bail!(Error::ExpectedFold);
395		};
396
397		// Once the challenge is known, values required for the instantiation of the
398		// multilinear prover for the remaining rounds become known.
399		let ZerocheckUnivariateFoldResult {
400			remaining_rounds,
401			subcube_lagrange_coeffs,
402			claimed_sums,
403			mut partial_eq_ind_evals,
404		} = univariate_evals_output.fold::<FDomain>(challenge)?;
405
406		// For each subcube of size 2**skip_rounds, we need to compute its
407		// inner product with Lagrange coefficients at challenge point in order
408		// to obtain the witness for the remaining multilinear rounds.
409		// REVIEW: Currently MultilinearPoly lacks a method to do that, so we
410		//         hack the needed functionality by overwriting the inner content
411		//         of a MultilinearQuery and performing an evaluate_partial_low,
412		//         which accidentally does what's needed. There should obviously
413		//         be a dedicated method for this someday.
414		let mut packed_subcube_lagrange_coeffs =
415			zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
416		copy_packed_from_scalars_slice(
417			&subcube_lagrange_coeffs[..1 << skip_rounds],
418			&mut packed_subcube_lagrange_coeffs,
419		);
420		let lagrange_coeffs_query =
421			MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
422
423		let folded_multilinears = padded_multilinears
424			.par_iter()
425			.map(|multilinear| -> Result<_, Error> {
426				let folded_multilinear = multilinear
427					.evaluate_partial_low(lagrange_coeffs_query.to_ref())?
428					.into_evals();
429
430				Ok(folded_multilinear)
431			})
432			.collect::<Result<Vec<_>, _>>()?;
433
434		let composite_claims = izip!(compositions, claimed_sums)
435			.map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum })
436			.collect::<Vec<_>>();
437
438		// Zerocheck tensor expansion for the reduced zerocheck should be one variable less
439		fold_partial_eq_ind::<P, Backend>(
440			EvaluationOrder::HighToLow,
441			remaining_rounds,
442			&mut partial_eq_ind_evals,
443		);
444
445		// The remaining non-univariate zerocheck rounds are an instance of EqIndSumcheck,
446		// due to the number of zerocheck challenges being equal to the number of remaining rounds.
447		// Note: while univariate round happens over lowest `skip_rounds` variables, the reduced
448		// EqIndSumcheck is high-to-low.
449		let regular_prover = EqIndSumcheckProverBuilder::without_switchover(
450			remaining_rounds,
451			folded_multilinears,
452			self.backend,
453		)
454		.with_eq_ind_partial_evals(partial_eq_ind_evals)
455		.build(
456			EvaluationOrder::HighToLow,
457			&self.zerocheck_challenges,
458			composite_claims,
459			domain_factory,
460		)?;
461
462		self.state = ZerocheckProverState::Projection {
463			skip_rounds,
464			padded_multilinears,
465		};
466
467		Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
468	}
469
470	fn project_to_skipped_variables(
471		self: Box<Self>,
472		challenges: &[F],
473	) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error> {
474		let ZerocheckProverState::Projection {
475			skip_rounds,
476			padded_multilinears,
477		} = self.state
478		else {
479			bail!(Error::ExpectedProjection);
480		};
481
482		let projection_n_vars = self.n_vars.saturating_sub(skip_rounds);
483		if challenges.len() < projection_n_vars {
484			bail!(Error::IncorrectNumberOfChallenges);
485		}
486
487		let packed_skipped_projections = if self.n_vars < skip_rounds {
488			padded_multilinears
489				.into_iter()
490				.map(|multilinear| {
491					multilinear
492						.expect_right("all multilinears are high-padded")
493						.upcast_arc_dyn()
494				})
495				.collect::<Vec<_>>()
496		} else {
497			let query = self
498				.backend
499				.multilinear_query(&challenges[challenges.len() - projection_n_vars..])?;
500			padded_multilinears
501				.par_iter()
502				.map(|multilinear| {
503					let projected_mle = self
504						.backend
505						.evaluate_partial_high(multilinear, query.to_ref())
506						.expect("sumcheck_challenges.len() >= n_vars - skip_rounds");
507
508					MLEDirectAdapter::from(projected_mle).upcast_arc_dyn()
509				})
510				.collect::<Vec<_>>()
511		};
512
513		Ok(packed_skipped_projections)
514	}
515}