binius_core/protocols/sumcheck/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use binius_field::Field;
6use binius_math::CompositionPoly;
7use binius_utils::{bail, sorting::is_sorted_ascending};
8use getset::CopyGetters;
9
10use super::error::Error;
11use crate::protocols::sumcheck::{eq_ind::EqIndSumcheckClaim, CompositeSumClaim};
12
13#[derive(Debug, CopyGetters)]
14pub struct ZerocheckClaim<F: Field, Composition> {
15	#[getset(get_copy = "pub")]
16	n_vars: usize,
17	#[getset(get_copy = "pub")]
18	n_multilinears: usize,
19	composite_zeros: Vec<Composition>,
20	_marker: PhantomData<F>,
21}
22
23impl<F: Field, Composition> ZerocheckClaim<F, Composition>
24where
25	Composition: CompositionPoly<F>,
26{
27	pub fn new(
28		n_vars: usize,
29		n_multilinears: usize,
30		composite_zeros: Vec<Composition>,
31	) -> Result<Self, Error> {
32		for composition in &composite_zeros {
33			if composition.n_vars() != n_multilinears {
34				bail!(Error::InvalidComposition {
35					actual: composition.n_vars(),
36					expected: n_multilinears,
37				});
38			}
39		}
40		Ok(Self {
41			n_vars,
42			n_multilinears,
43			composite_zeros,
44			_marker: PhantomData,
45		})
46	}
47
48	/// Returns the maximum individual degree of all composite polynomials.
49	pub fn max_individual_degree(&self) -> usize {
50		self.composite_zeros
51			.iter()
52			.map(|composite_zero| composite_zero.degree())
53			.max()
54			.unwrap_or(0)
55	}
56
57	pub fn composite_zeros(&self) -> &[Composition] {
58		&self.composite_zeros
59	}
60}
61
62pub fn reduce_to_eq_ind_sumchecks<F: Field, Composition: CompositionPoly<F>>(
63	claims: &[ZerocheckClaim<F, Composition>],
64) -> Result<Vec<EqIndSumcheckClaim<F, &Composition>>, Error> {
65	// Check that the claims are in descending order by n_vars
66	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
67		bail!(Error::ClaimsOutOfOrder);
68	}
69
70	claims
71		.iter()
72		.map(|zerocheck_claim| {
73			let &ZerocheckClaim {
74				n_vars,
75				n_multilinears,
76				ref composite_zeros,
77				..
78			} = zerocheck_claim;
79			EqIndSumcheckClaim::new(
80				n_vars,
81				n_multilinears,
82				composite_zeros
83					.iter()
84					.map(|composition| CompositeSumClaim {
85						composition,
86						sum: F::ZERO,
87					})
88					.collect(),
89			)
90		})
91		.collect()
92}
93
94#[cfg(test)]
95mod tests {
96	use std::{iter, sync::Arc};
97
98	use binius_field::{
99		BinaryField128b, BinaryField32b, BinaryField8b, PackedBinaryField1x128b, PackedExtension,
100		PackedFieldIndexable, PackedSubfield, RepackedExtension,
101	};
102	use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt};
103	use binius_hash::groestl::Groestl256;
104	use binius_math::{
105		EvaluationDomainFactory, EvaluationOrder, IsomorphicEvaluationDomainFactory,
106		MultilinearPoly,
107	};
108	use rand::{prelude::StdRng, SeedableRng};
109
110	use super::*;
111	use crate::{
112		fiat_shamir::{CanSample, HasherChallenger},
113		protocols::{
114			sumcheck::{
115				batch_verify,
116				eq_ind::{reduce_to_regular_sumchecks, verify_sumcheck_outputs, ExtraProduct},
117				prove::{batch_prove, zerocheck, RegularSumcheckProver, UnivariateZerocheck},
118				zerocheck::reduce_to_eq_ind_sumchecks,
119				BatchSumcheckOutput,
120			},
121			test_utils::{generate_zero_product_multilinears, TestProductComposition},
122		},
123		transcript::ProverTranscript,
124		transparent::eq_ind::EqIndPartialEval,
125		witness::MultilinearWitness,
126	};
127
128	fn make_regular_sumcheck_prover_for_zerocheck<'a, F, FDomain, P, Composition, M, Backend>(
129		multilinears: Vec<M>,
130		zero_claims: impl IntoIterator<Item = Composition>,
131		challenges: &[F],
132		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
133		switchover_fn: impl Fn(usize) -> usize,
134		backend: &'a Backend,
135	) -> RegularSumcheckProver<
136		'a,
137		FDomain,
138		P,
139		ExtraProduct<Composition>,
140		MultilinearWitness<'static, P>,
141		Backend,
142	>
143	where
144		F: Field,
145		FDomain: Field,
146		P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<P>,
147		Composition: CompositionPoly<P>,
148		M: MultilinearPoly<P> + Send + Sync + 'static,
149		Backend: ComputationBackend,
150	{
151		let eq_ind = EqIndPartialEval::new(challenges)
152			.multilinear_extension::<P, _>(backend)
153			.unwrap();
154
155		let multilinears = multilinears
156			.into_iter()
157			.map(|multilin| Arc::new(multilin) as Arc<dyn MultilinearPoly<_> + Send + Sync>)
158			.chain([eq_ind.specialize_arc_dyn()])
159			.collect();
160
161		let composite_sum_claims = zero_claims
162			.into_iter()
163			.map(|composition| CompositeSumClaim {
164				composition: ExtraProduct { inner: composition },
165				sum: F::ZERO,
166			});
167		RegularSumcheckProver::new(
168			EvaluationOrder::LowToHigh,
169			multilinears,
170			composite_sum_claims,
171			evaluation_domain_factory,
172			switchover_fn,
173			backend,
174		)
175		.unwrap()
176	}
177
178	fn test_compare_prover_with_reference(
179		n_vars: usize,
180		n_multilinears: usize,
181		switchover_rd: usize,
182	) {
183		type P = PackedBinaryField1x128b;
184		type FBase = BinaryField32b;
185		type FDomain = BinaryField8b;
186		let mut rng = StdRng::seed_from_u64(0);
187
188		// Setup ZC Witness
189		let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
190			&mut rng,
191			n_vars,
192			n_multilinears,
193		);
194
195		let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
196		zerocheck::validate_witness(&multilins, &binding).unwrap();
197
198		let mut prove_transcript_1 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
199		let backend = make_portable_backend();
200		let challenges = prove_transcript_1.sample_vec(n_vars);
201
202		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
203		let reference_prover = make_regular_sumcheck_prover_for_zerocheck::<_, FDomain, _, _, _, _>(
204			multilins.clone(),
205			binding.into_iter().map(|(_, composition)| composition),
206			&challenges,
207			domain_factory.clone(),
208			|_| switchover_rd,
209			&backend,
210		);
211
212		let BatchSumcheckOutput {
213			challenges: sumcheck_challenges_1,
214			multilinear_evals: multilinear_evals_1,
215		} = batch_prove(vec![reference_prover], &mut prove_transcript_1).unwrap();
216
217		let composition = TestProductComposition::new(n_multilinears);
218		let optimized_prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _, _, _>::new(
219			multilins,
220			[("test_product".into(), composition.clone(), composition)],
221			&challenges,
222			domain_factory,
223			|_| switchover_rd,
224			&backend,
225		)
226		.unwrap()
227		.into_regular_zerocheck()
228		.unwrap();
229
230		let mut prove_transcript_2 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
231		let _: Vec<BinaryField128b> = prove_transcript_2.sample_vec(n_vars);
232		let BatchSumcheckOutput {
233			challenges: sumcheck_challenges_2,
234			multilinear_evals: multilinear_evals_2,
235		} = batch_prove(vec![optimized_prover], &mut prove_transcript_2).unwrap();
236
237		assert_eq!(prove_transcript_1.finalize(), prove_transcript_2.finalize());
238		assert_eq!(multilinear_evals_1, multilinear_evals_2);
239		assert_eq!(sumcheck_challenges_1, sumcheck_challenges_2);
240	}
241
242	fn test_prove_verify_product_constraint_helper(
243		n_vars: usize,
244		n_multilinears: usize,
245		switchover_rd: usize,
246	) {
247		type P = PackedBinaryField1x128b;
248		type FBase = BinaryField32b;
249		type FE = BinaryField128b;
250		type FDomain = BinaryField8b;
251		let mut rng = StdRng::seed_from_u64(0);
252
253		let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
254			&mut rng,
255			n_vars,
256			n_multilinears,
257		);
258
259		let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
260		zerocheck::validate_witness(&multilins, &binding).unwrap();
261
262		let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
263		let challenges = prove_transcript.sample_vec(n_vars);
264
265		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
266		let backend = make_portable_backend();
267
268		let composition = TestProductComposition::new(n_multilinears);
269		let prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _, _, _>::new(
270			multilins.clone(),
271			[("test_product".into(), composition.clone(), composition)],
272			&challenges,
273			domain_factory,
274			|_| switchover_rd,
275			&backend,
276		)
277		.unwrap()
278		.into_regular_zerocheck()
279		.unwrap();
280
281		let prove_output = batch_prove(vec![prover], &mut prove_transcript).unwrap();
282
283		let claim = ZerocheckClaim::new(
284			n_vars,
285			n_multilinears,
286			vec![TestProductComposition::new(n_multilinears)],
287		)
288		.unwrap();
289		let zerocheck_claims = [claim];
290		let eq_ind_sumcheck_claims = reduce_to_eq_ind_sumchecks(&zerocheck_claims).unwrap();
291
292		let BatchSumcheckOutput {
293			challenges: prover_eval_point,
294			multilinear_evals: prover_multilinear_evals,
295		} = verify_sumcheck_outputs(
296			&eq_ind_sumcheck_claims,
297			&challenges,
298			prove_output,
299			// prover_sumcheck_multilinear_evals,
300			// &prover_sumcheck_challenges,
301		)
302		.unwrap();
303
304		let prover_sample = CanSample::<FE>::sample(&mut prove_transcript);
305		let mut verify_transcript = prove_transcript.into_verifier();
306		let _: Vec<BinaryField128b> = verify_transcript.sample_vec(n_vars);
307
308		let regular_sumcheck_claims = reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims).unwrap();
309
310		let verifier_output = batch_verify(
311			EvaluationOrder::LowToHigh,
312			&regular_sumcheck_claims,
313			&mut verify_transcript,
314		)
315		.unwrap();
316
317		let BatchSumcheckOutput {
318			challenges: verifier_eval_point,
319			multilinear_evals: verifier_multilinear_evals,
320		} = verify_sumcheck_outputs(&eq_ind_sumcheck_claims, &challenges, verifier_output).unwrap();
321
322		// Check that challengers are in the same state
323		assert_eq!(prover_sample, CanSample::<FE>::sample(&mut verify_transcript));
324		verify_transcript.finalize().unwrap();
325
326		assert_eq!(prover_eval_point, verifier_eval_point);
327		assert_eq!(prover_multilinear_evals, verifier_multilinear_evals);
328
329		assert_eq!(verifier_multilinear_evals.len(), 1);
330		assert_eq!(verifier_multilinear_evals[0].len(), n_multilinears);
331
332		// Verify the reduced multilinear evaluations are correct
333		let multilin_query = backend.multilinear_query(&verifier_eval_point).unwrap();
334		for (multilinear, &expected) in iter::zip(multilins, verifier_multilinear_evals[0].iter()) {
335			assert_eq!(multilinear.evaluate(multilin_query.to_ref()).unwrap(), expected);
336		}
337	}
338
339	#[test]
340	fn test_compare_zerocheck_prover_to_regular_sumcheck() {
341		for n_vars in 2..8 {
342			for n_multilinears in 1..5 {
343				for switchover_rd in 1..=n_vars / 2 {
344					test_compare_prover_with_reference(n_vars, n_multilinears, switchover_rd);
345				}
346			}
347		}
348	}
349
350	#[test]
351	fn test_prove_verify_product_basic() {
352		for n_vars in 2..8 {
353			for n_multilinears in 1..5 {
354				for switchover_rd in 1..=n_vars / 2 {
355					test_prove_verify_product_constraint_helper(
356						n_vars,
357						n_multilinears,
358						switchover_rd,
359					);
360				}
361			}
362		}
363	}
364}