1use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, TowerField,
7 packed::{get_packed_slice, get_packed_slice_checked},
8 recast_packed_mut,
9 util::inner_product_unchecked,
10};
11use binius_hal::ComputationBackend;
12use binius_math::{
13 BinarySubspace, CompositionPoly, Error as MathError, EvaluationDomain, MLEDirectAdapter,
14 MultilinearPoly, RowsBatchRef,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_ntt::{
18 AdditiveNTT, NTTShape, OddInterpolate, SingleThreadedNTT, twiddle::TwiddleAccess,
19};
20use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
21use bytemuck::zeroed_vec;
22use itertools::izip;
23use stackalloc::stackalloc_with_iter;
24use tracing::instrument;
25
26use crate::{
27 composition::{BivariateProduct, IndexComposition},
28 protocols::sumcheck::{
29 Error,
30 common::{equal_n_vars_check, small_field_embedding_degree_check},
31 prove::{
32 RegularSumcheckProver,
33 logging::{ExpandQueryData, UnivariateSkipCalculateCoeffsData},
34 },
35 zerocheck::{domain_size, extrapolated_scalars_count},
36 },
37};
38
39pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
40 'a,
41 FDomain,
42 P,
43 IndexComposition<BivariateProduct, 2>,
44 MLEDirectAdapter<P>,
45 Backend,
46>;
47
48#[derive(Debug)]
49struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
50 evals: Vec<P>,
53 extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
55 composition_evals: Vec<PackedSubfield<P, FBase>>,
57 packed_round_evals: Vec<Vec<P>>,
59}
60
61impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
62 fn new(
63 n_multilinears: usize,
64 skip_rounds: usize,
65 log_batch: usize,
66 log_embedding_degree: usize,
67 composition_degrees: impl Iterator<Item = usize> + Clone,
68 ) -> Self {
69 let subcube_vars = skip_rounds + log_batch;
70 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
71 let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
72 composition_max_degree,
73 skip_rounds,
74 log_batch,
75 );
76
77 let evals =
78 zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
79
80 let extrapolated_evals = (0..n_multilinears)
81 .map(|_| zeroed_vec(extrapolated_packed_pbase_len))
82 .collect();
83
84 let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
85
86 let packed_round_evals = composition_degrees
87 .map(|composition_degree| {
88 zeroed_vec(extrapolated_evals_packed_len::<P>(composition_degree, skip_rounds, 0))
89 })
90 .collect();
91
92 Self {
93 evals,
94 extrapolated_evals,
95 composition_evals,
96 packed_round_evals,
97 }
98 }
99}
100
101#[derive(Debug)]
102pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
103where
104 F: Field,
105 P: PackedField<Scalar = F>,
106 Backend: ComputationBackend,
107{
108 pub round_evals: Vec<Vec<F>>,
109 skip_rounds: usize,
110 remaining_rounds: usize,
111 max_domain_size: usize,
112 partial_eq_ind_evals: Backend::Vec<P>,
113}
114
115pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
116where
117 F: Field,
118 P: PackedField<Scalar = F>,
119 Backend: ComputationBackend,
120{
121 pub remaining_rounds: usize,
122 pub subcube_lagrange_coeffs: Vec<F>,
123 pub claimed_sums: Vec<F>,
124 pub partial_eq_ind_evals: Backend::Vec<P>,
125}
126
127impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
128where
129 F: Field,
130 P: PackedField<Scalar = F>,
131 Backend: ComputationBackend,
132{
133 #[instrument(
135 skip_all,
136 name = "ZerocheckUnivariateEvalsOutput::fold",
137 level = "debug"
138 )]
139 pub fn fold<FDomain>(
140 self,
141 challenge: F,
142 ) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
143 where
144 FDomain: TowerField,
145 F: ExtensionField<FDomain>,
146 {
147 let Self {
148 round_evals,
149 skip_rounds,
150 remaining_rounds,
151 max_domain_size,
152 partial_eq_ind_evals,
153 } = self;
154
155 let max_dim = log2_ceil_usize(max_domain_size);
158 let subspace =
159 BinarySubspace::<FDomain::Canonical>::with_dim(max_dim)?.isomorphic::<FDomain>();
160 let max_domain = EvaluationDomain::from_points(
161 subspace.iter().take(max_domain_size).collect::<Vec<_>>(),
162 false,
163 )?;
164
165 let subcube_lagrange_coeffs = EvaluationDomain::from_points(
167 subspace.reduce_dim(skip_rounds)?.iter().collect::<Vec<_>>(),
168 false,
169 )?
170 .lagrange_evals(challenge);
171
172 let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
174
175 let claimed_sums = round_evals
176 .into_iter()
177 .map(|evals| {
178 inner_product_unchecked::<F, F>(
179 evals,
180 round_evals_lagrange_coeffs[1 << skip_rounds..]
181 .iter()
182 .copied(),
183 )
184 })
185 .collect();
186
187 Ok(ZerocheckUnivariateFoldResult {
188 remaining_rounds,
189 subcube_lagrange_coeffs,
190 claimed_sums,
191 partial_eq_ind_evals,
192 })
193 }
194}
195
196pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
236 multilinears: &[M],
237 compositions: &[Composition],
238 zerocheck_challenges: &[F],
239 skip_rounds: usize,
240 max_domain_size: usize,
241 backend: &Backend,
242) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
243where
244 FDomain: TowerField,
245 FBase: ExtensionField<FDomain>,
246 F: TowerField,
247 P: PackedField<Scalar = F> + PackedExtension<FBase> + PackedExtension<FDomain>,
248 Composition: CompositionPoly<PackedSubfield<P, FBase>>,
249 M: MultilinearPoly<P> + Send + Sync,
250 Backend: ComputationBackend,
251{
252 let n_vars = equal_n_vars_check(multilinears)?;
253 let n_multilinears = multilinears.len();
254
255 if skip_rounds > n_vars {
256 bail!(Error::TooManySkippedRounds);
257 }
258
259 let remaining_rounds = n_vars - skip_rounds;
260 if zerocheck_challenges.len() != remaining_rounds {
261 bail!(Error::IncorrectZerocheckChallengesLength);
262 }
263
264 small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
265
266 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
267 let composition_degrees = compositions.iter().map(|composition| composition.degree());
268 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
269
270 if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
271 bail!(Error::LagrangeDomainTooSmall);
272 }
273
274 let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
276
277 let min_domain_bits = log2_ceil_usize(max_domain_size);
279 if min_domain_bits > FDomain::N_BITS {
280 bail!(MathError::DomainSizeTooLarge);
281 }
282
283 let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
285 .expect("FDomain cardinality checked before")
286 .precompute_twiddles();
287
288 const MAX_SUBCUBE_VARS: usize = 12;
292 let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
293
294 let dimensions_data = ExpandQueryData::new(zerocheck_challenges);
299 let expand_span = tracing::debug_span!(
300 "[task] Expand Query",
301 phase = "zerocheck",
302 perfetto_category = "task.main",
303 ?dimensions_data,
304 )
305 .entered();
306 let partial_eq_ind_evals: <Backend as ComputationBackend>::Vec<P> =
307 backend.tensor_product_full_query(zerocheck_challenges)?;
308 drop(expand_span);
309
310 let pbase_prefix_lens = composition_degrees
312 .clone()
313 .map(|composition_degree| {
314 extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
315 composition_degree,
316 skip_rounds,
317 log_batch,
318 )
319 })
320 .collect::<Vec<_>>();
321 let dimensions_data =
322 UnivariateSkipCalculateCoeffsData::new(n_vars, skip_rounds, n_multilinears, log_batch);
323 let coeffs_span = tracing::debug_span!(
324 "[task] Univariate Skip Calculate coeffs",
325 phase = "zerocheck",
326 perfetto_category = "task.main",
327 ?dimensions_data,
328 )
329 .entered();
330
331 let subcube_vars = log_batch + skip_rounds;
332 let log_subcube_count = n_vars - subcube_vars;
333
334 let p_coset_round_evals_len = 1 << skip_rounds.saturating_sub(P::LOG_WIDTH);
335 let pbase_coset_composition_evals_len =
336 1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree);
337
338 let staggered_round_evals = (0..1 << log_subcube_count)
344 .into_par_iter()
345 .try_fold(
346 || {
347 ParFoldStates::<FBase, P>::new(
348 n_multilinears,
349 skip_rounds,
350 log_batch,
351 log_embedding_degree,
352 composition_degrees.clone(),
353 )
354 },
355 |mut par_fold_states, subcube_index| -> Result<_, Error> {
356 let ParFoldStates {
357 evals,
358 extrapolated_evals,
359 composition_evals,
360 packed_round_evals,
361 } = &mut par_fold_states;
362
363 for (multilinear, extrapolated_evals) in
365 izip!(multilinears, extrapolated_evals.iter_mut())
366 {
367 multilinear.subcube_evals(
369 subcube_vars,
370 subcube_index,
371 log_embedding_degree,
372 evals,
373 )?;
374
375 let evals_base = <P as PackedExtension<FBase>>::cast_bases_mut(evals);
380 let evals_domain = recast_packed_mut::<P, FBase, FDomain>(evals_base);
381 let extrapolated_evals_domain =
382 recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
383
384 ntt_extrapolate(
385 &fdomain_ntt,
386 skip_rounds,
387 log_extension_degree_base_domain,
388 log_batch,
389 evals_domain,
390 extrapolated_evals_domain,
391 )?
392 }
393
394 for (composition, packed_round_evals, &pbase_prefix_len) in
396 izip!(compositions, packed_round_evals, &pbase_prefix_lens)
397 {
398 let extrapolated_evals_iter = extrapolated_evals
399 .iter()
400 .map(|evals| &evals[..pbase_prefix_len]);
401
402 stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
403 let batch_query = RowsBatchRef::new(batch_query, pbase_prefix_len);
404
405 composition.batch_evaluate(
407 &batch_query,
408 &mut composition_evals[..pbase_prefix_len],
409 )
410 })?;
411
412 for (packed_round_evals_coset, composition_evals_coset) in izip!(
415 packed_round_evals.chunks_exact_mut(p_coset_round_evals_len,),
416 composition_evals.chunks_exact(pbase_coset_composition_evals_len)
417 ) {
418 spread_product::<_, FBase>(
430 packed_round_evals_coset,
431 composition_evals_coset,
432 &partial_eq_ind_evals,
433 subcube_index,
434 skip_rounds,
435 log_batch,
436 );
437 }
438 }
439
440 Ok(par_fold_states)
441 },
442 )
443 .map(|states| -> Result<_, Error> {
444 let scalar_round_evals = izip!(composition_degrees.clone(), states?.packed_round_evals)
445 .map(|(composition_degree, packed_round_evals)| {
446 let mut composition_round_evals = Vec::with_capacity(
447 extrapolated_scalars_count(composition_degree, skip_rounds),
448 );
449
450 for packed_round_evals_coset in
451 packed_round_evals.chunks_exact(p_coset_round_evals_len)
452 {
453 let coset_scalars = packed_round_evals_coset
454 .iter()
455 .flat_map(|packed| packed.iter())
456 .take(1 << skip_rounds);
457
458 composition_round_evals.extend(coset_scalars);
459 }
460
461 composition_round_evals
462 })
463 .collect::<Vec<_>>();
464
465 Ok(scalar_round_evals)
466 })
467 .try_reduce(
468 || {
469 composition_degrees
470 .clone()
471 .map(|composition_degree| {
472 zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
473 })
474 .collect()
475 },
476 |lhs, rhs| -> Result<_, Error> {
477 let round_evals_sum = izip!(lhs, rhs)
478 .map(|(mut lhs_vals, rhs_vals)| {
479 debug_assert_eq!(lhs_vals.len(), rhs_vals.len());
480 for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
481 *lhs_val += rhs_val;
482 }
483 lhs_vals
484 })
485 .collect();
486
487 Ok(round_evals_sum)
488 },
489 )?;
490
491 let round_evals =
495 extrapolate_round_evals(&fdomain_ntt, staggered_round_evals, skip_rounds, max_domain_size)?;
496 drop(coeffs_span);
497
498 Ok(ZerocheckUnivariateEvalsOutput {
499 round_evals,
500 skip_rounds,
501 remaining_rounds,
502 max_domain_size,
503 partial_eq_ind_evals,
504 })
505}
506
507fn spread_product<P, FBase>(
510 accum: &mut [P],
511 small: &[PackedSubfield<P, FBase>],
512 large: &[P],
513 subcube_index: usize,
514 log_n: usize,
515 log_batch: usize,
516) where
517 P: PackedExtension<FBase>,
518 FBase: Field,
519{
520 let log_embedding_degree = <P::Scalar as ExtensionField<FBase>>::LOG_DEGREE;
521 let pbase_log_width = P::LOG_WIDTH + log_embedding_degree;
522
523 debug_assert_eq!(accum.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH));
524 debug_assert_eq!(small.len(), 1 << (log_n + log_batch).saturating_sub(pbase_log_width));
525
526 if log_n >= P::LOG_WIDTH {
527 let mask = (1 << log_embedding_degree) - 1;
529 for batch_idx in 0..1 << log_batch {
530 let mult = get_packed_slice(large, subcube_index << log_batch | batch_idx);
531 let spread_large = P::cast_base(P::broadcast(mult));
532
533 for (block_idx, dest) in accum.iter_mut().enumerate() {
534 let block_offset = block_idx | batch_idx << (log_n - P::LOG_WIDTH);
535 let spread_small = small[block_offset >> log_embedding_degree]
536 .spread(P::LOG_WIDTH, block_offset & mask);
537 *dest += P::cast_ext(spread_large * spread_small);
538 }
539 }
540 } else {
541 for (outer_idx, dest) in accum.iter_mut().enumerate() {
545 *dest = P::from_fn(|inner_idx| {
546 if inner_idx >= 1 << log_n {
547 return P::Scalar::ZERO;
548 }
549 (0..1 << log_batch)
550 .map(|batch_idx| {
551 let large = get_packed_slice(large, subcube_index << log_batch | batch_idx);
552 let small = get_packed_slice_checked(
553 small,
554 batch_idx << log_n | outer_idx << P::LOG_WIDTH | inner_idx,
555 )
556 .unwrap_or_default();
557 large * small
558 })
559 .sum()
560 })
561 }
562 }
563}
564
565#[instrument(skip_all, level = "debug")]
571fn extrapolate_round_evals<F, FDomain, TA>(
572 ntt: &SingleThreadedNTT<FDomain, TA>,
573 mut round_evals: Vec<Vec<F>>,
574 skip_rounds: usize,
575 max_domain_size: usize,
576) -> Result<Vec<Vec<F>>, Error>
577where
578 F: BinaryField + ExtensionField<FDomain>,
579 FDomain: BinaryField,
580 TA: TwiddleAccess<FDomain>,
581{
582 let subspace_upcast = BinarySubspace::new_unchecked(
587 ntt.subspace(ntt.log_domain_size())
588 .basis()
589 .iter()
590 .copied()
591 .map(F::from)
592 .collect(),
593 );
594 let ntt = SingleThreadedNTT::with_subspace(&subspace_upcast)
595 .expect("ntt provided is valid; subspace is equivalent but upcast to F");
596
597 let mut odd_interpolates = HashMap::new();
599
600 for round_evals in &mut round_evals {
601 round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
603
604 let n = round_evals.len();
605
606 let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
608 let ell = n.trailing_zeros() as usize;
609 assert!(ell >= skip_rounds);
610
611 let coset_bits = ntt.log_domain_size() - ell;
612 OddInterpolate::new(&ntt, n >> ell, ell, coset_bits)
613 .expect("domain large enough by construction")
614 });
615
616 odd_interpolate.inverse_transform(round_evals)?;
618
619 let next_log_n = ntt.log_domain_size();
621 round_evals.resize(1 << next_log_n, F::ZERO);
622
623 let shape = NTTShape {
624 log_y: next_log_n,
625 ..Default::default()
626 };
627 ntt.forward_transform(round_evals, shape, 0, 0, 0)?;
628
629 debug_assert!(
631 round_evals[..1 << skip_rounds]
632 .iter()
633 .all(|&coeff| coeff == F::ZERO)
634 );
635
636 round_evals.resize(max_domain_size, F::ZERO);
638 round_evals.drain(..1 << skip_rounds);
639 }
640
641 Ok(round_evals)
642}
643
644fn ntt_extrapolate<NTT, P>(
645 ntt: &NTT,
646 skip_rounds: usize,
647 log_stride_batch: usize,
648 log_batch: usize,
649 evals: &mut [P],
650 extrapolated_evals: &mut [P],
651) -> Result<(), Error>
652where
653 P: PackedField<Scalar: BinaryField>,
654 NTT: AdditiveNTT<P::Scalar>,
655{
656 let shape = NTTShape {
657 log_x: log_stride_batch,
658 log_y: skip_rounds,
659 log_z: log_batch,
660 };
661
662 let coset_bits = ntt.log_domain_size() - skip_rounds;
663
664 ntt.inverse_transform(evals, shape, 0, coset_bits, 0)?;
666
667 for (coset, extrapolated_chunk) in izip!(1.., extrapolated_evals.chunks_exact_mut(evals.len()))
669 {
670 extrapolated_chunk.copy_from_slice(evals);
673 ntt.forward_transform(extrapolated_chunk, shape, coset, coset_bits, 0)?;
674 }
675
676 Ok(())
677}
678
679const fn extrapolated_evals_packed_len<P: PackedField>(
680 composition_degree: usize,
681 skip_rounds: usize,
682 log_batch: usize,
683) -> usize {
684 composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
685}
686
687#[cfg(test)]
688mod tests {
689 use std::sync::Arc;
690
691 use binius_field::{
692 BinaryField1b, BinaryField8b, BinaryField16b, BinaryField128b, ExtensionField, Field,
693 PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
694 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
695 as_packed_field::{PackScalar, PackedType},
696 underlier::UnderlierType,
697 };
698 use binius_hal::make_portable_backend;
699 use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearPoly};
700 use binius_ntt::SingleThreadedNTT;
701 use rand::{SeedableRng, prelude::StdRng};
702
703 use crate::{
704 composition::{IndexComposition, ProductComposition},
705 polynomial::CompositionScalarAdapter,
706 protocols::{
707 sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
708 test_utils::generate_zero_product_multilinears,
709 },
710 transparent::eq_ind::EqIndPartialEval,
711 };
712
713 #[test]
714 fn ntt_extrapolate_correctness() {
715 type P = PackedBinaryField4x32b;
716 type FDomain = BinaryField16b;
717 let log_extension_degree_p_domain = 1;
718
719 let mut rng = StdRng::seed_from_u64(0);
720 let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
721 let subspace = BinarySubspace::<FDomain>::with_dim(10).unwrap();
722 let max_domain =
723 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
724
725 for skip_rounds in 0..5usize {
726 let subsubspace = subspace.reduce_dim(skip_rounds).unwrap();
727 let domain =
728 EvaluationDomain::from_points(subsubspace.iter().collect::<Vec<_>>(), false)
729 .unwrap();
730 for log_batch in 0..3usize {
731 for composition_degree in 0..5usize {
732 let subcube_vars = skip_rounds + log_batch;
733 let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
734 let interleaved_evals = (0..interleaved_len)
735 .map(|_| P::random(&mut rng))
736 .collect::<Vec<_>>();
737
738 let extrapolated_scalars_cnt =
739 composition_degree.saturating_sub(1) << skip_rounds;
740 let extrapolated_ntts = composition_degree.saturating_sub(1);
741 let extrapolated_len = extrapolated_ntts * interleaved_len;
742 let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
743
744 let mut interleaved_evals_scratch = interleaved_evals.clone();
745
746 let interleaved_evals_domain =
747 P::cast_bases_mut(&mut interleaved_evals_scratch);
748 let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
749
750 super::ntt_extrapolate(
751 &ntt,
752 skip_rounds,
753 log_extension_degree_p_domain,
754 log_batch,
755 interleaved_evals_domain,
756 extrapolated_evals_domain,
757 )
758 .unwrap();
759
760 let interleaved_scalars =
761 &P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
762 let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
763 [..extrapolated_scalars_cnt << log_batch];
764
765 for batch_idx in 0..1 << log_batch {
766 let values =
767 &interleaved_scalars[batch_idx << skip_rounds..][..1 << skip_rounds];
768
769 for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
770 [..extrapolated_scalars_cnt]
771 .iter()
772 .take(1 << skip_rounds)
773 .enumerate()
774 {
775 let extrapolated = domain.extrapolate(values, point.into()).unwrap();
776 let expected = extrapolated_scalars[batch_idx << skip_rounds | i];
777 assert_eq!(extrapolated, expected);
778 }
779 }
780 }
781 }
782 }
783 }
784
785 #[test]
786 fn zerocheck_univariate_evals_invariants_basic() {
787 zerocheck_univariate_evals_invariants_helper::<
788 OptimalUnderlier128b,
789 BinaryField128b,
790 BinaryField8b,
791 BinaryField16b,
792 >()
793 }
794
795 #[test]
796 fn zerocheck_univariate_evals_with_nontrivial_packing() {
797 zerocheck_univariate_evals_invariants_helper::<
800 OptimalUnderlier512b,
801 BinaryField128b,
802 BinaryField8b,
803 BinaryField16b,
804 >()
805 }
806
807 fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
808 where
809 U: UnderlierType
810 + PackScalar<F>
811 + PackScalar<FBase>
812 + PackScalar<FDomain>
813 + PackScalar<BinaryField1b>,
814 F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
815 FBase: TowerField + ExtensionField<FDomain>,
816 FDomain: TowerField + From<u8>,
817 PackedType<U, FBase>: PackedFieldIndexable,
818 PackedType<U, FDomain>: PackedFieldIndexable,
819 PackedType<U, F>: PackedFieldIndexable,
820 {
821 let mut rng = StdRng::seed_from_u64(0);
822
823 let n_vars = 7;
824 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
825
826 let mut multilinears = generate_zero_product_multilinears::<
827 PackedType<U, BinaryField1b>,
828 PackedType<U, F>,
829 >(&mut rng, n_vars, 2);
830 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
831 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
832
833 let compositions = [
834 Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
835 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
836 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
837 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
838 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
839 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
840 ];
841
842 let backend = make_portable_backend();
843 let zerocheck_challenges = (0..n_vars)
844 .map(|_| <F as Field>::random(&mut rng))
845 .collect::<Vec<_>>();
846
847 for skip_rounds in 0usize..=5 {
848 let max_domain_size = domain_size(5, skip_rounds);
849 let output =
850 zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
851 &multilinears,
852 &compositions,
853 &zerocheck_challenges[skip_rounds..],
854 skip_rounds,
855 max_domain_size,
856 &backend,
857 )
858 .unwrap();
859
860 let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
861 .multilinear_extension::<F, _>(&backend)
862 .unwrap();
863
864 let round_evals_len = 4usize << skip_rounds;
866 assert!(
867 output
868 .round_evals
869 .iter()
870 .all(|round_evals| round_evals.len() == round_evals_len)
871 );
872
873 let compositions = compositions
874 .iter()
875 .cloned()
876 .map(CompositionScalarAdapter::new)
877 .collect::<Vec<_>>();
878
879 let mut query = [FBase::ZERO; 9];
880 let mut evals = vec![
881 PackedType::<U, F>::zero();
882 1 << skip_rounds.saturating_sub(
883 log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
884 )
885 ];
886 let subspace = BinarySubspace::<FDomain>::with_dim(skip_rounds).unwrap();
887 let domain =
888 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
889 for round_evals_index in 0..round_evals_len {
890 let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
891 let mut composition_sums = vec![F::ZERO; compositions.len()];
892 for subcube_index in 0..1 << (n_vars - skip_rounds) {
893 for (query, multilinear) in query.iter_mut().zip(&multilinears) {
894 multilinear
895 .subcube_evals(
896 skip_rounds,
897 subcube_index,
898 log_embedding_degree,
899 &mut evals,
900 )
901 .unwrap();
902 let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
903 PackedExtension::<FBase>::cast_bases(&evals),
904 )[..1 << skip_rounds];
905 let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
906 *query = extrapolated;
907 }
908
909 let eq_ind_factor = zerocheck_eq_ind
910 .evaluate_on_hypercube(subcube_index)
911 .unwrap();
912 for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
913 *sum += eq_ind_factor * composition.evaluate(&query).unwrap();
914 }
915 }
916
917 let univariate_skip_composition_sums = output
918 .round_evals
919 .iter()
920 .map(|round_evals| round_evals[round_evals_index])
921 .collect::<Vec<_>>();
922 assert_eq!(univariate_skip_composition_sums, composition_sums);
923 }
924 }
925 }
926}