binius_core/protocols/sumcheck/prove/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, sync::Arc};
4
5use binius_field::{
6	util::powers, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
7	PackedSubfield, TowerField,
8};
9use binius_hal::ComputationBackend;
10use binius_math::{
11	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearPoly,
12	MultilinearQuery,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_utils::bail;
16use bytemuck::zeroed_vec;
17use getset::Getters;
18use itertools::izip;
19use tracing::instrument;
20
21use crate::{
22	polynomial::MultilinearComposite,
23	protocols::sumcheck::{
24		common::{equal_n_vars_check, CompositeSumClaim},
25		prove::{
26			eq_ind::EqIndSumcheckProverBuilder,
27			univariate::{
28				zerocheck_univariate_evals, ZerocheckUnivariateEvalsOutput,
29				ZerocheckUnivariateFoldResult,
30			},
31			SumcheckProver, UnivariateZerocheckProver,
32		},
33		univariate::LagrangeRoundEvals,
34		univariate_zerocheck::domain_size,
35		Error,
36	},
37	witness::MultilinearWitness,
38};
39
40pub fn validate_witness<'a, F, P, M, Composition>(
41	multilinears: &[M],
42	zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
43) -> Result<(), Error>
44where
45	F: Field,
46	P: PackedField<Scalar = F>,
47	M: MultilinearPoly<P> + Send + Sync,
48	Composition: CompositionPoly<P> + 'a,
49{
50	let n_vars = multilinears
51		.first()
52		.map(|multilinear| multilinear.n_vars())
53		.unwrap_or_default();
54	for multilinear in multilinears {
55		if multilinear.n_vars() != n_vars {
56			bail!(Error::NumberOfVariablesMismatch);
57		}
58	}
59
60	let multilinears = multilinears.iter().collect::<Vec<_>>();
61
62	for (name, composition) in zero_claims {
63		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
64		(0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
65			if witness.evaluate_on_hypercube(j)? != F::ZERO {
66				return Err(Error::ZerocheckNaiveValidationFailure {
67					composition_name: name.to_string(),
68					vertex_index: j,
69				});
70			}
71			Ok(())
72		})?;
73	}
74	Ok(())
75}
76
77/// A prover that is capable of performing univariate skip.
78///
79/// By recasting `skip_rounds` first variables in a multilinear sumcheck into a univariate domain,
80/// it becomes possible to compute all of these rounds in small fields, unlocking significant
81/// performance gains. See [`zerocheck_univariate_evals`] rustdoc for a more detailed explanation.
82///
83/// This struct is an entrypoint to proving all zerochecks instances, univariatized and regular.
84/// "Regular" multilinear case is covered by calling [`Self::into_regular_zerocheck`] right away,
85/// producing a `EqIndSumcheckProver`. Univariatized case is handled by using methods from a
86/// [`UnivariateZerocheckProver`] trait, where folding results in a reduced multilinear zerocheck
87/// prover for the remaining rounds.
88#[derive(Debug, Getters)]
89pub struct UnivariateZerocheck<
90	'a,
91	FDomain,
92	FBase,
93	P,
94	CompositionBase,
95	Composition,
96	M,
97	DomainFactory,
98	SwitchoverFn,
99	Backend,
100> where
101	FDomain: Field,
102	FBase: Field,
103	P: PackedField,
104	Backend: ComputationBackend,
105{
106	n_vars: usize,
107	#[getset(get = "pub")]
108	multilinears: Vec<M>,
109	compositions: Vec<(String, CompositionBase, Composition)>,
110	zerocheck_challenges: Vec<P::Scalar>,
111	domain_factory: DomainFactory,
112	switchover_fn: SwitchoverFn,
113	backend: &'a Backend,
114	univariate_evals_output: Option<ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>>,
115	_p_base_marker: PhantomData<FBase>,
116	_fdomain_marker: PhantomData<FDomain>,
117}
118
119impl<
120		'a,
121		F,
122		FDomain,
123		FBase,
124		P,
125		CompositionBase,
126		Composition,
127		M,
128		DomainFactory,
129		SwitchoverFn,
130		Backend,
131	>
132	UnivariateZerocheck<
133		'a,
134		FDomain,
135		FBase,
136		P,
137		CompositionBase,
138		Composition,
139		M,
140		DomainFactory,
141		SwitchoverFn,
142		Backend,
143	>
144where
145	F: TowerField,
146	FDomain: Field,
147	FBase: ExtensionField<FDomain>,
148	P: PackedFieldIndexable<Scalar = F>
149		+ PackedExtension<F, PackedSubfield = P>
150		+ PackedExtension<FBase>
151		+ PackedExtension<FDomain>,
152	CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
153	Composition: CompositionPoly<P> + 'a,
154	M: MultilinearPoly<P> + Send + Sync + 'a,
155	DomainFactory: EvaluationDomainFactory<FDomain>,
156	SwitchoverFn: Fn(usize) -> usize,
157	Backend: ComputationBackend,
158{
159	pub fn new(
160		multilinears: Vec<M>,
161		zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
162		zerocheck_challenges: &[F],
163		domain_factory: DomainFactory,
164		switchover_fn: SwitchoverFn,
165		backend: &'a Backend,
166	) -> Result<Self, Error> {
167		let n_vars = equal_n_vars_check(&multilinears)?;
168
169		let compositions = zero_claims.into_iter().collect::<Vec<_>>();
170		for (_, composition_base, composition) in &compositions {
171			if composition_base.n_vars() != multilinears.len()
172				|| composition.n_vars() != multilinears.len()
173				|| composition_base.degree() != composition.degree()
174			{
175				bail!(Error::InvalidComposition {
176					actual: composition.n_vars(),
177					expected: multilinears.len(),
178				});
179			}
180		}
181		#[cfg(feature = "debug_validate_sumcheck")]
182		{
183			let compositions = compositions
184				.iter()
185				.map(|(name, _, a)| (name.clone(), a))
186				.collect::<Vec<_>>();
187			validate_witness(&multilinears, &compositions)?;
188		}
189
190		let zerocheck_challenges = zerocheck_challenges.to_vec();
191
192		Ok(Self {
193			n_vars,
194			multilinears,
195			compositions,
196			zerocheck_challenges,
197			domain_factory,
198			switchover_fn,
199			backend,
200			univariate_evals_output: None,
201			_p_base_marker: PhantomData,
202			_fdomain_marker: PhantomData,
203		})
204	}
205
206	#[instrument(skip_all, level = "debug")]
207	#[allow(clippy::type_complexity)]
208	pub fn into_regular_zerocheck(self) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
209		if self.univariate_evals_output.is_some() {
210			bail!(Error::ExpectedFold);
211		}
212
213		#[cfg(feature = "debug_validate_sumcheck")]
214		{
215			let compositions = self
216				.compositions
217				.iter()
218				.map(|(name, _, a)| (name.clone(), a))
219				.collect::<Vec<_>>();
220			validate_witness(&self.multilinears, &compositions)?;
221		}
222
223		let composite_claims = self
224			.compositions
225			.into_iter()
226			.map(|(_, _, composition)| CompositeSumClaim {
227				composition,
228				sum: F::ZERO,
229			})
230			.collect::<Vec<_>>();
231
232		let first_round_eval_1s = composite_claims.iter().map(|_| F::ZERO).collect::<Vec<_>>();
233
234		let prover = EqIndSumcheckProverBuilder::new(self.backend)
235			.with_first_round_eval_1s(&first_round_eval_1s)
236			.build(
237				EvaluationOrder::LowToHigh,
238				self.multilinears,
239				&self.zerocheck_challenges,
240				composite_claims,
241				self.domain_factory,
242				self.switchover_fn,
243			)?;
244
245		Ok(Box::new(prover) as Box<dyn SumcheckProver<F> + 'a>)
246	}
247}
248
249impl<
250		'a,
251		F,
252		FDomain,
253		FBase,
254		P,
255		CompositionBase,
256		Composition,
257		M,
258		InterpolationDomainFactory,
259		SwitchoverFn,
260		Backend,
261	> UnivariateZerocheckProver<'a, F>
262	for UnivariateZerocheck<
263		'a,
264		FDomain,
265		FBase,
266		P,
267		CompositionBase,
268		Composition,
269		M,
270		InterpolationDomainFactory,
271		SwitchoverFn,
272		Backend,
273	>
274where
275	F: TowerField,
276	FDomain: TowerField,
277	FBase: ExtensionField<FDomain>,
278	P: PackedFieldIndexable<Scalar = F>
279		+ PackedExtension<F, PackedSubfield = P>
280		+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
281		+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
282	CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
283	Composition: CompositionPoly<P> + 'static,
284	M: MultilinearPoly<P> + Send + Sync + 'a,
285	InterpolationDomainFactory: EvaluationDomainFactory<FDomain>,
286	SwitchoverFn: Fn(usize) -> usize,
287	Backend: ComputationBackend,
288{
289	fn n_vars(&self) -> usize {
290		self.n_vars
291	}
292
293	fn domain_size(&self, skip_rounds: usize) -> usize {
294		self.compositions
295			.iter()
296			.map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
297			.max()
298			.unwrap_or(0)
299	}
300
301	#[instrument(skip_all, level = "debug")]
302	fn execute_univariate_round(
303		&mut self,
304		skip_rounds: usize,
305		max_domain_size: usize,
306		batch_coeff: F,
307	) -> Result<LagrangeRoundEvals<F>, Error> {
308		if self.univariate_evals_output.is_some() {
309			bail!(Error::ExpectedFold);
310		}
311
312		// Only use base compositions in the univariate round (it's the whole point)
313		let compositions_base = self
314			.compositions
315			.iter()
316			.map(|(_, composition_base, _)| composition_base)
317			.collect::<Vec<_>>();
318
319		// Output contains values that are needed for computations that happen after
320		// the round challenge has been sampled
321		let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
322			&self.multilinears,
323			&compositions_base,
324			&self.zerocheck_challenges,
325			skip_rounds,
326			max_domain_size,
327			self.backend,
328		)?;
329
330		// Batch together Lagrange round evals using powers of batch_coeff
331		let zeros_prefix_len = 1 << skip_rounds;
332		let batched_round_evals = univariate_evals_output
333			.round_evals
334			.iter()
335			.zip(powers(batch_coeff))
336			.map(|(evals, scalar)| {
337				let round_evals = LagrangeRoundEvals {
338					zeros_prefix_len,
339					evals: evals.clone(),
340				};
341				round_evals * scalar
342			})
343			.try_fold(
344				LagrangeRoundEvals::zeros(max_domain_size),
345				|mut accum, evals| -> Result<_, Error> {
346					accum.add_assign_lagrange(&evals)?;
347					Ok(accum)
348				},
349			)?;
350
351		self.univariate_evals_output = Some(univariate_evals_output);
352
353		Ok(batched_round_evals)
354	}
355
356	#[instrument(skip_all, level = "debug")]
357	fn fold_univariate_round(
358		self: Box<Self>,
359		challenge: F,
360	) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
361		if self.univariate_evals_output.is_none() {
362			bail!(Error::ExpectedExecution);
363		}
364
365		// Once the challenge is known, values required for the instantiation of the
366		// multilinear prover for the remaining rounds become known.
367		let ZerocheckUnivariateFoldResult {
368			skip_rounds,
369			subcube_lagrange_coeffs,
370			claimed_sums,
371			partial_eq_ind_evals,
372		} = self
373			.univariate_evals_output
374			.expect("validated to be Some")
375			.fold::<FDomain>(challenge)?;
376
377		// For each subcube of size 2**skip_rounds, we need to compute its
378		// inner product with Lagrange coefficients at challenge point in order
379		// to obtain the witness for the remaining multilinear rounds.
380		// REVIEW: Currently MultilinearPoly lacks a method to do that, so we
381		//         hack the needed functionality by overwriting the inner content
382		//         of a MultilinearQuery and performing an evaluate_partial_low,
383		//         which accidentally does what's needed. There should obviously
384		//         be a dedicated method for this someday.
385		let mut packed_subcube_lagrange_coeffs =
386			zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
387		P::unpack_scalars_mut(&mut packed_subcube_lagrange_coeffs)[..1 << skip_rounds]
388			.copy_from_slice(&subcube_lagrange_coeffs);
389		let lagrange_coeffs_query =
390			MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
391
392		let partial_low_multilinears = self
393			.multilinears
394			.into_par_iter()
395			.map(|multilinear| -> Result<_, Error> {
396				let multilinear =
397					multilinear.evaluate_partial_low(lagrange_coeffs_query.to_ref())?;
398				let mle_adapter = Arc::new(MLEDirectAdapter::from(multilinear));
399				Ok(mle_adapter as MultilinearWitness<'static, P>)
400			})
401			.collect::<Result<Vec<_>, _>>()?;
402
403		let composite_claims = izip!(self.compositions, claimed_sums)
404			.map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum })
405			.collect::<Vec<_>>();
406
407		// The remaining non-univariate zerocheck rounds are an instance of EqIndSumcheck,
408		// due to the number of zerocheck challenges being equal to the number of remaining rounds.
409		let regular_prover = EqIndSumcheckProverBuilder::new(self.backend)
410			.with_eq_ind_partial_evals(partial_eq_ind_evals)
411			.build(
412				EvaluationOrder::LowToHigh,
413				partial_low_multilinears,
414				&self.zerocheck_challenges,
415				composite_claims,
416				self.domain_factory,
417				|extension_degree| {
418					(self.switchover_fn)(extension_degree).saturating_sub(skip_rounds)
419				},
420			)?;
421
422		Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
423	}
424}