1use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6 packed::{get_packed_slice, get_packed_slice_checked},
7 recast_packed_mut,
8 util::inner_product_unchecked,
9 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, TowerField,
10};
11use binius_hal::ComputationBackend;
12use binius_math::{
13 BinarySubspace, CompositionPoly, Error as MathError, EvaluationDomain, MLEDirectAdapter,
14 MultilinearPoly, RowsBatchRef,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_ntt::{AdditiveNTT, NTTShape, OddInterpolate, SingleThreadedNTT};
18use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
19use bytemuck::zeroed_vec;
20use itertools::izip;
21use stackalloc::stackalloc_with_iter;
22use tracing::instrument;
23
24use crate::{
25 composition::{BivariateProduct, IndexComposition},
26 protocols::sumcheck::{
27 common::{equal_n_vars_check, small_field_embedding_degree_check},
28 prove::{
29 logging::{ExpandQueryData, UnivariateSkipCalculateCoeffsData},
30 RegularSumcheckProver,
31 },
32 zerocheck::{domain_size, extrapolated_scalars_count},
33 Error,
34 },
35};
36
37pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
38 'a,
39 FDomain,
40 P,
41 IndexComposition<BivariateProduct, 2>,
42 MLEDirectAdapter<P>,
43 Backend,
44>;
45
46#[derive(Debug)]
47struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
48 evals: Vec<P>,
50 extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
52 composition_evals: Vec<PackedSubfield<P, FBase>>,
54 packed_round_evals: Vec<Vec<P>>,
56}
57
58impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
59 fn new(
60 n_multilinears: usize,
61 skip_rounds: usize,
62 log_batch: usize,
63 log_embedding_degree: usize,
64 composition_degrees: impl Iterator<Item = usize> + Clone,
65 ) -> Self {
66 let subcube_vars = skip_rounds + log_batch;
67 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
68 let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
69 composition_max_degree,
70 skip_rounds,
71 log_batch,
72 );
73
74 let evals =
75 zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
76
77 let extrapolated_evals = (0..n_multilinears)
78 .map(|_| zeroed_vec(extrapolated_packed_pbase_len))
79 .collect();
80
81 let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
82
83 let packed_round_evals = composition_degrees
84 .map(|composition_degree| {
85 zeroed_vec(extrapolated_evals_packed_len::<P>(composition_degree, skip_rounds, 0))
86 })
87 .collect();
88
89 Self {
90 evals,
91 extrapolated_evals,
92 composition_evals,
93 packed_round_evals,
94 }
95 }
96}
97
98#[derive(Debug)]
99pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
100where
101 F: Field,
102 P: PackedField<Scalar = F>,
103 Backend: ComputationBackend,
104{
105 pub round_evals: Vec<Vec<F>>,
106 skip_rounds: usize,
107 remaining_rounds: usize,
108 max_domain_size: usize,
109 partial_eq_ind_evals: Backend::Vec<P>,
110}
111
112pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
113where
114 F: Field,
115 P: PackedField<Scalar = F>,
116 Backend: ComputationBackend,
117{
118 pub remaining_rounds: usize,
119 pub subcube_lagrange_coeffs: Vec<F>,
120 pub claimed_sums: Vec<F>,
121 pub partial_eq_ind_evals: Backend::Vec<P>,
122}
123
124impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
125where
126 F: Field,
127 P: PackedField<Scalar = F>,
128 Backend: ComputationBackend,
129{
130 #[instrument(
132 skip_all,
133 name = "ZerocheckUnivariateEvalsOutput::fold",
134 level = "debug"
135 )]
136 pub fn fold<FDomain>(
137 self,
138 challenge: F,
139 ) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
140 where
141 FDomain: TowerField,
142 F: ExtensionField<FDomain>,
143 {
144 let Self {
145 round_evals,
146 skip_rounds,
147 remaining_rounds,
148 max_domain_size,
149 partial_eq_ind_evals,
150 } = self;
151
152 let max_dim = log2_ceil_usize(max_domain_size);
155 let subspace =
156 BinarySubspace::<FDomain::Canonical>::with_dim(max_dim)?.isomorphic::<FDomain>();
157 let max_domain = EvaluationDomain::from_points(
158 subspace.iter().take(max_domain_size).collect::<Vec<_>>(),
159 false,
160 )?;
161
162 let subcube_lagrange_coeffs = EvaluationDomain::from_points(
164 subspace.reduce_dim(skip_rounds)?.iter().collect::<Vec<_>>(),
165 false,
166 )?
167 .lagrange_evals(challenge);
168
169 let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
171
172 let claimed_sums = round_evals
173 .into_iter()
174 .map(|evals| {
175 inner_product_unchecked::<F, F>(
176 evals,
177 round_evals_lagrange_coeffs[1 << skip_rounds..]
178 .iter()
179 .copied(),
180 )
181 })
182 .collect();
183
184 Ok(ZerocheckUnivariateFoldResult {
185 remaining_rounds,
186 subcube_lagrange_coeffs,
187 claimed_sums,
188 partial_eq_ind_evals,
189 })
190 }
191}
192
193pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
228 multilinears: &[M],
229 compositions: &[Composition],
230 zerocheck_challenges: &[F],
231 skip_rounds: usize,
232 max_domain_size: usize,
233 backend: &Backend,
234) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
235where
236 FDomain: TowerField,
237 FBase: ExtensionField<FDomain>,
238 F: TowerField,
239 P: PackedField<Scalar = F> + PackedExtension<FBase> + PackedExtension<FDomain>,
240 Composition: CompositionPoly<PackedSubfield<P, FBase>>,
241 M: MultilinearPoly<P> + Send + Sync,
242 Backend: ComputationBackend,
243{
244 let n_vars = equal_n_vars_check(multilinears)?;
245 let n_multilinears = multilinears.len();
246
247 if skip_rounds > n_vars {
248 bail!(Error::TooManySkippedRounds);
249 }
250
251 let remaining_rounds = n_vars - skip_rounds;
252 if zerocheck_challenges.len() != remaining_rounds {
253 bail!(Error::IncorrectZerocheckChallengesLength);
254 }
255
256 small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
257
258 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
259 let composition_degrees = compositions.iter().map(|composition| composition.degree());
260 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
261
262 if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
263 bail!(Error::LagrangeDomainTooSmall);
264 }
265
266 let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
268
269 let min_domain_bits = log2_ceil_usize(max_domain_size);
271 if min_domain_bits > FDomain::N_BITS {
272 bail!(MathError::DomainSizeTooLarge);
273 }
274
275 let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
277 .expect("FDomain cardinality checked before")
278 .precompute_twiddles();
279
280 const MAX_SUBCUBE_VARS: usize = 12;
284 let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
285
286 let dimensions_data = ExpandQueryData::new(zerocheck_challenges);
291 let expand_span = tracing::debug_span!(
292 "[task] Expand Query",
293 phase = "zerocheck",
294 perfetto_category = "task.main",
295 ?dimensions_data,
296 )
297 .entered();
298 let partial_eq_ind_evals: <Backend as ComputationBackend>::Vec<P> =
299 backend.tensor_product_full_query(zerocheck_challenges)?;
300 drop(expand_span);
301
302 let pbase_prefix_lens = composition_degrees
304 .clone()
305 .map(|composition_degree| {
306 extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
307 composition_degree,
308 skip_rounds,
309 log_batch,
310 )
311 })
312 .collect::<Vec<_>>();
313 let dimensions_data =
314 UnivariateSkipCalculateCoeffsData::new(n_vars, skip_rounds, n_multilinears, log_batch);
315 let coeffs_span = tracing::debug_span!(
316 "[task] Univariate Skip Calculate coeffs",
317 phase = "zerocheck",
318 perfetto_category = "task.main",
319 ?dimensions_data,
320 )
321 .entered();
322
323 let subcube_vars = log_batch + skip_rounds;
324 let log_subcube_count = n_vars - subcube_vars;
325
326 let p_coset_round_evals_len = 1 << skip_rounds.saturating_sub(P::LOG_WIDTH);
327 let pbase_coset_composition_evals_len =
328 1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree);
329
330 let staggered_round_evals = (0..1 << log_subcube_count)
336 .into_par_iter()
337 .try_fold(
338 || {
339 ParFoldStates::<FBase, P>::new(
340 n_multilinears,
341 skip_rounds,
342 log_batch,
343 log_embedding_degree,
344 composition_degrees.clone(),
345 )
346 },
347 |mut par_fold_states, subcube_index| -> Result<_, Error> {
348 let ParFoldStates {
349 evals,
350 extrapolated_evals,
351 composition_evals,
352 packed_round_evals,
353 } = &mut par_fold_states;
354
355 for (multilinear, extrapolated_evals) in
357 izip!(multilinears, extrapolated_evals.iter_mut())
358 {
359 multilinear.subcube_evals(
361 subcube_vars,
362 subcube_index,
363 log_embedding_degree,
364 evals,
365 )?;
366
367 let evals_base = <P as PackedExtension<FBase>>::cast_bases_mut(evals);
371 let evals_domain = recast_packed_mut::<P, FBase, FDomain>(evals_base);
372 let extrapolated_evals_domain =
373 recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
374
375 ntt_extrapolate(
376 &fdomain_ntt,
377 skip_rounds,
378 log_extension_degree_base_domain,
379 log_batch,
380 evals_domain,
381 extrapolated_evals_domain,
382 )?
383 }
384
385 for (composition, packed_round_evals, &pbase_prefix_len) in
387 izip!(compositions, packed_round_evals, &pbase_prefix_lens)
388 {
389 let extrapolated_evals_iter = extrapolated_evals
390 .iter()
391 .map(|evals| &evals[..pbase_prefix_len]);
392
393 stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
394 let batch_query = RowsBatchRef::new(batch_query, pbase_prefix_len);
395
396 composition.batch_evaluate(
398 &batch_query,
399 &mut composition_evals[..pbase_prefix_len],
400 )
401 })?;
402
403 for (packed_round_evals_coset, composition_evals_coset) in izip!(
406 packed_round_evals.chunks_exact_mut(p_coset_round_evals_len,),
407 composition_evals.chunks_exact(pbase_coset_composition_evals_len)
408 ) {
409 spread_product::<_, FBase>(
421 packed_round_evals_coset,
422 composition_evals_coset,
423 &partial_eq_ind_evals,
424 subcube_index,
425 skip_rounds,
426 log_batch,
427 );
428 }
429 }
430
431 Ok(par_fold_states)
432 },
433 )
434 .map(|states| -> Result<_, Error> {
435 let scalar_round_evals = izip!(composition_degrees.clone(), states?.packed_round_evals)
436 .map(|(composition_degree, packed_round_evals)| {
437 let mut composition_round_evals = Vec::with_capacity(
438 extrapolated_scalars_count(composition_degree, skip_rounds),
439 );
440
441 for packed_round_evals_coset in
442 packed_round_evals.chunks_exact(p_coset_round_evals_len)
443 {
444 let coset_scalars = packed_round_evals_coset
445 .iter()
446 .flat_map(|packed| packed.iter())
447 .take(1 << skip_rounds);
448
449 composition_round_evals.extend(coset_scalars);
450 }
451
452 composition_round_evals
453 })
454 .collect::<Vec<_>>();
455
456 Ok(scalar_round_evals)
457 })
458 .try_reduce(
459 || {
460 composition_degrees
461 .clone()
462 .map(|composition_degree| {
463 zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
464 })
465 .collect()
466 },
467 |lhs, rhs| -> Result<_, Error> {
468 let round_evals_sum = izip!(lhs, rhs)
469 .map(|(mut lhs_vals, rhs_vals)| {
470 debug_assert_eq!(lhs_vals.len(), rhs_vals.len());
471 for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
472 *lhs_val += rhs_val;
473 }
474 lhs_vals
475 })
476 .collect();
477
478 Ok(round_evals_sum)
479 },
480 )?;
481
482 let round_evals = extrapolate_round_evals(staggered_round_evals, skip_rounds, max_domain_size)?;
486 drop(coeffs_span);
487
488 Ok(ZerocheckUnivariateEvalsOutput {
489 round_evals,
490 skip_rounds,
491 remaining_rounds,
492 max_domain_size,
493 partial_eq_ind_evals,
494 })
495}
496
497fn spread_product<P, FBase>(
500 accum: &mut [P],
501 small: &[PackedSubfield<P, FBase>],
502 large: &[P],
503 subcube_index: usize,
504 log_n: usize,
505 log_batch: usize,
506) where
507 P: PackedExtension<FBase>,
508 FBase: Field,
509{
510 let log_embedding_degree = <P::Scalar as ExtensionField<FBase>>::LOG_DEGREE;
511 let pbase_log_width = P::LOG_WIDTH + log_embedding_degree;
512
513 debug_assert_eq!(accum.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH));
514 debug_assert_eq!(small.len(), 1 << (log_n + log_batch).saturating_sub(pbase_log_width));
515
516 if log_n >= P::LOG_WIDTH {
517 let mask = (1 << log_embedding_degree) - 1;
519 for batch_idx in 0..1 << log_batch {
520 let mult = get_packed_slice(large, subcube_index << log_batch | batch_idx);
521 let spread_large = P::cast_base(P::broadcast(mult));
522
523 for (block_idx, dest) in accum.iter_mut().enumerate() {
524 let block_offset = block_idx | batch_idx << (log_n - P::LOG_WIDTH);
525 let spread_small = small[block_offset >> log_embedding_degree]
526 .spread(P::LOG_WIDTH, block_offset & mask);
527 *dest += P::cast_ext(spread_large * spread_small);
528 }
529 }
530 } else {
531 for (outer_idx, dest) in accum.iter_mut().enumerate() {
535 *dest = P::from_fn(|inner_idx| {
536 if inner_idx >= 1 << log_n {
537 return P::Scalar::ZERO;
538 }
539 (0..1 << log_batch)
540 .map(|batch_idx| {
541 let large = get_packed_slice(large, subcube_index << log_batch | batch_idx);
542 let small = get_packed_slice_checked(
543 small,
544 batch_idx << log_n | outer_idx << P::LOG_WIDTH | inner_idx,
545 )
546 .unwrap_or_default();
547 large * small
548 })
549 .sum()
550 })
551 }
552 }
553}
554
555#[instrument(skip_all, level = "debug")]
561fn extrapolate_round_evals<F: TowerField>(
562 mut round_evals: Vec<Vec<F>>,
563 skip_rounds: usize,
564 max_domain_size: usize,
565) -> Result<Vec<Vec<F>>, Error> {
566 let ntt = SingleThreadedNTT::with_canonical_field(log2_ceil_usize(max_domain_size))?;
569
570 let mut odd_interpolates = HashMap::new();
572
573 for round_evals in &mut round_evals {
574 round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
576
577 let n = round_evals.len();
578
579 let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
581 let ell = n.trailing_zeros() as usize;
582 assert!(ell >= skip_rounds);
583
584 OddInterpolate::new(n >> ell, ell, ntt.twiddles())
585 .expect("domain large enough by construction")
586 });
587
588 odd_interpolate.inverse_transform(&ntt, round_evals)?;
590
591 let next_log_n = log2_ceil_usize(max_domain_size);
593 round_evals.resize(1 << next_log_n, F::ZERO);
594
595 let shape = NTTShape {
596 log_y: next_log_n,
597 ..Default::default()
598 };
599 ntt.forward_transform(round_evals, shape, 0, 0)?;
600
601 debug_assert!(round_evals[..1 << skip_rounds]
603 .iter()
604 .all(|&coeff| coeff == F::ZERO));
605
606 round_evals.resize(max_domain_size, F::ZERO);
608 round_evals.drain(..1 << skip_rounds);
609 }
610
611 Ok(round_evals)
612}
613
614fn ntt_extrapolate<NTT, P>(
615 ntt: &NTT,
616 skip_rounds: usize,
617 log_stride_batch: usize,
618 log_batch: usize,
619 evals: &mut [P],
620 extrapolated_evals: &mut [P],
621) -> Result<(), Error>
622where
623 P: PackedField<Scalar: BinaryField>,
624 NTT: AdditiveNTT<P::Scalar>,
625{
626 let shape = NTTShape {
627 log_x: log_stride_batch,
628 log_y: skip_rounds,
629 log_z: log_batch,
630 };
631
632 ntt.inverse_transform(evals, shape, 0, 0)?;
634
635 for (coset, extrapolated_chunk) in
637 izip!(1u32.., extrapolated_evals.chunks_exact_mut(evals.len()))
638 {
639 extrapolated_chunk.copy_from_slice(evals);
641 ntt.forward_transform(extrapolated_chunk, shape, coset, 0)?;
642 }
643
644 Ok(())
645}
646
647const fn extrapolated_evals_packed_len<P: PackedField>(
648 composition_degree: usize,
649 skip_rounds: usize,
650 log_batch: usize,
651) -> usize {
652 composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
653}
654
655#[cfg(test)]
656mod tests {
657 use std::sync::Arc;
658
659 use binius_field::{
660 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
661 as_packed_field::{PackScalar, PackedType},
662 underlier::UnderlierType,
663 BinaryField128b, BinaryField16b, BinaryField1b, BinaryField8b, ExtensionField, Field,
664 PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
665 };
666 use binius_hal::make_portable_backend;
667 use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearPoly};
668 use binius_ntt::SingleThreadedNTT;
669 use rand::{prelude::StdRng, SeedableRng};
670
671 use crate::{
672 composition::{IndexComposition, ProductComposition},
673 polynomial::CompositionScalarAdapter,
674 protocols::{
675 sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
676 test_utils::generate_zero_product_multilinears,
677 },
678 transparent::eq_ind::EqIndPartialEval,
679 };
680
681 #[test]
682 fn ntt_extrapolate_correctness() {
683 type P = PackedBinaryField4x32b;
684 type FDomain = BinaryField16b;
685 let log_extension_degree_p_domain = 1;
686
687 let mut rng = StdRng::seed_from_u64(0);
688 let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
689 let subspace = BinarySubspace::<FDomain>::with_dim(10).unwrap();
690 let max_domain =
691 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
692
693 for skip_rounds in 0..5usize {
694 let subsubspace = subspace.reduce_dim(skip_rounds).unwrap();
695 let domain =
696 EvaluationDomain::from_points(subsubspace.iter().collect::<Vec<_>>(), false)
697 .unwrap();
698 for log_batch in 0..3usize {
699 for composition_degree in 0..5usize {
700 let subcube_vars = skip_rounds + log_batch;
701 let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
702 let interleaved_evals = (0..interleaved_len)
703 .map(|_| P::random(&mut rng))
704 .collect::<Vec<_>>();
705
706 let extrapolated_scalars_cnt =
707 composition_degree.saturating_sub(1) << skip_rounds;
708 let extrapolated_ntts = composition_degree.saturating_sub(1);
709 let extrapolated_len = extrapolated_ntts * interleaved_len;
710 let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
711
712 let mut interleaved_evals_scratch = interleaved_evals.clone();
713
714 let interleaved_evals_domain =
715 P::cast_bases_mut(&mut interleaved_evals_scratch);
716 let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
717
718 super::ntt_extrapolate(
719 &ntt,
720 skip_rounds,
721 log_extension_degree_p_domain,
722 log_batch,
723 interleaved_evals_domain,
724 extrapolated_evals_domain,
725 )
726 .unwrap();
727
728 let interleaved_scalars =
729 &P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
730 let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
731 [..extrapolated_scalars_cnt << log_batch];
732
733 for batch_idx in 0..1 << log_batch {
734 let values =
735 &interleaved_scalars[batch_idx << skip_rounds..][..1 << skip_rounds];
736
737 for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
738 [..extrapolated_scalars_cnt]
739 .iter()
740 .take(1 << skip_rounds)
741 .enumerate()
742 {
743 let extrapolated = domain.extrapolate(values, point.into()).unwrap();
744 let expected = extrapolated_scalars[batch_idx << skip_rounds | i];
745 assert_eq!(extrapolated, expected);
746 }
747 }
748 }
749 }
750 }
751 }
752
753 #[test]
754 fn zerocheck_univariate_evals_invariants_basic() {
755 zerocheck_univariate_evals_invariants_helper::<
756 OptimalUnderlier128b,
757 BinaryField128b,
758 BinaryField8b,
759 BinaryField16b,
760 >()
761 }
762
763 #[test]
764 fn zerocheck_univariate_evals_with_nontrivial_packing() {
765 zerocheck_univariate_evals_invariants_helper::<
768 OptimalUnderlier512b,
769 BinaryField128b,
770 BinaryField8b,
771 BinaryField16b,
772 >()
773 }
774
775 fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
776 where
777 U: UnderlierType
778 + PackScalar<F>
779 + PackScalar<FBase>
780 + PackScalar<FDomain>
781 + PackScalar<BinaryField1b>,
782 F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
783 FBase: TowerField + ExtensionField<FDomain>,
784 FDomain: TowerField + From<u8>,
785 PackedType<U, FBase>: PackedFieldIndexable,
786 PackedType<U, FDomain>: PackedFieldIndexable,
787 PackedType<U, F>: PackedFieldIndexable,
788 {
789 let mut rng = StdRng::seed_from_u64(0);
790
791 let n_vars = 7;
792 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
793
794 let mut multilinears = generate_zero_product_multilinears::<
795 PackedType<U, BinaryField1b>,
796 PackedType<U, F>,
797 >(&mut rng, n_vars, 2);
798 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
799 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
800
801 let compositions = [
802 Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
803 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
804 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
805 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
806 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
807 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
808 ];
809
810 let backend = make_portable_backend();
811 let zerocheck_challenges = (0..n_vars)
812 .map(|_| <F as Field>::random(&mut rng))
813 .collect::<Vec<_>>();
814
815 for skip_rounds in 0usize..=5 {
816 let max_domain_size = domain_size(5, skip_rounds);
817 let output =
818 zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
819 &multilinears,
820 &compositions,
821 &zerocheck_challenges[skip_rounds..],
822 skip_rounds,
823 max_domain_size,
824 &backend,
825 )
826 .unwrap();
827
828 let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
829 .multilinear_extension::<F, _>(&backend)
830 .unwrap();
831
832 let round_evals_len = 4usize << skip_rounds;
834 assert!(output
835 .round_evals
836 .iter()
837 .all(|round_evals| round_evals.len() == round_evals_len));
838
839 let compositions = compositions
840 .iter()
841 .cloned()
842 .map(CompositionScalarAdapter::new)
843 .collect::<Vec<_>>();
844
845 let mut query = [FBase::ZERO; 9];
846 let mut evals = vec![
847 PackedType::<U, F>::zero();
848 1 << skip_rounds.saturating_sub(
849 log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
850 )
851 ];
852 let subspace = BinarySubspace::<FDomain>::with_dim(skip_rounds).unwrap();
853 let domain =
854 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
855 for round_evals_index in 0..round_evals_len {
856 let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
857 let mut composition_sums = vec![F::ZERO; compositions.len()];
858 for subcube_index in 0..1 << (n_vars - skip_rounds) {
859 for (query, multilinear) in query.iter_mut().zip(&multilinears) {
860 multilinear
861 .subcube_evals(
862 skip_rounds,
863 subcube_index,
864 log_embedding_degree,
865 &mut evals,
866 )
867 .unwrap();
868 let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
869 PackedExtension::<FBase>::cast_bases(&evals),
870 )[..1 << skip_rounds];
871 let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
872 *query = extrapolated;
873 }
874
875 let eq_ind_factor = zerocheck_eq_ind
876 .evaluate_on_hypercube(subcube_index)
877 .unwrap();
878 for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
879 *sum += eq_ind_factor * composition.evaluate(&query).unwrap();
880 }
881 }
882
883 let univariate_skip_composition_sums = output
884 .round_evals
885 .iter()
886 .map(|round_evals| round_evals[round_evals_index])
887 .collect::<Vec<_>>();
888 assert_eq!(univariate_skip_composition_sums, composition_sums);
889 }
890 }
891 }
892}