1use std::{
4 marker::PhantomData,
5 ops::{Mul, MulAssign},
6};
7
8use binius_field::{packed::set_packed_slice, ExtensionField, Field, PackedField, TowerField};
9use binius_hal::{make_portable_backend, ComputationBackendExt};
10use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearExtension};
11use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
12use bytemuck::zeroed_vec;
13use getset::CopyGetters;
14use itertools::izip;
15
16use super::error::Error;
17use crate::{
18 composition::{BivariateProduct, IndexComposition},
19 polynomial::Error as PolynomialError,
20 protocols::sumcheck::{
21 eq_ind::EqIndSumcheckClaim, BatchSumcheckOutput, CompositeSumClaim, SumcheckClaim,
22 VerificationError,
23 },
24};
25
26#[derive(Debug, CopyGetters)]
27pub struct ZerocheckClaim<F: Field, Composition> {
28 #[getset(get_copy = "pub")]
29 n_vars: usize,
30 #[getset(get_copy = "pub")]
31 n_multilinears: usize,
32 composite_zeros: Vec<Composition>,
33 _marker: PhantomData<F>,
34}
35
36impl<F: Field, Composition> ZerocheckClaim<F, Composition>
37where
38 Composition: CompositionPoly<F>,
39{
40 pub fn new(
41 n_vars: usize,
42 n_multilinears: usize,
43 composite_zeros: Vec<Composition>,
44 ) -> Result<Self, Error> {
45 for composition in &composite_zeros {
46 if composition.n_vars() != n_multilinears {
47 bail!(Error::InvalidComposition {
48 actual: composition.n_vars(),
49 expected: n_multilinears,
50 });
51 }
52 }
53 Ok(Self {
54 n_vars,
55 n_multilinears,
56 composite_zeros,
57 _marker: PhantomData,
58 })
59 }
60
61 pub fn max_individual_degree(&self) -> usize {
63 self.composite_zeros
64 .iter()
65 .map(|composite_zero| composite_zero.degree())
66 .max()
67 .unwrap_or(0)
68 }
69
70 pub fn composite_zeros(&self) -> &[Composition] {
71 &self.composite_zeros
72 }
73}
74
75#[derive(Clone, Debug)]
81pub struct ZerocheckRoundEvals<F: Field> {
82 pub evals: Vec<F>,
83}
84
85impl<F: Field> ZerocheckRoundEvals<F> {
86 pub fn zeros(len: usize) -> Self {
88 Self {
89 evals: vec![F::ZERO; len],
90 }
91 }
92
93 pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
96 if self.evals.len() != rhs.evals.len() {
97 bail!(Error::LagrangeRoundEvalsSizeMismatch);
98 }
99
100 for (lhs, rhs) in izip!(&mut self.evals, &rhs.evals) {
101 *lhs += rhs;
102 }
103
104 Ok(())
105 }
106}
107
108impl<F: Field> Mul<F> for ZerocheckRoundEvals<F> {
109 type Output = Self;
110
111 fn mul(mut self, rhs: F) -> Self::Output {
112 self *= rhs;
113 self
114 }
115}
116
117impl<F: Field> MulAssign<F> for ZerocheckRoundEvals<F> {
118 fn mul_assign(&mut self, rhs: F) {
119 for eval in &mut self.evals {
120 *eval *= rhs;
121 }
122 }
123}
124
125pub const fn domain_size(composition_degree: usize, skip_rounds: usize) -> usize {
132 composition_degree << skip_rounds
133}
134
135pub const fn extrapolated_scalars_count(composition_degree: usize, skip_rounds: usize) -> usize {
137 composition_degree.saturating_sub(1) << skip_rounds
138}
139
140pub struct BatchZerocheckOutput<F: Field> {
142 pub skipped_challenges: Vec<F>,
145 pub unskipped_challenges: Vec<F>,
148 pub concat_multilinear_evals: Vec<F>,
150}
151
152pub fn reduce_to_eq_ind_sumchecks<F: Field, Composition: CompositionPoly<F>>(
157 skip_rounds: usize,
158 claims: &[ZerocheckClaim<F, Composition>],
159) -> Result<Vec<EqIndSumcheckClaim<F, &Composition>>, Error> {
160 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
162 bail!(Error::ClaimsOutOfOrder);
163 }
164
165 claims
166 .iter()
167 .map(|zerocheck_claim| {
168 let &ZerocheckClaim {
169 n_vars,
170 n_multilinears,
171 ref composite_zeros,
172 ..
173 } = zerocheck_claim;
174 EqIndSumcheckClaim::new(
175 n_vars.saturating_sub(skip_rounds),
176 n_multilinears,
177 composite_zeros
178 .iter()
179 .map(|composition| CompositeSumClaim {
180 composition,
181 sum: F::ZERO,
182 })
183 .collect(),
184 )
185 })
186 .collect()
187}
188
189pub fn univariatizing_reduction_claim<F: Field>(
198 skip_rounds: usize,
199 univariatized_multilinear_evals: &[impl AsRef<[F]>],
200) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
201 let n_multilinears = univariatized_multilinear_evals
202 .iter()
203 .map(|claim_evals| claim_evals.as_ref().len())
204 .sum();
205
206 let composite_sums = univariatized_multilinear_evals
209 .iter()
210 .flat_map(|claim_evals| claim_evals.as_ref())
211 .enumerate()
212 .map(|(i, &univariatized_multilinear_eval)| {
213 let composition =
214 IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
215 .expect("index composition indice correct by construction");
216
217 CompositeSumClaim {
218 composition,
219 sum: univariatized_multilinear_eval,
220 }
221 })
222 .collect();
223
224 SumcheckClaim::new(skip_rounds, n_multilinears + 1, composite_sums)
225}
226
227pub fn verify_reduction_sumcheck_output<F>(
233 claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
234 skip_rounds: usize,
235 univariate_challenge: F,
236 reduction_sumcheck_output: BatchSumcheckOutput<F>,
237) -> Result<BatchSumcheckOutput<F>, Error>
238where
239 F: TowerField,
240{
241 let BatchSumcheckOutput {
242 challenges: reduction_sumcheck_challenges,
243 mut multilinear_evals,
244 } = reduction_sumcheck_output;
245
246 if claim.n_vars() != skip_rounds {
248 bail!(Error::IncorrectUnivariatizingReductionClaims);
249 }
250
251 if reduction_sumcheck_challenges.len() != skip_rounds || multilinear_evals.len() != 1 {
253 bail!(Error::IncorrectUnivariatizingReductionSumcheck);
254 }
255
256 let subspace = BinarySubspace::<F::Canonical>::with_dim(skip_rounds)?.isomorphic::<F>();
258 let evaluation_domain =
259 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
260
261 let lagrange_mle =
262 lagrange_evals_multilinear_extension::<F, F, F>(&evaluation_domain, univariate_challenge)?;
263
264 let query = make_portable_backend().multilinear_query::<F>(&reduction_sumcheck_challenges)?;
265 let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
266
267 let first_claim_multilinear_evals = multilinear_evals
268 .first_mut()
269 .expect("exactly one claim in reduction sumcheck");
270
271 let multilinear_evals_last_eval = first_claim_multilinear_evals
273 .pop()
274 .ok_or(VerificationError::NumberOfFinalEvaluations)?;
275
276 if multilinear_evals_last_eval != expected_last_eval {
277 bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
278 }
279
280 let output = BatchSumcheckOutput {
281 challenges: reduction_sumcheck_challenges,
282 multilinear_evals,
283 };
284
285 Ok(output)
286}
287
288pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
291 evaluation_domain: &EvaluationDomain<FDomain>,
292 univariate_challenge: F,
293) -> Result<MultilinearExtension<P>, PolynomialError>
294where
295 FDomain: Field,
296 F: Field + ExtensionField<FDomain>,
297 P: PackedField<Scalar = F>,
298{
299 let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
300
301 let n_vars = log2_strict_usize(lagrange_evals.len());
302 let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
303
304 for (i, &lagrange_eval) in lagrange_evals.iter().enumerate() {
305 set_packed_slice(&mut packed, i, lagrange_eval);
306 }
307
308 Ok(MultilinearExtension::new(n_vars, packed)?)
309}
310
311#[cfg(test)]
312mod tests {
313 use std::sync::Arc;
314
315 use binius_field::{
316 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
317 as_packed_field::{PackScalar, PackedType},
318 underlier::{UnderlierType, WithUnderlier},
319 AESTowerField128b, AESTowerField16b, AESTowerField8b, BinaryField128b, BinaryField16b,
320 BinaryField8b, ByteSlicedAES64x128b,
321 };
322 use binius_hal::make_portable_backend;
323 use binius_hash::groestl::Groestl256;
324 use binius_math::IsomorphicEvaluationDomainFactory;
325 use rand::{prelude::StdRng, SeedableRng};
326
327 use super::*;
328 use crate::{
329 composition::ProductComposition,
330 fiat_shamir::{CanSample, HasherChallenger},
331 polynomial::CompositionScalarAdapter,
332 protocols::{
333 sumcheck::{self, prove::ZerocheckProverImpl},
334 test_utils::generate_zero_product_multilinears,
335 },
336 transcript::ProverTranscript,
337 };
338
339 fn test_zerocheck_end_to_end_helper<U, F, FDomain, FBase, FWitness>()
340 where
341 U: UnderlierType
342 + PackScalar<F>
343 + PackScalar<FBase>
344 + PackScalar<FDomain>
345 + PackScalar<FWitness>,
346 F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
347 FBase: TowerField + ExtensionField<FDomain>,
348 FDomain: TowerField,
349 FWitness: Field,
350 {
351 let max_n_vars = 6;
352 let n_multilinears = 9;
353
354 let backend = make_portable_backend();
355 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
356 let mut rng = StdRng::seed_from_u64(0);
357
358 let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
359 let triple =
360 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
361 let quad =
362 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
363
364 let prover_compositions = [
365 (
366 "pair".into(),
367 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
368 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
369 ),
370 (
371 "triple".into(),
372 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
373 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
374 ),
375 (
376 "quad".into(),
377 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
378 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
379 ),
380 ];
381
382 let prover_adapter_compositions = [
383 CompositionScalarAdapter::new(pair as Arc<dyn CompositionPoly<F>>),
384 CompositionScalarAdapter::new(triple as Arc<dyn CompositionPoly<F>>),
385 CompositionScalarAdapter::new(quad as Arc<dyn CompositionPoly<F>>),
386 ];
387
388 for skip_rounds in 0..=max_n_vars {
389 let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
390
391 let prover_zerocheck_challenges: Vec<F> = proof.sample_vec(max_n_vars - skip_rounds);
392
393 let mut zerocheck_claims = Vec::new();
394 let mut zerocheck_provers = Vec::new();
395 for n_vars in 1..=max_n_vars {
396 let mut multilinears = generate_zero_product_multilinears::<
397 PackedType<U, FWitness>,
398 PackedType<U, F>,
399 >(&mut rng, n_vars, 2);
400 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
401 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
402
403 let claim = ZerocheckClaim::<F, _>::new(
404 n_vars,
405 n_multilinears,
406 prover_adapter_compositions.to_vec(),
407 )
408 .unwrap();
409
410 let prover =
411 ZerocheckProverImpl::<FDomain, FBase, PackedType<U, F>, _, _, _, _, _>::new(
412 multilinears,
413 prover_compositions.to_vec(),
414 &prover_zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..],
415 domain_factory.clone(),
416 &backend,
417 )
418 .unwrap();
419
420 zerocheck_claims.push(claim);
421 zerocheck_provers.push(prover);
422 }
423
424 let prover_zerocheck_output =
425 sumcheck::prove::batch_prove_zerocheck::<F, FDomain, PackedType<U, F>, _, _>(
426 zerocheck_provers,
427 skip_rounds,
428 &mut proof,
429 )
430 .unwrap();
431
432 let mut verifier_proof = proof.into_verifier();
433
434 let verifier_zerocheck_output = sumcheck::batch_verify_zerocheck(
435 &zerocheck_claims,
436 skip_rounds,
437 &mut verifier_proof,
438 )
439 .unwrap();
440
441 verifier_proof.finalize().unwrap();
442
443 assert_eq!(
444 prover_zerocheck_output.skipped_challenges,
445 verifier_zerocheck_output.skipped_challenges
446 );
447 assert_eq!(
448 prover_zerocheck_output.unskipped_challenges,
449 verifier_zerocheck_output.unskipped_challenges
450 );
451 assert_eq!(
452 prover_zerocheck_output.concat_multilinear_evals,
453 verifier_zerocheck_output.concat_multilinear_evals,
454 );
455 }
456 }
457
458 #[test]
459 fn test_zerocheck_end_to_end_basic() {
460 test_zerocheck_end_to_end_helper::<
461 OptimalUnderlier128b,
462 BinaryField128b,
463 BinaryField16b,
464 BinaryField16b,
465 BinaryField8b,
466 >()
467 }
468
469 #[test]
470 fn test_zerocheck_end_to_end_with_nontrivial_packing() {
471 test_zerocheck_end_to_end_helper::<
474 OptimalUnderlier512b,
475 BinaryField128b,
476 BinaryField16b,
477 BinaryField16b,
478 BinaryField8b,
479 >()
480 }
481
482 #[test]
483 fn test_zerocheck_end_to_end_bytesliced() {
484 test_zerocheck_end_to_end_helper::<
485 <ByteSlicedAES64x128b as WithUnderlier>::Underlier,
486 AESTowerField128b,
487 AESTowerField16b,
488 AESTowerField16b,
489 AESTowerField8b,
490 >()
491 }
492}