binius_core/protocols/sumcheck/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use binius_field::{util::eq, Field, PackedField};
6use binius_math::{ArithExpr, CompositionPoly};
7use binius_utils::{bail, sorting::is_sorted_ascending};
8use getset::CopyGetters;
9
10use super::error::{Error, VerificationError};
11use crate::protocols::sumcheck::{BatchSumcheckOutput, CompositeSumClaim, SumcheckClaim};
12
13#[derive(Debug, CopyGetters)]
14pub struct ZerocheckClaim<F: Field, Composition> {
15	#[getset(get_copy = "pub")]
16	n_vars: usize,
17	#[getset(get_copy = "pub")]
18	n_multilinears: usize,
19	composite_zeros: Vec<Composition>,
20	_marker: PhantomData<F>,
21}
22
23impl<F: Field, Composition> ZerocheckClaim<F, Composition>
24where
25	Composition: CompositionPoly<F>,
26{
27	pub fn new(
28		n_vars: usize,
29		n_multilinears: usize,
30		composite_zeros: Vec<Composition>,
31	) -> Result<Self, Error> {
32		for composition in &composite_zeros {
33			if composition.n_vars() != n_multilinears {
34				bail!(Error::InvalidComposition {
35					actual: composition.n_vars(),
36					expected: n_multilinears,
37				});
38			}
39		}
40		Ok(Self {
41			n_vars,
42			n_multilinears,
43			composite_zeros,
44			_marker: PhantomData,
45		})
46	}
47
48	/// Returns the maximum individual degree of all composite polynomials.
49	pub fn max_individual_degree(&self) -> usize {
50		self.composite_zeros
51			.iter()
52			.map(|composite_zero| composite_zero.degree())
53			.max()
54			.unwrap_or(0)
55	}
56
57	pub fn composite_zeros(&self) -> &[Composition] {
58		&self.composite_zeros
59	}
60}
61
62/// Requirement: zerocheck challenges have been sampled before this is called
63pub fn reduce_to_sumchecks<F: Field, Composition: CompositionPoly<F>>(
64	claims: &[ZerocheckClaim<F, Composition>],
65) -> Result<Vec<SumcheckClaim<F, ExtraProduct<&Composition>>>, Error> {
66	// Check that the claims are in descending order by n_vars
67	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
68		bail!(Error::ClaimsOutOfOrder);
69	}
70
71	claims
72		.iter()
73		.map(|zerocheck_claim| {
74			let ZerocheckClaim {
75				n_vars,
76				n_multilinears,
77				composite_zeros,
78				..
79			} = zerocheck_claim;
80			SumcheckClaim::new(
81				*n_vars,
82				*n_multilinears + 1,
83				composite_zeros
84					.iter()
85					.map(|composition| CompositeSumClaim {
86						composition: ExtraProduct { inner: composition },
87						sum: F::ZERO,
88					})
89					.collect(),
90			)
91		})
92		.collect()
93}
94
95/// Verify the validity of the sumcheck outputs for a reduced zerocheck.
96///
97/// This takes in the output of the reduced sumcheck protocol and returns the output for the
98/// zerocheck instance. This simply strips off the multilinear evaluation of the eq indicator
99/// polynomial and verifies that the value is correct.
100///
101/// Note that due to univariatization of some rounds the number of challenges may be less than
102/// the maximum number of variables among claims.
103pub fn verify_sumcheck_outputs<F: Field, Composition: CompositionPoly<F>>(
104	claims: &[ZerocheckClaim<F, Composition>],
105	zerocheck_challenges: &[F],
106	sumcheck_output: BatchSumcheckOutput<F>,
107) -> Result<BatchSumcheckOutput<F>, Error> {
108	let BatchSumcheckOutput {
109		challenges: sumcheck_challenges,
110		mut multilinear_evals,
111	} = sumcheck_output;
112
113	assert_eq!(multilinear_evals.len(), claims.len());
114
115	// Check that the claims are in descending order by n_vars
116	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
117		bail!(Error::ClaimsOutOfOrder);
118	}
119
120	let max_n_vars = claims
121		.first()
122		.map(|claim| claim.n_vars())
123		.unwrap_or_default();
124
125	assert!(sumcheck_challenges.len() <= max_n_vars);
126	assert_eq!(zerocheck_challenges.len(), sumcheck_challenges.len());
127
128	let mut eq_ind_eval = F::ONE;
129	let mut last_n_vars = 0;
130	for (claim, multilinear_evals) in claims.iter().zip(multilinear_evals.iter_mut()).rev() {
131		assert_eq!(claim.n_multilinears() + 1, multilinear_evals.len());
132
133		while last_n_vars < claim.n_vars() && last_n_vars < sumcheck_challenges.len() {
134			let sumcheck_challenge =
135				sumcheck_challenges[sumcheck_challenges.len() - 1 - last_n_vars];
136			let zerocheck_challenge =
137				zerocheck_challenges[zerocheck_challenges.len() - 1 - last_n_vars];
138			eq_ind_eval *= eq(sumcheck_challenge, zerocheck_challenge);
139			last_n_vars += 1;
140		}
141
142		let multilinear_evals_last = multilinear_evals
143			.pop()
144			.expect("checked above that multilinear_evals length is at least 1");
145		if eq_ind_eval != multilinear_evals_last {
146			return Err(VerificationError::IncorrectEqIndEvaluation.into());
147		}
148	}
149
150	Ok(BatchSumcheckOutput {
151		challenges: sumcheck_challenges,
152		multilinear_evals,
153	})
154}
155
156#[derive(Debug)]
157pub struct ExtraProduct<Composition> {
158	pub inner: Composition,
159}
160
161impl<P, Composition> CompositionPoly<P> for ExtraProduct<Composition>
162where
163	P: PackedField,
164	Composition: CompositionPoly<P>,
165{
166	fn n_vars(&self) -> usize {
167		self.inner.n_vars() + 1
168	}
169
170	fn degree(&self) -> usize {
171		self.inner.degree() + 1
172	}
173
174	fn expression(&self) -> ArithExpr<P::Scalar> {
175		self.inner.expression() * ArithExpr::Var(self.inner.n_vars())
176	}
177
178	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
179		let n_vars = self.n_vars();
180		if query.len() != n_vars {
181			bail!(binius_math::Error::IncorrectQuerySize { expected: n_vars });
182		}
183
184		let inner_eval = self.inner.evaluate(&query[..n_vars - 1])?;
185		Ok(inner_eval * query[n_vars - 1])
186	}
187
188	fn binary_tower_level(&self) -> usize {
189		self.inner.binary_tower_level()
190	}
191}
192
193#[cfg(test)]
194mod tests {
195	use std::{iter, sync::Arc};
196
197	use binius_field::{
198		BinaryField128b, BinaryField32b, BinaryField8b, PackedBinaryField1x128b, PackedExtension,
199		PackedFieldIndexable, PackedSubfield, RepackedExtension,
200	};
201	use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt};
202	use binius_math::{
203		EvaluationDomainFactory, EvaluationOrder, IsomorphicEvaluationDomainFactory,
204		MultilinearPoly,
205	};
206	use groestl_crypto::Groestl256;
207	use rand::{prelude::StdRng, SeedableRng};
208
209	use super::*;
210	use crate::{
211		fiat_shamir::{CanSample, HasherChallenger},
212		protocols::{
213			sumcheck::{
214				batch_verify,
215				prove::{batch_prove, zerocheck, RegularSumcheckProver, UnivariateZerocheck},
216			},
217			test_utils::{generate_zero_product_multilinears, TestProductComposition},
218		},
219		transcript::ProverTranscript,
220		transparent::eq_ind::EqIndPartialEval,
221		witness::MultilinearWitness,
222	};
223
224	fn make_regular_sumcheck_prover_for_zerocheck<'a, F, FDomain, P, Composition, M, Backend>(
225		multilinears: Vec<M>,
226		zero_claims: impl IntoIterator<Item = Composition>,
227		challenges: &[F],
228		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
229		switchover_fn: impl Fn(usize) -> usize,
230		backend: &'a Backend,
231	) -> RegularSumcheckProver<
232		'a,
233		FDomain,
234		P,
235		ExtraProduct<Composition>,
236		MultilinearWitness<'static, P>,
237		Backend,
238	>
239	where
240		F: Field,
241		FDomain: Field,
242		P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<P>,
243		Composition: CompositionPoly<P>,
244		M: MultilinearPoly<P> + Send + Sync + 'static,
245		Backend: ComputationBackend,
246	{
247		let eq_ind = EqIndPartialEval::new(challenges)
248			.multilinear_extension::<P, _>(backend)
249			.unwrap();
250
251		let multilinears = multilinears
252			.into_iter()
253			.map(|multilin| Arc::new(multilin) as Arc<dyn MultilinearPoly<_> + Send + Sync>)
254			.chain([eq_ind.specialize_arc_dyn()])
255			.collect();
256
257		let composite_sum_claims = zero_claims
258			.into_iter()
259			.map(|composition| CompositeSumClaim {
260				composition: ExtraProduct { inner: composition },
261				sum: F::ZERO,
262			});
263		RegularSumcheckProver::new(
264			EvaluationOrder::LowToHigh,
265			multilinears,
266			composite_sum_claims,
267			evaluation_domain_factory,
268			switchover_fn,
269			backend,
270		)
271		.unwrap()
272	}
273
274	fn test_compare_prover_with_reference(
275		n_vars: usize,
276		n_multilinears: usize,
277		switchover_rd: usize,
278	) {
279		type P = PackedBinaryField1x128b;
280		type FBase = BinaryField32b;
281		type FDomain = BinaryField8b;
282		let mut rng = StdRng::seed_from_u64(0);
283
284		// Setup ZC Witness
285		let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
286			&mut rng,
287			n_vars,
288			n_multilinears,
289		);
290
291		let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
292		zerocheck::validate_witness(&multilins, &binding).unwrap();
293
294		let mut prove_transcript_1 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
295		let backend = make_portable_backend();
296		let challenges = prove_transcript_1.sample_vec(n_vars);
297
298		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
299		let reference_prover = make_regular_sumcheck_prover_for_zerocheck::<_, FDomain, _, _, _, _>(
300			multilins.clone(),
301			binding.into_iter().map(|(_, composition)| composition),
302			&challenges,
303			domain_factory.clone(),
304			|_| switchover_rd,
305			&backend,
306		);
307
308		let BatchSumcheckOutput {
309			challenges: sumcheck_challenges_1,
310			multilinear_evals: multilinear_evals_1,
311		} = batch_prove(vec![reference_prover], &mut prove_transcript_1).unwrap();
312
313		let composition = TestProductComposition::new(n_multilinears);
314		let optimized_prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _>::new(
315			multilins,
316			[("test_product".into(), composition.clone(), composition)],
317			&challenges,
318			domain_factory,
319			|_| switchover_rd,
320			&backend,
321		)
322		.unwrap()
323		.into_regular_zerocheck()
324		.unwrap();
325
326		let mut prove_transcript_2 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
327		let _: Vec<BinaryField128b> = prove_transcript_2.sample_vec(n_vars);
328		let BatchSumcheckOutput {
329			challenges: sumcheck_challenges_2,
330			multilinear_evals: multilinear_evals_2,
331		} = batch_prove(vec![optimized_prover], &mut prove_transcript_2).unwrap();
332
333		assert_eq!(prove_transcript_1.finalize(), prove_transcript_2.finalize());
334		assert_eq!(multilinear_evals_1, multilinear_evals_2);
335		assert_eq!(sumcheck_challenges_1, sumcheck_challenges_2);
336	}
337
338	fn test_prove_verify_product_constraint_helper(
339		n_vars: usize,
340		n_multilinears: usize,
341		switchover_rd: usize,
342	) {
343		type P = PackedBinaryField1x128b;
344		type FBase = BinaryField32b;
345		type FE = BinaryField128b;
346		type FDomain = BinaryField8b;
347		let mut rng = StdRng::seed_from_u64(0);
348
349		let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
350			&mut rng,
351			n_vars,
352			n_multilinears,
353		);
354
355		let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
356		zerocheck::validate_witness(&multilins, &binding).unwrap();
357
358		let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
359		let challenges = prove_transcript.sample_vec(n_vars);
360
361		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
362		let backend = make_portable_backend();
363
364		let composition = TestProductComposition::new(n_multilinears);
365		let prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _>::new(
366			multilins.clone(),
367			[("test_product".into(), composition.clone(), composition)],
368			&challenges,
369			domain_factory,
370			|_| switchover_rd,
371			&backend,
372		)
373		.unwrap()
374		.into_regular_zerocheck()
375		.unwrap();
376
377		let prove_output = batch_prove(vec![prover], &mut prove_transcript).unwrap();
378
379		let claim = ZerocheckClaim::new(
380			n_vars,
381			n_multilinears,
382			vec![TestProductComposition::new(n_multilinears)],
383		)
384		.unwrap();
385		let zerocheck_claims = [claim];
386		let BatchSumcheckOutput {
387			challenges: prover_eval_point,
388			multilinear_evals: prover_multilinear_evals,
389		} = verify_sumcheck_outputs(
390			&zerocheck_claims,
391			&challenges,
392			prove_output,
393			// prover_sumcheck_multilinear_evals,
394			// &prover_sumcheck_challenges,
395		)
396		.unwrap();
397
398		let prover_sample = CanSample::<FE>::sample(&mut prove_transcript);
399		let mut verify_transcript = prove_transcript.into_verifier();
400		let _: Vec<BinaryField128b> = verify_transcript.sample_vec(n_vars);
401
402		let sumcheck_claims = reduce_to_sumchecks(&zerocheck_claims).unwrap();
403		let verifier_output =
404			batch_verify(EvaluationOrder::LowToHigh, &sumcheck_claims, &mut verify_transcript)
405				.unwrap();
406
407		let BatchSumcheckOutput {
408			challenges: verifier_eval_point,
409			multilinear_evals: verifier_multilinear_evals,
410		} = verify_sumcheck_outputs(&zerocheck_claims, &challenges, verifier_output).unwrap();
411
412		// Check that challengers are in the same state
413		assert_eq!(prover_sample, CanSample::<FE>::sample(&mut verify_transcript));
414		verify_transcript.finalize().unwrap();
415
416		assert_eq!(prover_eval_point, verifier_eval_point);
417		assert_eq!(prover_multilinear_evals, verifier_multilinear_evals);
418
419		assert_eq!(verifier_multilinear_evals.len(), 1);
420		assert_eq!(verifier_multilinear_evals[0].len(), n_multilinears);
421
422		// Verify the reduced multilinear evaluations are correct
423		let multilin_query = backend.multilinear_query(&verifier_eval_point).unwrap();
424		for (multilinear, &expected) in iter::zip(multilins, verifier_multilinear_evals[0].iter()) {
425			assert_eq!(multilinear.evaluate(multilin_query.to_ref()).unwrap(), expected);
426		}
427	}
428
429	#[test]
430	fn test_compare_zerocheck_prover_to_regular_sumcheck() {
431		for n_vars in 2..8 {
432			for n_multilinears in 1..5 {
433				for switchover_rd in 1..=n_vars / 2 {
434					test_compare_prover_with_reference(n_vars, n_multilinears, switchover_rd);
435				}
436			}
437		}
438	}
439
440	#[test]
441	fn test_prove_verify_product_basic() {
442		for n_vars in 2..8 {
443			for n_multilinears in 1..5 {
444				for switchover_rd in 1..=n_vars / 2 {
445					test_prove_verify_product_constraint_helper(
446						n_vars,
447						n_multilinears,
448						switchover_rd,
449					);
450				}
451			}
452		}
453	}
454}