1use binius_field::{util::eq, Field, PackedField};
4use binius_math::{ArithCircuit, CompositionPoly};
5use binius_utils::bail;
6use getset::CopyGetters;
7use itertools::{izip, Either};
8
9use super::{
10 common::{CompositeSumClaim, SumcheckClaim},
11 error::{Error, VerificationError},
12};
13use crate::protocols::sumcheck::BatchSumcheckOutput;
14
15#[derive(Debug, Clone, CopyGetters)]
21pub struct EqIndSumcheckClaim<F: Field, Composition> {
22 #[getset(get_copy = "pub")]
23 n_vars: usize,
24 #[getset(get_copy = "pub")]
25 n_multilinears: usize,
26 eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
27}
28
29impl<F: Field, Composition> EqIndSumcheckClaim<F, Composition>
30where
31 Composition: CompositionPoly<F>,
32{
33 pub fn new(
40 n_vars: usize,
41 n_multilinears: usize,
42 eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
43 ) -> Result<Self, Error> {
44 for CompositeSumClaim {
45 ref composition, ..
46 } in &eq_ind_composite_sums
47 {
48 if composition.n_vars() != n_multilinears {
49 bail!(Error::InvalidComposition {
50 actual: composition.n_vars(),
51 expected: n_multilinears,
52 });
53 }
54 }
55 Ok(Self {
56 n_vars,
57 n_multilinears,
58 eq_ind_composite_sums,
59 })
60 }
61
62 pub fn max_individual_degree(&self) -> usize {
64 self.eq_ind_composite_sums
65 .iter()
66 .map(|composite_sum| composite_sum.composition.degree())
67 .max()
68 .unwrap_or(0)
69 }
70
71 pub fn eq_ind_composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
72 &self.eq_ind_composite_sums
73 }
74}
75
76pub fn reduce_to_regular_sumchecks<F: Field, Composition: CompositionPoly<F>>(
81 claims: &[EqIndSumcheckClaim<F, Composition>],
82) -> Result<Vec<SumcheckClaim<F, ExtraProduct<&Composition>>>, Error> {
83 claims
84 .iter()
85 .map(|eq_ind_sumcheck_claim| {
86 let EqIndSumcheckClaim {
87 n_vars,
88 n_multilinears,
89 eq_ind_composite_sums,
90 ..
91 } = eq_ind_sumcheck_claim;
92 SumcheckClaim::new(
93 *n_vars,
94 *n_multilinears + 1,
95 eq_ind_composite_sums
96 .iter()
97 .map(|composite_sum| CompositeSumClaim {
98 composition: ExtraProduct {
99 inner: &composite_sum.composition,
100 },
101 sum: composite_sum.sum,
102 })
103 .collect(),
104 )
105 })
106 .collect()
107}
108
109pub enum ClaimsSortingOrder {
111 AscendingVars,
112 DescendingVars,
113}
114
115pub fn verify_sumcheck_outputs<F: Field, Composition: CompositionPoly<F>>(
124 sorting_order: ClaimsSortingOrder,
125 claims: &[EqIndSumcheckClaim<F, Composition>],
126 eq_ind_challenges: &[F],
127 sumcheck_output: BatchSumcheckOutput<F>,
128) -> Result<BatchSumcheckOutput<F>, Error> {
129 let BatchSumcheckOutput {
130 challenges: sumcheck_challenges,
131 mut multilinear_evals,
132 } = sumcheck_output;
133
134 if multilinear_evals.len() != claims.len() {
135 bail!(VerificationError::NumberOfFinalEvaluations);
136 }
137
138 let claims_evals_inner = izip!(claims, &mut multilinear_evals);
140 let claims_evals_non_desc = match sorting_order {
141 ClaimsSortingOrder::AscendingVars => Either::Left(claims_evals_inner),
142 ClaimsSortingOrder::DescendingVars => Either::Right(claims_evals_inner.rev()),
143 };
144
145 if eq_ind_challenges.len() != sumcheck_challenges.len() {
146 bail!(VerificationError::NumberOfRounds);
147 }
148
149 let mut eq_ind_eval = F::ONE;
151 let mut last_n_vars = 0;
152 for (claim, multilinear_evals) in claims_evals_non_desc {
153 if claim.n_multilinears() + 1 != multilinear_evals.len() {
154 bail!(VerificationError::NumberOfMultilinearEvals);
155 }
156
157 if claim.n_vars() < last_n_vars {
158 bail!(Error::ClaimsOutOfOrder);
159 }
160
161 while last_n_vars < claim.n_vars() && last_n_vars < sumcheck_challenges.len() {
162 let sumcheck_challenge =
164 sumcheck_challenges[sumcheck_challenges.len() - 1 - last_n_vars];
165 let eq_ind_challenge = eq_ind_challenges[eq_ind_challenges.len() - 1 - last_n_vars];
166 eq_ind_eval *= eq(sumcheck_challenge, eq_ind_challenge);
167 last_n_vars += 1;
168 }
169
170 let multilinear_evals_last = multilinear_evals
171 .pop()
172 .expect("checked above that multilinear_evals length is at least 1");
173 if eq_ind_eval != multilinear_evals_last {
174 return Err(VerificationError::IncorrectEqIndEvaluation.into());
175 }
176 }
177
178 Ok(BatchSumcheckOutput {
179 challenges: sumcheck_challenges,
180 multilinear_evals,
181 })
182}
183
184#[derive(Debug, Clone)]
185pub struct ExtraProduct<Composition> {
186 pub inner: Composition,
187}
188
189impl<P, Composition> CompositionPoly<P> for ExtraProduct<Composition>
190where
191 P: PackedField,
192 Composition: CompositionPoly<P>,
193{
194 fn n_vars(&self) -> usize {
195 self.inner.n_vars() + 1
196 }
197
198 fn degree(&self) -> usize {
199 self.inner.degree() + 1
200 }
201
202 fn expression(&self) -> ArithCircuit<P::Scalar> {
203 self.inner.expression() * ArithCircuit::var(self.inner.n_vars())
204 }
205
206 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
207 let n_vars = self.n_vars();
208 if query.len() != n_vars {
209 bail!(binius_math::Error::IncorrectQuerySize { expected: n_vars });
210 }
211
212 let inner_eval = self.inner.evaluate(&query[..n_vars - 1])?;
213 Ok(inner_eval * query[n_vars - 1])
214 }
215
216 fn binary_tower_level(&self) -> usize {
217 self.inner.binary_tower_level()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use std::{iter, sync::Arc};
224
225 use binius_field::{
226 arch::{OptimalUnderlier128b, OptimalUnderlier256b, OptimalUnderlier512b},
227 as_packed_field::{PackScalar, PackedType},
228 packed::set_packed_slice,
229 underlier::UnderlierType,
230 BinaryField128b, BinaryField32b, BinaryField8b, ExtensionField, Field,
231 PackedBinaryField1x128b, PackedExtension, PackedField, PackedFieldIndexable,
232 PackedSubfield, RepackedExtension, TowerField,
233 };
234 use binius_hal::{
235 make_portable_backend, ComputationBackend, ComputationBackendExt, SumcheckMultilinear,
236 };
237 use binius_hash::groestl::Groestl256;
238 use binius_math::{
239 CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
240 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
241 MultilinearQuery,
242 };
243 use rand::{rngs::StdRng, Rng, SeedableRng};
244
245 use crate::{
246 composition::BivariateProduct,
247 fiat_shamir::{CanSample, HasherChallenger},
248 protocols::{
249 sumcheck::{
250 self,
251 eq_ind::{ClaimsSortingOrder, ExtraProduct},
252 immediate_switchover_heuristic,
253 prove::{
254 eq_ind::{ConstEvalSuffix, EqIndSumcheckProverBuilder},
255 RegularSumcheckProver,
256 },
257 BatchSumcheckOutput, CompositeSumClaim, EqIndSumcheckClaim,
258 },
259 test_utils::{
260 generate_zero_product_multilinears, AddOneComposition, TestProductComposition,
261 },
262 },
263 transcript::ProverTranscript,
264 transparent::eq_ind::EqIndPartialEval,
265 witness::MultilinearWitness,
266 };
267
268 fn test_prove_verify_bivariate_product_helper<U, F, FDomain>(n_vars: usize)
269 where
270 U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
271 F: TowerField + ExtensionField<FDomain>,
272 FDomain: TowerField,
273 PackedType<U, F>: PackedFieldIndexable,
274 {
275 let max_nonzero_prefix = 1 << n_vars;
276 let mut nonzero_prefixes = vec![0];
277
278 for i in 1..=n_vars {
279 nonzero_prefixes.push(1 << i);
280 }
281
282 let mut rng = StdRng::seed_from_u64(0);
283 for _ in 0..n_vars + 5 {
284 nonzero_prefixes.push(rng.gen_range(1..max_nonzero_prefix));
285 }
286
287 for nonzero_prefix in nonzero_prefixes {
288 for evaluation_order in [EvaluationOrder::LowToHigh, EvaluationOrder::HighToLow] {
289 test_prove_verify_bivariate_product_helper_under_evaluation_order::<U, F, FDomain>(
290 evaluation_order,
291 n_vars,
292 nonzero_prefix,
293 );
294 }
295 }
296 }
297
298 fn test_prove_verify_bivariate_product_helper_under_evaluation_order<U, F, FDomain>(
299 evaluation_order: EvaluationOrder,
300 n_vars: usize,
301 nonzero_prefix: usize,
302 ) where
303 U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
304 F: TowerField + ExtensionField<FDomain>,
305 FDomain: TowerField,
306 PackedType<U, F>: PackedFieldIndexable,
307 {
308 let mut rng = StdRng::seed_from_u64(0);
309
310 let packed_len = 1 << n_vars.saturating_sub(PackedType::<U, F>::LOG_WIDTH);
311 let mut a_column = (0..packed_len)
312 .map(|_| PackedType::<U, F>::random(&mut rng))
313 .collect::<Vec<_>>();
314 let b_column = (0..packed_len)
315 .map(|_| PackedType::<U, F>::random(&mut rng))
316 .collect::<Vec<_>>();
317 let mut ab1_column = iter::zip(&a_column, &b_column)
318 .map(|(&a, &b)| a * b + PackedType::<U, F>::one())
319 .collect::<Vec<_>>();
320
321 for i in nonzero_prefix..1 << n_vars {
322 set_packed_slice(&mut a_column, i, F::ZERO);
323 set_packed_slice(&mut ab1_column, i, F::ONE);
324 }
325
326 let a_mle =
327 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&a_column).unwrap());
328 let b_mle =
329 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&b_column).unwrap());
330 let ab1_mle =
331 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&ab1_column).unwrap());
332
333 let eq_ind_challenges = (0..n_vars).map(|_| F::random(&mut rng)).collect::<Vec<_>>();
334 let sum = ab1_mle
335 .evaluate(MultilinearQuery::expand(&eq_ind_challenges).to_ref())
336 .unwrap();
337
338 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
339
340 let backend = make_portable_backend();
341 let evaluation_domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
342
343 let composition = AddOneComposition::new(BivariateProduct {});
344
345 let composite_claim = CompositeSumClaim { sum, composition };
346
347 let prover = EqIndSumcheckProverBuilder::with_switchover(
348 vec![a_mle, b_mle],
349 immediate_switchover_heuristic,
350 &backend,
351 )
352 .unwrap()
353 .with_const_suffixes(&[(F::ZERO, (1 << n_vars) - nonzero_prefix), (F::ZERO, 0)])
354 .unwrap()
355 .build(
356 evaluation_order,
357 &eq_ind_challenges,
358 [composite_claim.clone()],
359 evaluation_domain_factory,
360 )
361 .unwrap();
362
363 let (_, const_eval_suffix) = prover.compositions().first().unwrap();
364 assert_eq!(
365 *const_eval_suffix,
366 ConstEvalSuffix {
367 suffix: (1 << n_vars) - nonzero_prefix,
368 value: F::ONE,
369 value_at_inf: F::ZERO
370 }
371 );
372
373 let _sumcheck_proof_output =
374 sumcheck::prove::batch_prove(vec![prover], &mut transcript).unwrap();
375
376 let mut verifier_transcript = transcript.into_verifier();
377
378 let eq_ind_sumcheck_verifier_claim =
379 EqIndSumcheckClaim::new(n_vars, 2, vec![composite_claim]).unwrap();
380 let eq_ind_sumcheck_verifier_claims = [eq_ind_sumcheck_verifier_claim];
381 let regular_sumcheck_verifier_claims =
382 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_verifier_claims)
383 .unwrap();
384
385 let _sumcheck_verify_output = sumcheck::batch_verify(
386 evaluation_order,
387 ®ular_sumcheck_verifier_claims,
388 &mut verifier_transcript,
389 )
390 .unwrap();
391 }
392
393 #[test]
394 fn test_eq_ind_sumcheck_prove_verify_128b() {
395 let n_vars = 8;
396
397 test_prove_verify_bivariate_product_helper::<
398 OptimalUnderlier128b,
399 BinaryField128b,
400 BinaryField8b,
401 >(n_vars);
402 }
403
404 #[test]
405 fn test_eq_ind_sumcheck_prove_verify_256b() {
406 let n_vars = 8;
407
408 test_prove_verify_bivariate_product_helper::<
411 OptimalUnderlier256b,
412 BinaryField128b,
413 BinaryField8b,
414 >(n_vars);
415 }
416
417 #[test]
418 fn test_eq_ind_sumcheck_prove_verify_512b() {
419 let n_vars = 8;
420
421 test_prove_verify_bivariate_product_helper::<
424 OptimalUnderlier512b,
425 BinaryField128b,
426 BinaryField8b,
427 >(n_vars);
428 }
429
430 fn make_regular_sumcheck_prover_for_eq_ind_sumcheck<
431 'a,
432 'b,
433 F,
434 FDomain,
435 P,
436 Composition,
437 M,
438 Backend,
439 >(
440 multilinears: Vec<M>,
441 claims: &'b [CompositeSumClaim<F, Composition>],
442 challenges: &[F],
443 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
444 switchover_fn: impl Fn(usize) -> usize,
445 backend: &'a Backend,
446 ) -> RegularSumcheckProver<
447 'a,
448 FDomain,
449 P,
450 ExtraProduct<&'b Composition>,
451 MultilinearWitness<'static, P>,
452 Backend,
453 >
454 where
455 F: Field,
456 FDomain: Field,
457 P: PackedField<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<P>,
458 Composition: CompositionPoly<P>,
459 M: MultilinearPoly<P> + Send + Sync + 'static,
460 Backend: ComputationBackend,
461 {
462 let eq_ind = EqIndPartialEval::new(challenges)
463 .multilinear_extension::<P, _>(backend)
464 .unwrap();
465
466 let multilinears = multilinears
467 .into_iter()
468 .map(|multilin| Arc::new(multilin) as Arc<dyn MultilinearPoly<_> + Send + Sync>)
469 .chain([eq_ind.specialize_arc_dyn()])
470 .collect();
471
472 let composite_sum_claims =
473 claims
474 .iter()
475 .map(|CompositeSumClaim { composition, sum }| CompositeSumClaim {
476 composition: ExtraProduct { inner: composition },
477 sum: *sum,
478 });
479 RegularSumcheckProver::new(
480 EvaluationOrder::HighToLow,
481 multilinears,
482 composite_sum_claims,
483 evaluation_domain_factory,
484 switchover_fn,
485 backend,
486 )
487 .unwrap()
488 }
489
490 fn test_compare_prover_with_reference(
491 n_vars: usize,
492 n_multilinears: usize,
493 switchover_rd: usize,
494 ) {
495 type P = PackedBinaryField1x128b;
496 type FBase = BinaryField32b;
497 type FDomain = BinaryField8b;
498 let mut rng = StdRng::seed_from_u64(0);
499
500 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
502 &mut rng,
503 n_vars,
504 n_multilinears,
505 );
506
507 let mut prove_transcript_1 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
508 let backend = make_portable_backend();
509 let challenges = prove_transcript_1.sample_vec(n_vars);
510
511 let composite_claim = CompositeSumClaim {
512 composition: TestProductComposition::new(n_multilinears),
513 sum: Field::ZERO,
514 };
515
516 let composite_claims = [composite_claim];
517
518 let switchover_fn = |_| switchover_rd;
519
520 let sumcheck_multilinears = multilins
521 .iter()
522 .cloned()
523 .map(|multilin| SumcheckMultilinear::transparent(multilin, &switchover_fn))
524 .collect::<Vec<_>>();
525
526 sumcheck::prove::eq_ind::validate_witness(
527 n_vars,
528 &sumcheck_multilinears,
529 &challenges,
530 composite_claims.clone(),
531 )
532 .unwrap();
533
534 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
535 let reference_prover =
536 make_regular_sumcheck_prover_for_eq_ind_sumcheck::<_, FDomain, _, _, _, _>(
537 multilins.clone(),
538 &composite_claims,
539 &challenges,
540 domain_factory.clone(),
541 |_| switchover_rd,
542 &backend,
543 );
544
545 let BatchSumcheckOutput {
546 challenges: sumcheck_challenges_1,
547 multilinear_evals: multilinear_evals_1,
548 } = sumcheck::batch_prove(vec![reference_prover], &mut prove_transcript_1).unwrap();
549
550 let optimized_prover =
551 EqIndSumcheckProverBuilder::with_switchover(multilins, switchover_fn, &backend)
552 .unwrap()
553 .build::<FDomain, _>(
554 EvaluationOrder::HighToLow,
555 &challenges,
556 composite_claims,
557 domain_factory,
558 )
559 .unwrap();
560
561 let mut prove_transcript_2 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
562 let _: Vec<BinaryField128b> = prove_transcript_2.sample_vec(n_vars);
563 let BatchSumcheckOutput {
564 challenges: sumcheck_challenges_2,
565 multilinear_evals: multilinear_evals_2,
566 } = sumcheck::batch_prove(vec![optimized_prover], &mut prove_transcript_2).unwrap();
567
568 assert_eq!(prove_transcript_1.finalize(), prove_transcript_2.finalize());
569 assert_eq!(multilinear_evals_1, multilinear_evals_2);
570 assert_eq!(sumcheck_challenges_1, sumcheck_challenges_2);
571 }
572
573 fn test_prove_verify_product_constraint_helper(
574 n_vars: usize,
575 n_multilinears: usize,
576 switchover_rd: usize,
577 ) {
578 type P = PackedBinaryField1x128b;
579 type FBase = BinaryField32b;
580 type FE = BinaryField128b;
581 type FDomain = BinaryField8b;
582 let mut rng = StdRng::seed_from_u64(0);
583
584 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
585 &mut rng,
586 n_vars,
587 n_multilinears,
588 );
589
590 let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
591 let challenges = prove_transcript.sample_vec(n_vars);
592
593 let composite_claim = CompositeSumClaim {
594 composition: TestProductComposition::new(n_multilinears),
595 sum: Field::ZERO,
596 };
597
598 let composite_claims = vec![composite_claim];
599
600 let switchover_fn = |_| switchover_rd;
601
602 let sumcheck_multilinears = multilins
603 .iter()
604 .cloned()
605 .map(|multilin| SumcheckMultilinear::transparent(multilin, &switchover_fn))
606 .collect::<Vec<_>>();
607
608 sumcheck::prove::eq_ind::validate_witness(
609 n_vars,
610 &sumcheck_multilinears,
611 &challenges,
612 composite_claims.clone(),
613 )
614 .unwrap();
615
616 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
617 let backend = make_portable_backend();
618
619 let prover =
620 EqIndSumcheckProverBuilder::with_switchover(multilins.clone(), switchover_fn, &backend)
621 .unwrap()
622 .build::<FDomain, _>(
623 EvaluationOrder::HighToLow,
624 &challenges,
625 composite_claims.clone(),
626 domain_factory,
627 )
628 .unwrap();
629
630 let prove_output =
631 sumcheck::prove::batch_prove(vec![prover], &mut prove_transcript).unwrap();
632
633 let eq_ind_sumcheck_claim =
634 EqIndSumcheckClaim::new(n_vars, n_multilinears, composite_claims).unwrap();
635 let eq_ind_sumcheck_claims = vec![eq_ind_sumcheck_claim];
636
637 let BatchSumcheckOutput {
638 challenges: prover_eval_point,
639 multilinear_evals: prover_multilinear_evals,
640 } = sumcheck::eq_ind::verify_sumcheck_outputs(
641 ClaimsSortingOrder::AscendingVars,
642 &eq_ind_sumcheck_claims,
643 &challenges,
644 prove_output,
645 )
646 .unwrap();
647
648 let prover_sample = CanSample::<FE>::sample(&mut prove_transcript);
649 let mut verify_transcript = prove_transcript.into_verifier();
650 let _: Vec<BinaryField128b> = verify_transcript.sample_vec(n_vars);
651
652 let regular_sumcheck_claims =
653 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims).unwrap();
654
655 let verifier_output = sumcheck::batch_verify(
656 EvaluationOrder::HighToLow,
657 ®ular_sumcheck_claims,
658 &mut verify_transcript,
659 )
660 .unwrap();
661
662 let BatchSumcheckOutput {
663 challenges: verifier_eval_point,
664 multilinear_evals: verifier_multilinear_evals,
665 } = sumcheck::eq_ind::verify_sumcheck_outputs(
666 ClaimsSortingOrder::AscendingVars,
667 &eq_ind_sumcheck_claims,
668 &challenges,
669 verifier_output,
670 )
671 .unwrap();
672
673 assert_eq!(prover_sample, CanSample::<FE>::sample(&mut verify_transcript));
675 verify_transcript.finalize().unwrap();
676
677 assert_eq!(prover_eval_point, verifier_eval_point);
678 assert_eq!(prover_multilinear_evals, verifier_multilinear_evals);
679
680 assert_eq!(verifier_multilinear_evals.len(), 1);
681 assert_eq!(verifier_multilinear_evals[0].len(), n_multilinears);
682
683 let multilin_query = backend.multilinear_query(&verifier_eval_point).unwrap();
685 for (multilinear, &expected) in iter::zip(multilins, verifier_multilinear_evals[0].iter()) {
686 assert_eq!(multilinear.evaluate(multilin_query.to_ref()).unwrap(), expected);
687 }
688 }
689
690 #[test]
691 fn test_compare_eq_ind_prover_to_regular_sumcheck() {
692 for n_vars in 2..8 {
693 for n_multilinears in 1..5 {
694 for switchover_rd in 1..=n_vars / 2 {
695 test_compare_prover_with_reference(n_vars, n_multilinears, switchover_rd);
696 }
697 }
698 }
699 }
700
701 #[test]
702 fn test_prove_verify_product_basic() {
703 for n_vars in 2..8 {
704 for n_multilinears in 1..5 {
705 for switchover_rd in 1..=n_vars / 2 {
706 test_prove_verify_product_constraint_helper(
707 n_vars,
708 n_multilinears,
709 switchover_rd,
710 );
711 }
712 }
713 }
714 }
715}