1use binius_field::{Field, PackedField, util::eq};
4use binius_math::{ArithCircuit, CompositionPoly};
5use binius_utils::bail;
6use getset::CopyGetters;
7use itertools::{Either, izip};
8
9use super::{
10 common::{CompositeSumClaim, SumcheckClaim},
11 error::{Error, VerificationError},
12};
13use crate::protocols::sumcheck::BatchSumcheckOutput;
14
15#[derive(Debug, Clone, CopyGetters)]
21pub struct EqIndSumcheckClaim<F: Field, Composition> {
22 #[getset(get_copy = "pub")]
23 n_vars: usize,
24 #[getset(get_copy = "pub")]
25 n_multilinears: usize,
26 eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
27}
28
29impl<F: Field, Composition> EqIndSumcheckClaim<F, Composition>
30where
31 Composition: CompositionPoly<F>,
32{
33 pub fn new(
40 n_vars: usize,
41 n_multilinears: usize,
42 eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
43 ) -> Result<Self, Error> {
44 for CompositeSumClaim { composition, .. } in &eq_ind_composite_sums {
45 if composition.n_vars() != n_multilinears {
46 bail!(Error::InvalidComposition {
47 actual: composition.n_vars(),
48 expected: n_multilinears,
49 });
50 }
51 }
52 Ok(Self {
53 n_vars,
54 n_multilinears,
55 eq_ind_composite_sums,
56 })
57 }
58
59 pub fn max_individual_degree(&self) -> usize {
61 self.eq_ind_composite_sums
62 .iter()
63 .map(|composite_sum| composite_sum.composition.degree())
64 .max()
65 .unwrap_or(0)
66 }
67
68 pub fn eq_ind_composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
69 &self.eq_ind_composite_sums
70 }
71}
72
73pub fn reduce_to_regular_sumchecks<F: Field, Composition: CompositionPoly<F>>(
79 claims: &[EqIndSumcheckClaim<F, Composition>],
80) -> Result<Vec<SumcheckClaim<F, ExtraProduct<&Composition>>>, Error> {
81 claims
82 .iter()
83 .map(|eq_ind_sumcheck_claim| {
84 let EqIndSumcheckClaim {
85 n_vars,
86 n_multilinears,
87 eq_ind_composite_sums,
88 ..
89 } = eq_ind_sumcheck_claim;
90 SumcheckClaim::new(
91 *n_vars,
92 *n_multilinears + 1,
93 eq_ind_composite_sums
94 .iter()
95 .map(|composite_sum| CompositeSumClaim {
96 composition: ExtraProduct {
97 inner: &composite_sum.composition,
98 },
99 sum: composite_sum.sum,
100 })
101 .collect(),
102 )
103 })
104 .collect()
105}
106
107pub enum ClaimsSortingOrder {
109 AscendingVars,
110 DescendingVars,
111}
112
113pub fn verify_sumcheck_outputs<F: Field, Composition: CompositionPoly<F>>(
122 sorting_order: ClaimsSortingOrder,
123 claims: &[EqIndSumcheckClaim<F, Composition>],
124 eq_ind_challenges: &[F],
125 sumcheck_output: BatchSumcheckOutput<F>,
126) -> Result<BatchSumcheckOutput<F>, Error> {
127 let BatchSumcheckOutput {
128 challenges: sumcheck_challenges,
129 mut multilinear_evals,
130 } = sumcheck_output;
131
132 if multilinear_evals.len() != claims.len() {
133 bail!(VerificationError::NumberOfFinalEvaluations);
134 }
135
136 let claims_evals_inner = izip!(claims, &mut multilinear_evals);
138 let claims_evals_non_desc = match sorting_order {
139 ClaimsSortingOrder::AscendingVars => Either::Left(claims_evals_inner),
140 ClaimsSortingOrder::DescendingVars => Either::Right(claims_evals_inner.rev()),
141 };
142
143 if eq_ind_challenges.len() != sumcheck_challenges.len() {
144 bail!(VerificationError::NumberOfRounds);
145 }
146
147 let mut eq_ind_eval = F::ONE;
150 let mut last_n_vars = 0;
151 for (claim, multilinear_evals) in claims_evals_non_desc {
152 if claim.n_multilinears() + 1 != multilinear_evals.len() {
153 bail!(VerificationError::NumberOfMultilinearEvals);
154 }
155
156 if claim.n_vars() < last_n_vars {
157 bail!(Error::ClaimsOutOfOrder);
158 }
159
160 while last_n_vars < claim.n_vars() && last_n_vars < sumcheck_challenges.len() {
161 let sumcheck_challenge =
163 sumcheck_challenges[sumcheck_challenges.len() - 1 - last_n_vars];
164 let eq_ind_challenge = eq_ind_challenges[eq_ind_challenges.len() - 1 - last_n_vars];
165 eq_ind_eval *= eq(sumcheck_challenge, eq_ind_challenge);
166 last_n_vars += 1;
167 }
168
169 let multilinear_evals_last = multilinear_evals
170 .pop()
171 .expect("checked above that multilinear_evals length is at least 1");
172 if eq_ind_eval != multilinear_evals_last {
173 return Err(VerificationError::IncorrectEqIndEvaluation.into());
174 }
175 }
176
177 Ok(BatchSumcheckOutput {
178 challenges: sumcheck_challenges,
179 multilinear_evals,
180 })
181}
182
183#[derive(Debug, Clone)]
184pub struct ExtraProduct<Composition> {
185 pub inner: Composition,
186}
187
188impl<P, Composition> CompositionPoly<P> for ExtraProduct<Composition>
189where
190 P: PackedField,
191 Composition: CompositionPoly<P>,
192{
193 fn n_vars(&self) -> usize {
194 self.inner.n_vars() + 1
195 }
196
197 fn degree(&self) -> usize {
198 self.inner.degree() + 1
199 }
200
201 fn expression(&self) -> ArithCircuit<P::Scalar> {
202 self.inner.expression() * ArithCircuit::var(self.inner.n_vars())
203 }
204
205 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
206 let n_vars = self.n_vars();
207 if query.len() != n_vars {
208 bail!(binius_math::Error::IncorrectQuerySize {
209 expected: n_vars,
210 actual: query.len()
211 });
212 }
213
214 let inner_eval = self.inner.evaluate(&query[..n_vars - 1])?;
215 Ok(inner_eval * query[n_vars - 1])
216 }
217
218 fn binary_tower_level(&self) -> usize {
219 self.inner.binary_tower_level()
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use std::{iter, sync::Arc};
226
227 use binius_field::{
228 BinaryField8b, BinaryField32b, BinaryField128b, ExtensionField, Field,
229 PackedBinaryField1x128b, PackedExtension, PackedField, PackedFieldIndexable,
230 PackedSubfield, RepackedExtension, TowerField,
231 arch::{OptimalUnderlier128b, OptimalUnderlier256b, OptimalUnderlier512b},
232 as_packed_field::{PackScalar, PackedType},
233 packed::set_packed_slice,
234 underlier::UnderlierType,
235 };
236 use binius_hal::{
237 ComputationBackend, ComputationBackendExt, SumcheckMultilinear, make_portable_backend,
238 };
239 use binius_hash::groestl::Groestl256;
240 use binius_math::{
241 CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
242 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly,
243 MultilinearQuery,
244 };
245 use rand::{Rng, SeedableRng, rngs::StdRng};
246
247 use crate::{
248 composition::BivariateProduct,
249 fiat_shamir::{CanSample, HasherChallenger},
250 protocols::{
251 sumcheck::{
252 self, BatchSumcheckOutput, CompositeSumClaim, EqIndSumcheckClaim,
253 eq_ind::{ClaimsSortingOrder, ExtraProduct},
254 immediate_switchover_heuristic,
255 prove::{
256 RegularSumcheckProver,
257 eq_ind::{ConstEvalSuffix, EqIndSumcheckProverBuilder},
258 },
259 },
260 test_utils::{
261 AddOneComposition, TestProductComposition, generate_zero_product_multilinears,
262 },
263 },
264 transcript::ProverTranscript,
265 transparent::eq_ind::EqIndPartialEval,
266 witness::MultilinearWitness,
267 };
268
269 fn test_prove_verify_bivariate_product_helper<U, F, FDomain>(n_vars: usize)
270 where
271 U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
272 F: TowerField + ExtensionField<FDomain>,
273 FDomain: TowerField,
274 PackedType<U, F>: PackedFieldIndexable,
275 {
276 let max_nonzero_prefix = 1 << n_vars;
277 let mut nonzero_prefixes = vec![0];
278
279 for i in 1..=n_vars {
280 nonzero_prefixes.push(1 << i);
281 }
282
283 let mut rng = StdRng::seed_from_u64(0);
284 for _ in 0..n_vars + 5 {
285 nonzero_prefixes.push(rng.gen_range(1..max_nonzero_prefix));
286 }
287
288 for nonzero_prefix in nonzero_prefixes {
289 for evaluation_order in [EvaluationOrder::LowToHigh, EvaluationOrder::HighToLow] {
290 test_prove_verify_bivariate_product_helper_under_evaluation_order::<U, F, FDomain>(
291 evaluation_order,
292 n_vars,
293 nonzero_prefix,
294 );
295 }
296 }
297 }
298
299 fn test_prove_verify_bivariate_product_helper_under_evaluation_order<U, F, FDomain>(
300 evaluation_order: EvaluationOrder,
301 n_vars: usize,
302 nonzero_prefix: usize,
303 ) where
304 U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
305 F: TowerField + ExtensionField<FDomain>,
306 FDomain: TowerField,
307 PackedType<U, F>: PackedFieldIndexable,
308 {
309 let mut rng = StdRng::seed_from_u64(0);
310
311 let packed_len = 1 << n_vars.saturating_sub(PackedType::<U, F>::LOG_WIDTH);
312 let mut a_column = (0..packed_len)
313 .map(|_| PackedType::<U, F>::random(&mut rng))
314 .collect::<Vec<_>>();
315 let b_column = (0..packed_len)
316 .map(|_| PackedType::<U, F>::random(&mut rng))
317 .collect::<Vec<_>>();
318 let mut ab1_column = iter::zip(&a_column, &b_column)
319 .map(|(&a, &b)| a * b + PackedType::<U, F>::one())
320 .collect::<Vec<_>>();
321
322 for i in nonzero_prefix..1 << n_vars {
323 set_packed_slice(&mut a_column, i, F::ZERO);
324 set_packed_slice(&mut ab1_column, i, F::ONE);
325 }
326
327 let a_mle =
328 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&a_column).unwrap());
329 let b_mle =
330 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&b_column).unwrap());
331 let ab1_mle =
332 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&ab1_column).unwrap());
333
334 let eq_ind_challenges = (0..n_vars).map(|_| F::random(&mut rng)).collect::<Vec<_>>();
335 let sum = ab1_mle
336 .evaluate(MultilinearQuery::expand(&eq_ind_challenges).to_ref())
337 .unwrap();
338
339 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
340
341 let backend = make_portable_backend();
342 let evaluation_domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
343
344 let composition = AddOneComposition::new(BivariateProduct {});
345
346 let composite_claim = CompositeSumClaim { sum, composition };
347
348 let prover = EqIndSumcheckProverBuilder::with_switchover(
349 vec![a_mle, b_mle],
350 immediate_switchover_heuristic,
351 &backend,
352 )
353 .unwrap()
354 .with_const_suffixes(&[(F::ZERO, (1 << n_vars) - nonzero_prefix), (F::ZERO, 0)])
355 .unwrap()
356 .build(
357 evaluation_order,
358 &eq_ind_challenges,
359 [composite_claim.clone()],
360 evaluation_domain_factory,
361 )
362 .unwrap();
363
364 let (_, const_eval_suffix) = prover.compositions().first().unwrap();
365 assert_eq!(
366 *const_eval_suffix,
367 ConstEvalSuffix {
368 suffix: (1 << n_vars) - nonzero_prefix,
369 value: F::ONE,
370 value_at_inf: F::ZERO
371 }
372 );
373
374 let _sumcheck_proof_output =
375 sumcheck::prove::batch_prove(vec![prover], &mut transcript).unwrap();
376
377 let mut verifier_transcript = transcript.into_verifier();
378
379 let eq_ind_sumcheck_verifier_claim =
380 EqIndSumcheckClaim::new(n_vars, 2, vec![composite_claim]).unwrap();
381 let eq_ind_sumcheck_verifier_claims = [eq_ind_sumcheck_verifier_claim];
382 let regular_sumcheck_verifier_claims =
383 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_verifier_claims)
384 .unwrap();
385
386 let _sumcheck_verify_output = sumcheck::batch_verify(
387 evaluation_order,
388 ®ular_sumcheck_verifier_claims,
389 &mut verifier_transcript,
390 )
391 .unwrap();
392 }
393
394 #[test]
395 fn test_eq_ind_sumcheck_prove_verify_128b() {
396 let n_vars = 8;
397
398 test_prove_verify_bivariate_product_helper::<
399 OptimalUnderlier128b,
400 BinaryField128b,
401 BinaryField8b,
402 >(n_vars);
403 }
404
405 #[test]
406 fn test_eq_ind_sumcheck_prove_verify_256b() {
407 let n_vars = 8;
408
409 test_prove_verify_bivariate_product_helper::<
412 OptimalUnderlier256b,
413 BinaryField128b,
414 BinaryField8b,
415 >(n_vars);
416 }
417
418 #[test]
419 fn test_eq_ind_sumcheck_prove_verify_512b() {
420 let n_vars = 8;
421
422 test_prove_verify_bivariate_product_helper::<
425 OptimalUnderlier512b,
426 BinaryField128b,
427 BinaryField8b,
428 >(n_vars);
429 }
430
431 fn make_regular_sumcheck_prover_for_eq_ind_sumcheck<
432 'a,
433 'b,
434 F,
435 FDomain,
436 P,
437 Composition,
438 M,
439 Backend,
440 >(
441 multilinears: Vec<M>,
442 claims: &'b [CompositeSumClaim<F, Composition>],
443 challenges: &[F],
444 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
445 switchover_fn: impl Fn(usize) -> usize,
446 backend: &'a Backend,
447 ) -> RegularSumcheckProver<
448 'a,
449 FDomain,
450 P,
451 ExtraProduct<&'b Composition>,
452 MultilinearWitness<'static, P>,
453 Backend,
454 >
455 where
456 F: Field,
457 FDomain: Field,
458 P: PackedField<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<P>,
459 Composition: CompositionPoly<P>,
460 M: MultilinearPoly<P> + Send + Sync + 'static,
461 Backend: ComputationBackend,
462 {
463 let eq_ind = EqIndPartialEval::new(challenges)
464 .multilinear_extension::<P, _>(backend)
465 .unwrap();
466
467 let multilinears = multilinears
468 .into_iter()
469 .map(|multilin| Arc::new(multilin) as Arc<dyn MultilinearPoly<_> + Send + Sync>)
470 .chain([eq_ind.specialize_arc_dyn()])
471 .collect();
472
473 let composite_sum_claims =
474 claims
475 .iter()
476 .map(|CompositeSumClaim { composition, sum }| CompositeSumClaim {
477 composition: ExtraProduct { inner: composition },
478 sum: *sum,
479 });
480 RegularSumcheckProver::new(
481 EvaluationOrder::HighToLow,
482 multilinears,
483 composite_sum_claims,
484 evaluation_domain_factory,
485 switchover_fn,
486 backend,
487 )
488 .unwrap()
489 }
490
491 fn test_compare_prover_with_reference(
492 n_vars: usize,
493 n_multilinears: usize,
494 switchover_rd: usize,
495 ) {
496 type P = PackedBinaryField1x128b;
497 type FBase = BinaryField32b;
498 type FDomain = BinaryField8b;
499 let mut rng = StdRng::seed_from_u64(0);
500
501 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
503 &mut rng,
504 n_vars,
505 n_multilinears,
506 );
507
508 let mut prove_transcript_1 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
509 let backend = make_portable_backend();
510 let challenges = prove_transcript_1.sample_vec(n_vars);
511
512 let composite_claim = CompositeSumClaim {
513 composition: TestProductComposition::new(n_multilinears),
514 sum: Field::ZERO,
515 };
516
517 let composite_claims = [composite_claim];
518
519 let switchover_fn = |_| switchover_rd;
520
521 let sumcheck_multilinears = multilins
522 .iter()
523 .cloned()
524 .map(|multilin| SumcheckMultilinear::transparent(multilin, &switchover_fn))
525 .collect::<Vec<_>>();
526
527 sumcheck::prove::eq_ind::validate_witness(
528 n_vars,
529 &sumcheck_multilinears,
530 &challenges,
531 composite_claims.clone(),
532 )
533 .unwrap();
534
535 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
536 let reference_prover =
537 make_regular_sumcheck_prover_for_eq_ind_sumcheck::<_, FDomain, _, _, _, _>(
538 multilins.clone(),
539 &composite_claims,
540 &challenges,
541 domain_factory.clone(),
542 |_| switchover_rd,
543 &backend,
544 );
545
546 let BatchSumcheckOutput {
547 challenges: sumcheck_challenges_1,
548 multilinear_evals: multilinear_evals_1,
549 } = sumcheck::batch_prove(vec![reference_prover], &mut prove_transcript_1).unwrap();
550
551 let optimized_prover =
552 EqIndSumcheckProverBuilder::with_switchover(multilins, switchover_fn, &backend)
553 .unwrap()
554 .build::<FDomain, _>(
555 EvaluationOrder::HighToLow,
556 &challenges,
557 composite_claims,
558 domain_factory,
559 )
560 .unwrap();
561
562 let mut prove_transcript_2 = ProverTranscript::<HasherChallenger<Groestl256>>::new();
563 let _: Vec<BinaryField128b> = prove_transcript_2.sample_vec(n_vars);
564 let BatchSumcheckOutput {
565 challenges: sumcheck_challenges_2,
566 multilinear_evals: multilinear_evals_2,
567 } = sumcheck::batch_prove(vec![optimized_prover], &mut prove_transcript_2).unwrap();
568
569 assert_eq!(prove_transcript_1.finalize(), prove_transcript_2.finalize());
570 assert_eq!(multilinear_evals_1, multilinear_evals_2);
571 assert_eq!(sumcheck_challenges_1, sumcheck_challenges_2);
572 }
573
574 fn test_prove_verify_product_constraint_helper(
575 n_vars: usize,
576 n_multilinears: usize,
577 switchover_rd: usize,
578 ) {
579 type P = PackedBinaryField1x128b;
580 type FBase = BinaryField32b;
581 type FE = BinaryField128b;
582 type FDomain = BinaryField8b;
583 let mut rng = StdRng::seed_from_u64(0);
584
585 let multilins = generate_zero_product_multilinears::<PackedSubfield<P, FBase>, P>(
586 &mut rng,
587 n_vars,
588 n_multilinears,
589 );
590
591 let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
592 let challenges = prove_transcript.sample_vec(n_vars);
593
594 let composite_claim = CompositeSumClaim {
595 composition: TestProductComposition::new(n_multilinears),
596 sum: Field::ZERO,
597 };
598
599 let composite_claims = vec![composite_claim];
600
601 let switchover_fn = |_| switchover_rd;
602
603 let sumcheck_multilinears = multilins
604 .iter()
605 .cloned()
606 .map(|multilin| SumcheckMultilinear::transparent(multilin, &switchover_fn))
607 .collect::<Vec<_>>();
608
609 sumcheck::prove::eq_ind::validate_witness(
610 n_vars,
611 &sumcheck_multilinears,
612 &challenges,
613 composite_claims.clone(),
614 )
615 .unwrap();
616
617 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
618 let backend = make_portable_backend();
619
620 let prover =
621 EqIndSumcheckProverBuilder::with_switchover(multilins.clone(), switchover_fn, &backend)
622 .unwrap()
623 .build::<FDomain, _>(
624 EvaluationOrder::HighToLow,
625 &challenges,
626 composite_claims.clone(),
627 domain_factory,
628 )
629 .unwrap();
630
631 let prove_output =
632 sumcheck::prove::batch_prove(vec![prover], &mut prove_transcript).unwrap();
633
634 let eq_ind_sumcheck_claim =
635 EqIndSumcheckClaim::new(n_vars, n_multilinears, composite_claims).unwrap();
636 let eq_ind_sumcheck_claims = vec![eq_ind_sumcheck_claim];
637
638 let BatchSumcheckOutput {
639 challenges: prover_eval_point,
640 multilinear_evals: prover_multilinear_evals,
641 } = sumcheck::eq_ind::verify_sumcheck_outputs(
642 ClaimsSortingOrder::AscendingVars,
643 &eq_ind_sumcheck_claims,
644 &challenges,
645 prove_output,
646 )
647 .unwrap();
648
649 let prover_sample = CanSample::<FE>::sample(&mut prove_transcript);
650 let mut verify_transcript = prove_transcript.into_verifier();
651 let _: Vec<BinaryField128b> = verify_transcript.sample_vec(n_vars);
652
653 let regular_sumcheck_claims =
654 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims).unwrap();
655
656 let verifier_output = sumcheck::batch_verify(
657 EvaluationOrder::HighToLow,
658 ®ular_sumcheck_claims,
659 &mut verify_transcript,
660 )
661 .unwrap();
662
663 let BatchSumcheckOutput {
664 challenges: verifier_eval_point,
665 multilinear_evals: verifier_multilinear_evals,
666 } = sumcheck::eq_ind::verify_sumcheck_outputs(
667 ClaimsSortingOrder::AscendingVars,
668 &eq_ind_sumcheck_claims,
669 &challenges,
670 verifier_output,
671 )
672 .unwrap();
673
674 assert_eq!(prover_sample, CanSample::<FE>::sample(&mut verify_transcript));
676 verify_transcript.finalize().unwrap();
677
678 assert_eq!(prover_eval_point, verifier_eval_point);
679 assert_eq!(prover_multilinear_evals, verifier_multilinear_evals);
680
681 assert_eq!(verifier_multilinear_evals.len(), 1);
682 assert_eq!(verifier_multilinear_evals[0].len(), n_multilinears);
683
684 let multilin_query = backend.multilinear_query(&verifier_eval_point).unwrap();
686 for (multilinear, &expected) in iter::zip(multilins, verifier_multilinear_evals[0].iter()) {
687 assert_eq!(multilinear.evaluate(multilin_query.to_ref()).unwrap(), expected);
688 }
689 }
690
691 #[test]
692 fn test_compare_eq_ind_prover_to_regular_sumcheck() {
693 for n_vars in 2..8 {
694 for n_multilinears in 1..5 {
695 for switchover_rd in 1..=n_vars / 2 {
696 test_compare_prover_with_reference(n_vars, n_multilinears, switchover_rd);
697 }
698 }
699 }
700 }
701
702 #[test]
703 fn test_prove_verify_product_basic() {
704 for n_vars in 2..8 {
705 for n_multilinears in 1..5 {
706 for switchover_rd in 1..=n_vars / 2 {
707 test_prove_verify_product_constraint_helper(
708 n_vars,
709 n_multilinears,
710 switchover_rd,
711 );
712 }
713 }
714 }
715 }
716}