1use std::marker::PhantomData;
4
5use binius_field::{util::eq, Field, PackedField};
6use binius_math::{ArithExpr, CompositionPoly};
7use binius_utils::{bail, sorting::is_sorted_ascending};
8use getset::CopyGetters;
9
10use super::error::{Error, VerificationError};
11use crate::protocols::sumcheck::{BatchSumcheckOutput, CompositeSumClaim, SumcheckClaim};
12
13#[derive(Debug, CopyGetters)]
14pub struct ZerocheckClaim<F: Field, Composition> {
15 #[getset(get_copy = "pub")]
16 n_vars: usize,
17 #[getset(get_copy = "pub")]
18 n_multilinears: usize,
19 composite_zeros: Vec<Composition>,
20 _marker: PhantomData<F>,
21}
22
23impl<F: Field, Composition> ZerocheckClaim<F, Composition>
24where
25 Composition: CompositionPoly<F>,
26{
27 pub fn new(
28 n_vars: usize,
29 n_multilinears: usize,
30 composite_zeros: Vec<Composition>,
31 ) -> Result<Self, Error> {
32 for composition in &composite_zeros {
33 if composition.n_vars() != n_multilinears {
34 bail!(Error::InvalidComposition {
35 actual: composition.n_vars(),
36 expected: n_multilinears,
37 });
38 }
39 }
40 Ok(Self {
41 n_vars,
42 n_multilinears,
43 composite_zeros,
44 _marker: PhantomData,
45 })
46 }
47
48 pub fn max_individual_degree(&self) -> usize {
50 self.composite_zeros
51 .iter()
52 .map(|composite_zero| composite_zero.degree())
53 .max()
54 .unwrap_or(0)
55 }
56
57 pub fn composite_zeros(&self) -> &[Composition] {
58 &self.composite_zeros
59 }
60}
61
62pub fn reduce_to_sumchecks<F: Field, Composition: CompositionPoly<F>>(
64 claims: &[ZerocheckClaim<F, Composition>],
65) -> Result<Vec<SumcheckClaim<F, ExtraProduct<&Composition>>>, Error> {
66 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
68 bail!(Error::ClaimsOutOfOrder);
69 }
70
71 claims
72 .iter()
73 .map(|zerocheck_claim| {
74 let ZerocheckClaim {
75 n_vars,
76 n_multilinears,
77 composite_zeros,
78 ..
79 } = zerocheck_claim;
80 SumcheckClaim::new(
81 *n_vars,
82 *n_multilinears + 1,
83 composite_zeros
84 .iter()
85 .map(|composition| CompositeSumClaim {
86 composition: ExtraProduct { inner: composition },
87 sum: F::ZERO,
88 })
89 .collect(),
90 )
91 })
92 .collect()
93}
94
95pub fn verify_sumcheck_outputs<F: Field, Composition: CompositionPoly<F>>(
104 claims: &[ZerocheckClaim<F, Composition>],
105 zerocheck_challenges: &[F],
106 sumcheck_output: BatchSumcheckOutput<F>,
107) -> Result<BatchSumcheckOutput<F>, Error> {
108 let BatchSumcheckOutput {
109 challenges: sumcheck_challenges,
110 mut multilinear_evals,
111 } = sumcheck_output;
112
113 assert_eq!(multilinear_evals.len(), claims.len());
114
115 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
117 bail!(Error::ClaimsOutOfOrder);
118 }
119
120 let max_n_vars = claims
121 .first()
122 .map(|claim| claim.n_vars())
123 .unwrap_or_default();
124
125 assert!(sumcheck_challenges.len() <= max_n_vars);
126 assert_eq!(zerocheck_challenges.len(), sumcheck_challenges.len());
127
128 let mut eq_ind_eval = F::ONE;
129 let mut last_n_vars = 0;
130 for (claim, multilinear_evals) in claims.iter().zip(multilinear_evals.iter_mut()).rev() {
131 assert_eq!(claim.n_multilinears() + 1, multilinear_evals.len());
132
133 while last_n_vars < claim.n_vars() && last_n_vars < sumcheck_challenges.len() {
134 let sumcheck_challenge =
135 sumcheck_challenges[sumcheck_challenges.len() - 1 - last_n_vars];
136 let zerocheck_challenge =
137 zerocheck_challenges[zerocheck_challenges.len() - 1 - last_n_vars];
138 eq_ind_eval *= eq(sumcheck_challenge, zerocheck_challenge);
139 last_n_vars += 1;
140 }
141
142 let multilinear_evals_last = multilinear_evals
143 .pop()
144 .expect("checked above that multilinear_evals length is at least 1");
145 if eq_ind_eval != multilinear_evals_last {
146 return Err(VerificationError::IncorrectEqIndEvaluation.into());
147 }
148 }
149
150 Ok(BatchSumcheckOutput {
151 challenges: sumcheck_challenges,
152 multilinear_evals,
153 })
154}
155
156#[derive(Debug)]
157pub struct ExtraProduct<Composition> {
158 pub inner: Composition,
159}
160
161impl<P, Composition> CompositionPoly<P> for ExtraProduct<Composition>
162where
163 P: PackedField,
164 Composition: CompositionPoly<P>,
165{
166 fn n_vars(&self) -> usize {
167 self.inner.n_vars() + 1
168 }
169
170 fn degree(&self) -> usize {
171 self.inner.degree() + 1
172 }
173
174 fn expression(&self) -> ArithExpr<P::Scalar> {
175 self.inner.expression() * ArithExpr::Var(self.inner.n_vars())
176 }
177
178 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
179 let n_vars = self.n_vars();
180 if query.len() != n_vars {
181 bail!(binius_math::Error::IncorrectQuerySize { expected: n_vars });
182 }
183
184 let inner_eval = self.inner.evaluate(&query[..n_vars - 1])?;
185 Ok(inner_eval * query[n_vars - 1])
186 }
187
188 fn binary_tower_level(&self) -> usize {
189 self.inner.binary_tower_level()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use std::{iter, sync::Arc};
196
197 use binius_field::{
198 BinaryField128b, BinaryField32b, BinaryField8b, PackedBinaryField1x128b, PackedExtension,
199 PackedFieldIndexable, PackedSubfield, RepackedExtension,
200 };
201 use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt};
202 use binius_math::{
203 EvaluationDomainFactory, EvaluationOrder, IsomorphicEvaluationDomainFactory,
204 MultilinearPoly,
205 };
206 use groestl_crypto::Groestl256;
207 use rand::{prelude::StdRng, SeedableRng};
208
209 use super::*;
210 use crate::{
211 fiat_shamir::{CanSample, HasherChallenger},
212 protocols::{
213 sumcheck::{
214 batch_verify,
215 prove::{batch_prove, zerocheck, RegularSumcheckProver, UnivariateZerocheck},
216 },
217 test_utils::{generate_zero_product_multilinears, TestProductComposition},
218 },
219 transcript::ProverTranscript,
220 transparent::eq_ind::EqIndPartialEval,
221 witness::MultilinearWitness,
222 };
223
224 fn make_regular_sumcheck_prover_for_zerocheck<'a, F, FDomain, P, Composition, M, Backend>(
225 multilinears: Vec<M>,
226 zero_claims: impl IntoIterator<Item = Composition>,
227 challenges: &[F],
228 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
229 switchover_fn: impl Fn(usize) -> usize,
230 backend: &'a Backend,
231 ) -> RegularSumcheckProver<
232 'a,
233 FDomain,
234 P,
235 ExtraProduct<Composition>,
236 MultilinearWitness<'static, P>,
237 Backend,
238 >
239 where
240 F: Field,
241 FDomain: Field,
242 P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<P>,
243 Composition: CompositionPoly<P>,
244 M: MultilinearPoly<P> + Send + Sync + 'static,
245 Backend: ComputationBackend,
246 {
247 let eq_ind = EqIndPartialEval::new(challenges)
248 .multilinear_extension::<P, _>(backend)
249 .unwrap();
250
251 let multilinears = multilinears
252 .into_iter()
253 .map(|multilin| Arc::new(multilin) as Arc<dyn MultilinearPoly<_> + Send + Sync>)
254 .chain([eq_ind.specialize_arc_dyn()])
255 .collect();
256
257 let composite_sum_claims = zero_claims
258 .into_iter()
259 .map(|composition| CompositeSumClaim {
260 composition: ExtraProduct { inner: composition },
261 sum: F::ZERO,
262 });
263 RegularSumcheckProver::new(
264 EvaluationOrder::LowToHigh,
265 multilinears,
266 composite_sum_claims,
267 evaluation_domain_factory,
268 switchover_fn,
269 backend,
270 )
271 .unwrap()
272 }
273
274 fn test_compare_prover_with_reference(
275 n_vars: usize,
276 n_multilinears: usize,
277 switchover_rd: usize,
278 ) {
279 type P = PackedBinaryField1x128b;
280 type FBase = BinaryField32b;
281 type FDomain = BinaryField8b;
282 let mut rng = StdRng::seed_from_u64(0);
283
284 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
286 &mut rng,
287 n_vars,
288 n_multilinears,
289 );
290
291 let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
292 zerocheck::validate_witness(&multilins, &binding).unwrap();
293
294 let mut prove_transcript_1 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
295 let backend = make_portable_backend();
296 let challenges = prove_transcript_1.sample_vec(n_vars);
297
298 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
299 let reference_prover = make_regular_sumcheck_prover_for_zerocheck::<_, FDomain, _, _, _, _>(
300 multilins.clone(),
301 binding.into_iter().map(|(_, composition)| composition),
302 &challenges,
303 domain_factory.clone(),
304 |_| switchover_rd,
305 &backend,
306 );
307
308 let BatchSumcheckOutput {
309 challenges: sumcheck_challenges_1,
310 multilinear_evals: multilinear_evals_1,
311 } = batch_prove(vec![reference_prover], &mut prove_transcript_1).unwrap();
312
313 let composition = TestProductComposition::new(n_multilinears);
314 let optimized_prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _>::new(
315 multilins,
316 [("test_product".into(), composition.clone(), composition)],
317 &challenges,
318 domain_factory,
319 |_| switchover_rd,
320 &backend,
321 )
322 .unwrap()
323 .into_regular_zerocheck()
324 .unwrap();
325
326 let mut prove_transcript_2 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
327 let _: Vec<BinaryField128b> = prove_transcript_2.sample_vec(n_vars);
328 let BatchSumcheckOutput {
329 challenges: sumcheck_challenges_2,
330 multilinear_evals: multilinear_evals_2,
331 } = batch_prove(vec![optimized_prover], &mut prove_transcript_2).unwrap();
332
333 assert_eq!(prove_transcript_1.finalize(), prove_transcript_2.finalize());
334 assert_eq!(multilinear_evals_1, multilinear_evals_2);
335 assert_eq!(sumcheck_challenges_1, sumcheck_challenges_2);
336 }
337
338 fn test_prove_verify_product_constraint_helper(
339 n_vars: usize,
340 n_multilinears: usize,
341 switchover_rd: usize,
342 ) {
343 type P = PackedBinaryField1x128b;
344 type FBase = BinaryField32b;
345 type FE = BinaryField128b;
346 type FDomain = BinaryField8b;
347 let mut rng = StdRng::seed_from_u64(0);
348
349 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
350 &mut rng,
351 n_vars,
352 n_multilinears,
353 );
354
355 let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
356 zerocheck::validate_witness(&multilins, &binding).unwrap();
357
358 let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
359 let challenges = prove_transcript.sample_vec(n_vars);
360
361 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
362 let backend = make_portable_backend();
363
364 let composition = TestProductComposition::new(n_multilinears);
365 let prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _>::new(
366 multilins.clone(),
367 [("test_product".into(), composition.clone(), composition)],
368 &challenges,
369 domain_factory,
370 |_| switchover_rd,
371 &backend,
372 )
373 .unwrap()
374 .into_regular_zerocheck()
375 .unwrap();
376
377 let prove_output = batch_prove(vec![prover], &mut prove_transcript).unwrap();
378
379 let claim = ZerocheckClaim::new(
380 n_vars,
381 n_multilinears,
382 vec![TestProductComposition::new(n_multilinears)],
383 )
384 .unwrap();
385 let zerocheck_claims = [claim];
386 let BatchSumcheckOutput {
387 challenges: prover_eval_point,
388 multilinear_evals: prover_multilinear_evals,
389 } = verify_sumcheck_outputs(
390 &zerocheck_claims,
391 &challenges,
392 prove_output,
393 )
396 .unwrap();
397
398 let prover_sample = CanSample::<FE>::sample(&mut prove_transcript);
399 let mut verify_transcript = prove_transcript.into_verifier();
400 let _: Vec<BinaryField128b> = verify_transcript.sample_vec(n_vars);
401
402 let sumcheck_claims = reduce_to_sumchecks(&zerocheck_claims).unwrap();
403 let verifier_output =
404 batch_verify(EvaluationOrder::LowToHigh, &sumcheck_claims, &mut verify_transcript)
405 .unwrap();
406
407 let BatchSumcheckOutput {
408 challenges: verifier_eval_point,
409 multilinear_evals: verifier_multilinear_evals,
410 } = verify_sumcheck_outputs(&zerocheck_claims, &challenges, verifier_output).unwrap();
411
412 assert_eq!(prover_sample, CanSample::<FE>::sample(&mut verify_transcript));
414 verify_transcript.finalize().unwrap();
415
416 assert_eq!(prover_eval_point, verifier_eval_point);
417 assert_eq!(prover_multilinear_evals, verifier_multilinear_evals);
418
419 assert_eq!(verifier_multilinear_evals.len(), 1);
420 assert_eq!(verifier_multilinear_evals[0].len(), n_multilinears);
421
422 let multilin_query = backend.multilinear_query(&verifier_eval_point).unwrap();
424 for (multilinear, &expected) in iter::zip(multilins, verifier_multilinear_evals[0].iter()) {
425 assert_eq!(multilinear.evaluate(multilin_query.to_ref()).unwrap(), expected);
426 }
427 }
428
429 #[test]
430 fn test_compare_zerocheck_prover_to_regular_sumcheck() {
431 for n_vars in 2..8 {
432 for n_multilinears in 1..5 {
433 for switchover_rd in 1..=n_vars / 2 {
434 test_compare_prover_with_reference(n_vars, n_multilinears, switchover_rd);
435 }
436 }
437 }
438 }
439
440 #[test]
441 fn test_prove_verify_product_basic() {
442 for n_vars in 2..8 {
443 for n_multilinears in 1..5 {
444 for switchover_rd in 1..=n_vars / 2 {
445 test_prove_verify_product_constraint_helper(
446 n_vars,
447 n_multilinears,
448 switchover_rd,
449 );
450 }
451 }
452 }
453 }
454}