1use std::{
4 marker::PhantomData,
5 ops::{Mul, MulAssign},
6};
7
8use binius_field::{ExtensionField, Field, PackedField, TowerField, packed::set_packed_slice};
9use binius_hal::{ComputationBackendExt, make_portable_backend};
10use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearExtension};
11use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
12use bytemuck::zeroed_vec;
13use getset::CopyGetters;
14use itertools::izip;
15
16use super::error::Error;
17use crate::{
18 composition::{BivariateProduct, IndexComposition},
19 polynomial::Error as PolynomialError,
20 protocols::sumcheck::{
21 BatchSumcheckOutput, CompositeSumClaim, SumcheckClaim, VerificationError,
22 eq_ind::EqIndSumcheckClaim,
23 },
24};
25
26#[derive(Debug, CopyGetters)]
27pub struct ZerocheckClaim<F: Field, Composition> {
28 #[getset(get_copy = "pub")]
29 n_vars: usize,
30 #[getset(get_copy = "pub")]
31 n_multilinears: usize,
32 composite_zeros: Vec<Composition>,
33 _marker: PhantomData<F>,
34}
35
36impl<F: Field, Composition> ZerocheckClaim<F, Composition>
37where
38 Composition: CompositionPoly<F>,
39{
40 pub fn new(
41 n_vars: usize,
42 n_multilinears: usize,
43 composite_zeros: Vec<Composition>,
44 ) -> Result<Self, Error> {
45 for composition in &composite_zeros {
46 if composition.n_vars() != n_multilinears {
47 bail!(Error::InvalidComposition {
48 actual: composition.n_vars(),
49 expected: n_multilinears,
50 });
51 }
52 }
53 Ok(Self {
54 n_vars,
55 n_multilinears,
56 composite_zeros,
57 _marker: PhantomData,
58 })
59 }
60
61 pub fn max_individual_degree(&self) -> usize {
63 self.composite_zeros
64 .iter()
65 .map(|composite_zero| composite_zero.degree())
66 .max()
67 .unwrap_or(0)
68 }
69
70 pub fn composite_zeros(&self) -> &[Composition] {
71 &self.composite_zeros
72 }
73}
74
75#[derive(Clone, Debug)]
81pub struct ZerocheckRoundEvals<F: Field> {
82 pub evals: Vec<F>,
83}
84
85impl<F: Field> ZerocheckRoundEvals<F> {
86 pub fn zeros(len: usize) -> Self {
88 Self {
89 evals: vec![F::ZERO; len],
90 }
91 }
92
93 pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
96 if self.evals.len() != rhs.evals.len() {
97 bail!(Error::LagrangeRoundEvalsSizeMismatch);
98 }
99
100 for (lhs, rhs) in izip!(&mut self.evals, &rhs.evals) {
101 *lhs += rhs;
102 }
103
104 Ok(())
105 }
106}
107
108impl<F: Field> Mul<F> for ZerocheckRoundEvals<F> {
109 type Output = Self;
110
111 fn mul(mut self, rhs: F) -> Self::Output {
112 self *= rhs;
113 self
114 }
115}
116
117impl<F: Field> MulAssign<F> for ZerocheckRoundEvals<F> {
118 fn mul_assign(&mut self, rhs: F) {
119 for eval in &mut self.evals {
120 *eval *= rhs;
121 }
122 }
123}
124
125pub const fn domain_size(composition_degree: usize, skip_rounds: usize) -> usize {
132 composition_degree << skip_rounds
133}
134
135pub const fn extrapolated_scalars_count(composition_degree: usize, skip_rounds: usize) -> usize {
137 composition_degree.saturating_sub(1) << skip_rounds
138}
139
140pub struct BatchZerocheckOutput<F: Field> {
142 pub skipped_challenges: Vec<F>,
145 pub unskipped_challenges: Vec<F>,
148 pub concat_multilinear_evals: Vec<F>,
150}
151
152pub fn reduce_to_eq_ind_sumchecks<F: Field, Composition: CompositionPoly<F>>(
158 skip_rounds: usize,
159 claims: &[ZerocheckClaim<F, Composition>],
160) -> Result<Vec<EqIndSumcheckClaim<F, &Composition>>, Error> {
161 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
163 bail!(Error::ClaimsOutOfOrder);
164 }
165
166 claims
167 .iter()
168 .map(|zerocheck_claim| {
169 let &ZerocheckClaim {
170 n_vars,
171 n_multilinears,
172 ref composite_zeros,
173 ..
174 } = zerocheck_claim;
175 EqIndSumcheckClaim::new(
176 n_vars.saturating_sub(skip_rounds),
177 n_multilinears,
178 composite_zeros
179 .iter()
180 .map(|composition| CompositeSumClaim {
181 composition,
182 sum: F::ZERO,
183 })
184 .collect(),
185 )
186 })
187 .collect()
188}
189
190pub fn univariatizing_reduction_claim<F: Field>(
200 skip_rounds: usize,
201 univariatized_multilinear_evals: &[impl AsRef<[F]>],
202) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
203 let n_multilinears = univariatized_multilinear_evals
204 .iter()
205 .map(|claim_evals| claim_evals.as_ref().len())
206 .sum();
207
208 let composite_sums = univariatized_multilinear_evals
211 .iter()
212 .flat_map(|claim_evals| claim_evals.as_ref())
213 .enumerate()
214 .map(|(i, &univariatized_multilinear_eval)| {
215 let composition =
216 IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
217 .expect("index composition indice correct by construction");
218
219 CompositeSumClaim {
220 composition,
221 sum: univariatized_multilinear_eval,
222 }
223 })
224 .collect();
225
226 SumcheckClaim::new(skip_rounds, n_multilinears + 1, composite_sums)
227}
228
229pub fn verify_reduction_sumcheck_output<F>(
236 claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
237 skip_rounds: usize,
238 univariate_challenge: F,
239 reduction_sumcheck_output: BatchSumcheckOutput<F>,
240) -> Result<BatchSumcheckOutput<F>, Error>
241where
242 F: TowerField,
243{
244 let BatchSumcheckOutput {
245 challenges: reduction_sumcheck_challenges,
246 mut multilinear_evals,
247 } = reduction_sumcheck_output;
248
249 if claim.n_vars() != skip_rounds {
251 bail!(Error::IncorrectUnivariatizingReductionClaims);
252 }
253
254 if reduction_sumcheck_challenges.len() != skip_rounds || multilinear_evals.len() != 1 {
256 bail!(Error::IncorrectUnivariatizingReductionSumcheck);
257 }
258
259 let subspace = BinarySubspace::<F::Canonical>::with_dim(skip_rounds)?.isomorphic::<F>();
261 let evaluation_domain =
262 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
263
264 let lagrange_mle =
265 lagrange_evals_multilinear_extension::<F, F, F>(&evaluation_domain, univariate_challenge)?;
266
267 let query = make_portable_backend().multilinear_query::<F>(&reduction_sumcheck_challenges)?;
268 let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
269
270 let first_claim_multilinear_evals = multilinear_evals
271 .first_mut()
272 .expect("exactly one claim in reduction sumcheck");
273
274 let multilinear_evals_last_eval = first_claim_multilinear_evals
276 .pop()
277 .ok_or(VerificationError::NumberOfFinalEvaluations)?;
278
279 if multilinear_evals_last_eval != expected_last_eval {
280 bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
281 }
282
283 let output = BatchSumcheckOutput {
284 challenges: reduction_sumcheck_challenges,
285 multilinear_evals,
286 };
287
288 Ok(output)
289}
290
291pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
294 evaluation_domain: &EvaluationDomain<FDomain>,
295 univariate_challenge: F,
296) -> Result<MultilinearExtension<P>, PolynomialError>
297where
298 FDomain: Field,
299 F: Field + ExtensionField<FDomain>,
300 P: PackedField<Scalar = F>,
301{
302 let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
303
304 let n_vars = log2_strict_usize(lagrange_evals.len());
305 let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
306
307 for (i, &lagrange_eval) in lagrange_evals.iter().enumerate() {
308 set_packed_slice(&mut packed, i, lagrange_eval);
309 }
310
311 Ok(MultilinearExtension::new(n_vars, packed)?)
312}
313
314#[cfg(test)]
315mod tests {
316 use std::sync::Arc;
317
318 use binius_field::{
319 AESTowerField8b, AESTowerField16b, AESTowerField128b, BinaryField8b, BinaryField16b,
320 BinaryField128b, ByteSlicedAES64x128b,
321 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
322 as_packed_field::{PackScalar, PackedType},
323 underlier::{UnderlierType, WithUnderlier},
324 };
325 use binius_hal::make_portable_backend;
326 use binius_hash::groestl::Groestl256;
327 use binius_math::IsomorphicEvaluationDomainFactory;
328 use rand::{SeedableRng, prelude::StdRng};
329
330 use super::*;
331 use crate::{
332 composition::ProductComposition,
333 fiat_shamir::{CanSample, HasherChallenger},
334 polynomial::CompositionScalarAdapter,
335 protocols::{
336 sumcheck::{self, prove::ZerocheckProverImpl},
337 test_utils::generate_zero_product_multilinears,
338 },
339 transcript::ProverTranscript,
340 };
341
342 fn test_zerocheck_end_to_end_helper<U, F, FDomain, FBase, FWitness>()
343 where
344 U: UnderlierType
345 + PackScalar<F>
346 + PackScalar<FBase>
347 + PackScalar<FDomain>
348 + PackScalar<FWitness>,
349 F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
350 FBase: TowerField + ExtensionField<FDomain>,
351 FDomain: TowerField,
352 FWitness: Field,
353 {
354 let max_n_vars = 6;
355 let n_multilinears = 9;
356
357 let backend = make_portable_backend();
358 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
359 let mut rng = StdRng::seed_from_u64(0);
360
361 let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
362 let triple =
363 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
364 let quad =
365 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
366
367 let prover_compositions = [
368 (
369 "pair".into(),
370 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
371 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
372 ),
373 (
374 "triple".into(),
375 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
376 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
377 ),
378 (
379 "quad".into(),
380 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
381 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
382 ),
383 ];
384
385 let prover_adapter_compositions = [
386 CompositionScalarAdapter::new(pair as Arc<dyn CompositionPoly<F>>),
387 CompositionScalarAdapter::new(triple as Arc<dyn CompositionPoly<F>>),
388 CompositionScalarAdapter::new(quad as Arc<dyn CompositionPoly<F>>),
389 ];
390
391 for skip_rounds in 0..=max_n_vars {
392 let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
393
394 let prover_zerocheck_challenges: Vec<F> = proof.sample_vec(max_n_vars - skip_rounds);
395
396 let mut zerocheck_claims = Vec::new();
397 let mut zerocheck_provers = Vec::new();
398 for n_vars in 1..=max_n_vars {
399 let mut multilinears = generate_zero_product_multilinears::<
400 PackedType<U, FWitness>,
401 PackedType<U, F>,
402 >(&mut rng, n_vars, 2);
403 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
404 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
405
406 let claim = ZerocheckClaim::<F, _>::new(
407 n_vars,
408 n_multilinears,
409 prover_adapter_compositions.to_vec(),
410 )
411 .unwrap();
412
413 let prover =
414 ZerocheckProverImpl::<FDomain, FBase, PackedType<U, F>, _, _, _, _, _>::new(
415 multilinears,
416 prover_compositions.to_vec(),
417 &prover_zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..],
418 domain_factory.clone(),
419 &backend,
420 )
421 .unwrap();
422
423 zerocheck_claims.push(claim);
424 zerocheck_provers.push(prover);
425 }
426
427 let prover_zerocheck_output =
428 sumcheck::prove::batch_prove_zerocheck::<F, FDomain, PackedType<U, F>, _, _>(
429 zerocheck_provers,
430 skip_rounds,
431 &mut proof,
432 )
433 .unwrap();
434
435 let mut verifier_proof = proof.into_verifier();
436
437 let verifier_zerocheck_output = sumcheck::batch_verify_zerocheck(
438 &zerocheck_claims,
439 skip_rounds,
440 &mut verifier_proof,
441 )
442 .unwrap();
443
444 verifier_proof.finalize().unwrap();
445
446 assert_eq!(
447 prover_zerocheck_output.skipped_challenges,
448 verifier_zerocheck_output.skipped_challenges
449 );
450 assert_eq!(
451 prover_zerocheck_output.unskipped_challenges,
452 verifier_zerocheck_output.unskipped_challenges
453 );
454 assert_eq!(
455 prover_zerocheck_output.concat_multilinear_evals,
456 verifier_zerocheck_output.concat_multilinear_evals,
457 );
458 }
459 }
460
461 #[test]
462 fn test_zerocheck_end_to_end_basic() {
463 test_zerocheck_end_to_end_helper::<
464 OptimalUnderlier128b,
465 BinaryField128b,
466 BinaryField16b,
467 BinaryField16b,
468 BinaryField8b,
469 >()
470 }
471
472 #[test]
473 fn test_zerocheck_end_to_end_with_nontrivial_packing() {
474 test_zerocheck_end_to_end_helper::<
477 OptimalUnderlier512b,
478 BinaryField128b,
479 BinaryField16b,
480 BinaryField16b,
481 BinaryField8b,
482 >()
483 }
484
485 #[test]
486 fn test_zerocheck_end_to_end_bytesliced() {
487 test_zerocheck_end_to_end_helper::<
488 <ByteSlicedAES64x128b as WithUnderlier>::Underlier,
489 AESTowerField128b,
490 AESTowerField16b,
491 AESTowerField16b,
492 AESTowerField8b,
493 >()
494 }
495}