1use std::marker::PhantomData;
4
5use binius_field::Field;
6use binius_math::CompositionPoly;
7use binius_utils::{bail, sorting::is_sorted_ascending};
8use getset::CopyGetters;
9
10use super::error::Error;
11use crate::protocols::sumcheck::{eq_ind::EqIndSumcheckClaim, CompositeSumClaim};
12
13#[derive(Debug, CopyGetters)]
14pub struct ZerocheckClaim<F: Field, Composition> {
15 #[getset(get_copy = "pub")]
16 n_vars: usize,
17 #[getset(get_copy = "pub")]
18 n_multilinears: usize,
19 composite_zeros: Vec<Composition>,
20 _marker: PhantomData<F>,
21}
22
23impl<F: Field, Composition> ZerocheckClaim<F, Composition>
24where
25 Composition: CompositionPoly<F>,
26{
27 pub fn new(
28 n_vars: usize,
29 n_multilinears: usize,
30 composite_zeros: Vec<Composition>,
31 ) -> Result<Self, Error> {
32 for composition in &composite_zeros {
33 if composition.n_vars() != n_multilinears {
34 bail!(Error::InvalidComposition {
35 actual: composition.n_vars(),
36 expected: n_multilinears,
37 });
38 }
39 }
40 Ok(Self {
41 n_vars,
42 n_multilinears,
43 composite_zeros,
44 _marker: PhantomData,
45 })
46 }
47
48 pub fn max_individual_degree(&self) -> usize {
50 self.composite_zeros
51 .iter()
52 .map(|composite_zero| composite_zero.degree())
53 .max()
54 .unwrap_or(0)
55 }
56
57 pub fn composite_zeros(&self) -> &[Composition] {
58 &self.composite_zeros
59 }
60}
61
62pub fn reduce_to_eq_ind_sumchecks<F: Field, Composition: CompositionPoly<F>>(
63 claims: &[ZerocheckClaim<F, Composition>],
64) -> Result<Vec<EqIndSumcheckClaim<F, &Composition>>, Error> {
65 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
67 bail!(Error::ClaimsOutOfOrder);
68 }
69
70 claims
71 .iter()
72 .map(|zerocheck_claim| {
73 let &ZerocheckClaim {
74 n_vars,
75 n_multilinears,
76 ref composite_zeros,
77 ..
78 } = zerocheck_claim;
79 EqIndSumcheckClaim::new(
80 n_vars,
81 n_multilinears,
82 composite_zeros
83 .iter()
84 .map(|composition| CompositeSumClaim {
85 composition,
86 sum: F::ZERO,
87 })
88 .collect(),
89 )
90 })
91 .collect()
92}
93
94#[cfg(test)]
95mod tests {
96 use std::{iter, sync::Arc};
97
98 use binius_field::{
99 BinaryField128b, BinaryField32b, BinaryField8b, PackedBinaryField1x128b, PackedExtension,
100 PackedFieldIndexable, PackedSubfield, RepackedExtension,
101 };
102 use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt};
103 use binius_hash::groestl::Groestl256;
104 use binius_math::{
105 EvaluationDomainFactory, EvaluationOrder, IsomorphicEvaluationDomainFactory,
106 MultilinearPoly,
107 };
108 use rand::{prelude::StdRng, SeedableRng};
109
110 use super::*;
111 use crate::{
112 fiat_shamir::{CanSample, HasherChallenger},
113 protocols::{
114 sumcheck::{
115 batch_verify,
116 eq_ind::{reduce_to_regular_sumchecks, verify_sumcheck_outputs, ExtraProduct},
117 prove::{batch_prove, zerocheck, RegularSumcheckProver, UnivariateZerocheck},
118 zerocheck::reduce_to_eq_ind_sumchecks,
119 BatchSumcheckOutput,
120 },
121 test_utils::{generate_zero_product_multilinears, TestProductComposition},
122 },
123 transcript::ProverTranscript,
124 transparent::eq_ind::EqIndPartialEval,
125 witness::MultilinearWitness,
126 };
127
128 fn make_regular_sumcheck_prover_for_zerocheck<'a, F, FDomain, P, Composition, M, Backend>(
129 multilinears: Vec<M>,
130 zero_claims: impl IntoIterator<Item = Composition>,
131 challenges: &[F],
132 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
133 switchover_fn: impl Fn(usize) -> usize,
134 backend: &'a Backend,
135 ) -> RegularSumcheckProver<
136 'a,
137 FDomain,
138 P,
139 ExtraProduct<Composition>,
140 MultilinearWitness<'static, P>,
141 Backend,
142 >
143 where
144 F: Field,
145 FDomain: Field,
146 P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<P>,
147 Composition: CompositionPoly<P>,
148 M: MultilinearPoly<P> + Send + Sync + 'static,
149 Backend: ComputationBackend,
150 {
151 let eq_ind = EqIndPartialEval::new(challenges)
152 .multilinear_extension::<P, _>(backend)
153 .unwrap();
154
155 let multilinears = multilinears
156 .into_iter()
157 .map(|multilin| Arc::new(multilin) as Arc<dyn MultilinearPoly<_> + Send + Sync>)
158 .chain([eq_ind.specialize_arc_dyn()])
159 .collect();
160
161 let composite_sum_claims = zero_claims
162 .into_iter()
163 .map(|composition| CompositeSumClaim {
164 composition: ExtraProduct { inner: composition },
165 sum: F::ZERO,
166 });
167 RegularSumcheckProver::new(
168 EvaluationOrder::LowToHigh,
169 multilinears,
170 composite_sum_claims,
171 evaluation_domain_factory,
172 switchover_fn,
173 backend,
174 )
175 .unwrap()
176 }
177
178 fn test_compare_prover_with_reference(
179 n_vars: usize,
180 n_multilinears: usize,
181 switchover_rd: usize,
182 ) {
183 type P = PackedBinaryField1x128b;
184 type FBase = BinaryField32b;
185 type FDomain = BinaryField8b;
186 let mut rng = StdRng::seed_from_u64(0);
187
188 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
190 &mut rng,
191 n_vars,
192 n_multilinears,
193 );
194
195 let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
196 zerocheck::validate_witness(&multilins, &binding).unwrap();
197
198 let mut prove_transcript_1 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
199 let backend = make_portable_backend();
200 let challenges = prove_transcript_1.sample_vec(n_vars);
201
202 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
203 let reference_prover = make_regular_sumcheck_prover_for_zerocheck::<_, FDomain, _, _, _, _>(
204 multilins.clone(),
205 binding.into_iter().map(|(_, composition)| composition),
206 &challenges,
207 domain_factory.clone(),
208 |_| switchover_rd,
209 &backend,
210 );
211
212 let BatchSumcheckOutput {
213 challenges: sumcheck_challenges_1,
214 multilinear_evals: multilinear_evals_1,
215 } = batch_prove(vec![reference_prover], &mut prove_transcript_1).unwrap();
216
217 let composition = TestProductComposition::new(n_multilinears);
218 let optimized_prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _, _, _>::new(
219 multilins,
220 [("test_product".into(), composition.clone(), composition)],
221 &challenges,
222 domain_factory,
223 |_| switchover_rd,
224 &backend,
225 )
226 .unwrap()
227 .into_regular_zerocheck()
228 .unwrap();
229
230 let mut prove_transcript_2 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
231 let _: Vec<BinaryField128b> = prove_transcript_2.sample_vec(n_vars);
232 let BatchSumcheckOutput {
233 challenges: sumcheck_challenges_2,
234 multilinear_evals: multilinear_evals_2,
235 } = batch_prove(vec![optimized_prover], &mut prove_transcript_2).unwrap();
236
237 assert_eq!(prove_transcript_1.finalize(), prove_transcript_2.finalize());
238 assert_eq!(multilinear_evals_1, multilinear_evals_2);
239 assert_eq!(sumcheck_challenges_1, sumcheck_challenges_2);
240 }
241
242 fn test_prove_verify_product_constraint_helper(
243 n_vars: usize,
244 n_multilinears: usize,
245 switchover_rd: usize,
246 ) {
247 type P = PackedBinaryField1x128b;
248 type FBase = BinaryField32b;
249 type FE = BinaryField128b;
250 type FDomain = BinaryField8b;
251 let mut rng = StdRng::seed_from_u64(0);
252
253 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
254 &mut rng,
255 n_vars,
256 n_multilinears,
257 );
258
259 let binding = [("test_product".into(), TestProductComposition::new(n_multilinears))];
260 zerocheck::validate_witness(&multilins, &binding).unwrap();
261
262 let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
263 let challenges = prove_transcript.sample_vec(n_vars);
264
265 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
266 let backend = make_portable_backend();
267
268 let composition = TestProductComposition::new(n_multilinears);
269 let prover = UnivariateZerocheck::<FDomain, FBase, P, _, _, _, _, _, _>::new(
270 multilins.clone(),
271 [("test_product".into(), composition.clone(), composition)],
272 &challenges,
273 domain_factory,
274 |_| switchover_rd,
275 &backend,
276 )
277 .unwrap()
278 .into_regular_zerocheck()
279 .unwrap();
280
281 let prove_output = batch_prove(vec![prover], &mut prove_transcript).unwrap();
282
283 let claim = ZerocheckClaim::new(
284 n_vars,
285 n_multilinears,
286 vec![TestProductComposition::new(n_multilinears)],
287 )
288 .unwrap();
289 let zerocheck_claims = [claim];
290 let eq_ind_sumcheck_claims = reduce_to_eq_ind_sumchecks(&zerocheck_claims).unwrap();
291
292 let BatchSumcheckOutput {
293 challenges: prover_eval_point,
294 multilinear_evals: prover_multilinear_evals,
295 } = verify_sumcheck_outputs(
296 &eq_ind_sumcheck_claims,
297 &challenges,
298 prove_output,
299 )
302 .unwrap();
303
304 let prover_sample = CanSample::<FE>::sample(&mut prove_transcript);
305 let mut verify_transcript = prove_transcript.into_verifier();
306 let _: Vec<BinaryField128b> = verify_transcript.sample_vec(n_vars);
307
308 let regular_sumcheck_claims = reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims).unwrap();
309
310 let verifier_output = batch_verify(
311 EvaluationOrder::LowToHigh,
312 ®ular_sumcheck_claims,
313 &mut verify_transcript,
314 )
315 .unwrap();
316
317 let BatchSumcheckOutput {
318 challenges: verifier_eval_point,
319 multilinear_evals: verifier_multilinear_evals,
320 } = verify_sumcheck_outputs(&eq_ind_sumcheck_claims, &challenges, verifier_output).unwrap();
321
322 assert_eq!(prover_sample, CanSample::<FE>::sample(&mut verify_transcript));
324 verify_transcript.finalize().unwrap();
325
326 assert_eq!(prover_eval_point, verifier_eval_point);
327 assert_eq!(prover_multilinear_evals, verifier_multilinear_evals);
328
329 assert_eq!(verifier_multilinear_evals.len(), 1);
330 assert_eq!(verifier_multilinear_evals[0].len(), n_multilinears);
331
332 let multilin_query = backend.multilinear_query(&verifier_eval_point).unwrap();
334 for (multilinear, &expected) in iter::zip(multilins, verifier_multilinear_evals[0].iter()) {
335 assert_eq!(multilinear.evaluate(multilin_query.to_ref()).unwrap(), expected);
336 }
337 }
338
339 #[test]
340 fn test_compare_zerocheck_prover_to_regular_sumcheck() {
341 for n_vars in 2..8 {
342 for n_multilinears in 1..5 {
343 for switchover_rd in 1..=n_vars / 2 {
344 test_compare_prover_with_reference(n_vars, n_multilinears, switchover_rd);
345 }
346 }
347 }
348 }
349
350 #[test]
351 fn test_prove_verify_product_basic() {
352 for n_vars in 2..8 {
353 for n_multilinears in 1..5 {
354 for switchover_rd in 1..=n_vars / 2 {
355 test_prove_verify_product_constraint_helper(
356 n_vars,
357 n_multilinears,
358 switchover_rd,
359 );
360 }
361 }
362 }
363 }
364}