1use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6 recast_packed_mut, util::inner_product_unchecked, BinaryField, ExtensionField, Field,
7 PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield, TowerField,
8};
9use binius_hal::{ComputationBackend, ComputationBackendExt};
10use binius_math::{
11 CompositionPoly, Error as MathError, EvaluationDomainFactory, EvaluationOrder,
12 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearPoly,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_ntt::{AdditiveNTT, OddInterpolate, SingleThreadedNTT};
16use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
17use bytemuck::zeroed_vec;
18use itertools::izip;
19use stackalloc::stackalloc_with_iter;
20use tracing::instrument;
21use transpose::transpose;
22
23use crate::{
24 composition::{BivariateProduct, IndexComposition},
25 protocols::sumcheck::{
26 common::{
27 equal_n_vars_check, immediate_switchover_heuristic, small_field_embedding_degree_check,
28 },
29 prove::{common::fold_partial_eq_ind, RegularSumcheckProver},
30 univariate::{
31 lagrange_evals_multilinear_extension, univariatizing_reduction_composite_sum_claims,
32 },
33 univariate_zerocheck::{domain_size, extrapolated_scalars_count},
34 Error, VerificationError,
35 },
36};
37
38#[instrument(skip_all, level = "debug")]
40pub fn reduce_to_skipped_projection<F, P, M, Backend>(
41 multilinears: Vec<M>,
42 sumcheck_challenges: &[F],
43 backend: &'_ Backend,
44) -> Result<Vec<MLEDirectAdapter<P>>, Error>
45where
46 F: Field,
47 P: PackedFieldIndexable<Scalar = F>,
48 M: MultilinearPoly<P> + Send + Sync,
49 Backend: ComputationBackend,
50{
51 let n_vars = equal_n_vars_check(&multilinears)?;
52
53 if sumcheck_challenges.len() > n_vars {
54 bail!(Error::IncorrectNumberOfChallenges);
55 }
56
57 let query = backend.multilinear_query(sumcheck_challenges)?;
58
59 let reduced_multilinears = multilinears
60 .par_iter()
61 .map(|multilinear| {
62 backend
63 .evaluate_partial_high(multilinear, query.to_ref())
64 .expect("0 <= sumcheck_challenges.len() < n_vars")
65 .into()
66 })
67 .collect();
68
69 Ok(reduced_multilinears)
70}
71
72pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
73 'a,
74 FDomain,
75 P,
76 IndexComposition<BivariateProduct, 2>,
77 MLEDirectAdapter<P>,
78 Backend,
79>;
80
81pub fn univariatizing_reduction_prover<'a, F, FDomain, P, Backend>(
90 mut reduced_multilinears: Vec<MLEDirectAdapter<P>>,
91 univariatized_multilinear_evals: &[F],
92 univariate_challenge: F,
93 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
94 backend: &'a Backend,
95) -> Result<Prover<'a, FDomain, P, Backend>, Error>
96where
97 F: TowerField,
98 FDomain: TowerField,
99 P: PackedFieldIndexable<Scalar = F>
100 + PackedExtension<F, PackedSubfield = P>
101 + PackedExtension<FDomain>,
102 Backend: ComputationBackend,
103{
104 let skip_rounds = equal_n_vars_check(&reduced_multilinears)?;
105
106 if univariatized_multilinear_evals.len() != reduced_multilinears.len() {
107 bail!(VerificationError::NumberOfFinalEvaluations);
108 }
109
110 let evaluation_domain = EvaluationDomainFactory::<FDomain>::create(
111 &IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default(),
112 1 << skip_rounds,
113 )?;
114
115 reduced_multilinears.push(
116 lagrange_evals_multilinear_extension(&evaluation_domain, univariate_challenge)?.into(),
117 );
118
119 let composite_sum_claims =
120 univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
121
122 let prover = RegularSumcheckProver::new(
123 EvaluationOrder::LowToHigh,
124 reduced_multilinears,
125 composite_sum_claims,
126 evaluation_domain_factory,
127 immediate_switchover_heuristic,
128 backend,
129 )?;
130
131 Ok(prover)
132}
133
134#[derive(Debug)]
135struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
136 evals: Vec<P>,
138 interleaved_evals: Vec<PackedSubfield<P, FBase>>,
140 extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
142 composition_evals: Vec<PackedSubfield<P, FBase>>,
144 round_evals: Vec<Vec<P::Scalar>>,
146}
147
148impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
149 fn new(
150 n_multilinears: usize,
151 skip_rounds: usize,
152 log_batch: usize,
153 log_embedding_degree: usize,
154 composition_degrees: impl Iterator<Item = usize> + Clone,
155 ) -> Self {
156 let subcube_vars = skip_rounds + log_batch;
157 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
158 let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
159 composition_max_degree,
160 skip_rounds,
161 log_batch,
162 );
163
164 let evals =
165 zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
166 let interleaved_evals =
167 zeroed_vec(1 << subcube_vars.saturating_sub(<PackedSubfield<P, FBase>>::LOG_WIDTH));
168
169 let extrapolated_evals = (0..n_multilinears)
170 .map(|_| zeroed_vec(extrapolated_packed_pbase_len))
171 .collect();
172
173 let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
174
175 let round_evals = composition_degrees
176 .map(|composition_degree| {
177 zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
178 })
179 .collect();
180
181 Self {
182 evals,
183 interleaved_evals,
184 extrapolated_evals,
185 composition_evals,
186 round_evals,
187 }
188 }
189}
190
191#[derive(Debug)]
192pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
193where
194 F: Field,
195 P: PackedField<Scalar = F>,
196 Backend: ComputationBackend,
197{
198 pub round_evals: Vec<Vec<F>>,
199 skip_rounds: usize,
200 remaining_rounds: usize,
201 max_domain_size: usize,
202 partial_eq_ind_evals: Backend::Vec<P>,
203}
204
205pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
206where
207 F: Field,
208 P: PackedField<Scalar = F>,
209 Backend: ComputationBackend,
210{
211 pub skip_rounds: usize,
212 pub subcube_lagrange_coeffs: Vec<F>,
213 pub claimed_prime_sums: Vec<F>,
214 pub partial_eq_ind_evals: Backend::Vec<P>,
215}
216
217impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
218where
219 F: Field,
220 P: PackedFieldIndexable<Scalar = F>,
221 Backend: ComputationBackend,
222{
223 #[instrument(
225 skip_all,
226 name = "ZerocheckUnivariateEvalsOutput::fold",
227 level = "debug"
228 )]
229 pub fn fold<FDomain>(
230 self,
231 challenge: F,
232 ) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
233 where
234 FDomain: TowerField,
235 F: ExtensionField<FDomain>,
236 {
237 let Self {
238 round_evals,
239 skip_rounds,
240 remaining_rounds,
241 max_domain_size,
242 mut partial_eq_ind_evals,
243 } = self;
244
245 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default();
246 let max_domain =
247 EvaluationDomainFactory::<FDomain>::create(&domain_factory, max_domain_size)?;
248
249 let subcube_evaluation_domain =
251 EvaluationDomainFactory::<FDomain>::create(&domain_factory, 1 << skip_rounds)?;
252
253 let subcube_lagrange_coeffs = subcube_evaluation_domain.lagrange_evals(challenge);
254
255 fold_partial_eq_ind::<P, Backend>(
257 EvaluationOrder::LowToHigh,
258 remaining_rounds,
259 &mut partial_eq_ind_evals,
260 );
261
262 let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
264
265 let claimed_prime_sums = round_evals
266 .into_iter()
267 .map(|evals| {
268 inner_product_unchecked::<F, F>(
269 evals,
270 round_evals_lagrange_coeffs[1 << skip_rounds..]
271 .iter()
272 .copied(),
273 )
274 })
275 .collect();
276
277 Ok(ZerocheckUnivariateFoldResult {
278 skip_rounds,
279 subcube_lagrange_coeffs,
280 claimed_prime_sums,
281 partial_eq_ind_evals,
282 })
283 }
284}
285
286#[instrument(skip_all, level = "debug")]
321pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
322 multilinears: &[M],
323 compositions: &[Composition],
324 zerocheck_challenges: &[F],
325 skip_rounds: usize,
326 max_domain_size: usize,
327 backend: &Backend,
328) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
329where
330 FDomain: TowerField,
331 FBase: ExtensionField<FDomain>,
332 F: TowerField,
333 P: PackedFieldIndexable<Scalar = F>
334 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
335 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
336 Composition: CompositionPoly<PackedSubfield<P, FBase>>,
337 M: MultilinearPoly<P> + Send + Sync,
338 Backend: ComputationBackend,
339{
340 let n_vars = equal_n_vars_check(multilinears)?;
341 let n_multilinears = multilinears.len();
342
343 if skip_rounds > n_vars {
344 bail!(Error::TooManySkippedRounds);
345 }
346
347 let remaining_rounds = n_vars - skip_rounds;
348 if zerocheck_challenges.len() != remaining_rounds {
349 bail!(Error::IncorrectZerocheckChallengesLength);
350 }
351
352 small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
353
354 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
355 let composition_degrees = compositions.iter().map(|composition| composition.degree());
356 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
357
358 if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
359 bail!(Error::LagrangeDomainTooSmall);
360 }
361
362 let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
364 let pdomain_log_width = <P as PackedExtension<FDomain>>::PackedSubfield::LOG_WIDTH;
365
366 let min_domain_bits = log2_ceil_usize(max_domain_size).max(pdomain_log_width + 1);
369 if min_domain_bits > FDomain::N_BITS {
370 bail!(MathError::DomainSizeTooLarge);
371 }
372
373 let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
375 .expect("FDomain cardinality checked before")
376 .precompute_twiddles();
377
378 const MAX_SUBCUBE_VARS: usize = 12;
380 let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
381
382 let partial_eq_ind_evals = backend.tensor_product_full_query(zerocheck_challenges)?;
387 let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals);
388
389 let pbase_prefix_lens = composition_degrees
391 .clone()
392 .map(|composition_degree| {
393 extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
394 composition_degree,
395 skip_rounds,
396 log_batch,
397 )
398 })
399 .collect::<Vec<_>>();
400
401 let subcube_vars = log_batch + skip_rounds;
402 let log_subcube_count = n_vars - subcube_vars;
403
404 let staggered_round_evals = (0..1 << log_subcube_count)
410 .into_par_iter()
411 .try_fold(
412 || {
413 ParFoldStates::<FBase, P>::new(
414 n_multilinears,
415 skip_rounds,
416 log_batch,
417 log_embedding_degree,
418 composition_degrees.clone(),
419 )
420 },
421 |mut par_fold_states, subcube_index| -> Result<_, Error> {
422 let ParFoldStates {
423 evals,
424 interleaved_evals,
425 extrapolated_evals,
426 composition_evals,
427 round_evals,
428 ..
429 } = &mut par_fold_states;
430
431 for (multilinear, extrapolated_evals) in
433 izip!(multilinears, extrapolated_evals.iter_mut())
434 {
435 multilinear.subcube_evals(
437 subcube_vars,
438 subcube_index,
439 log_embedding_degree,
440 evals.as_mut_slice(),
441 )?;
442
443 let evals_base =
445 <P as PackedExtension<FBase>>::cast_bases_mut(evals.as_mut_slice());
446
447 let interleaved_evals_ref = if log_batch == 0 {
452 evals_base
454 } else {
455 let evals_base_scalars =
456 &<PackedSubfield<P, FBase>>::unpack_scalars(evals_base)
457 [..1 << subcube_vars];
458 let interleaved_evals_scalars =
459 &mut <PackedSubfield<P, FBase>>::unpack_scalars_mut(
460 interleaved_evals.as_mut_slice(),
461 )[..1 << subcube_vars];
462
463 transpose(
464 evals_base_scalars,
465 interleaved_evals_scalars,
466 1 << skip_rounds,
467 1 << log_batch,
468 );
469
470 interleaved_evals.as_mut_slice()
471 };
472
473 let interleaved_evals_bases =
478 recast_packed_mut::<P, FBase, FDomain>(interleaved_evals_ref);
479 let extrapolated_evals_bases =
480 recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
481
482 ntt_extrapolate(
483 &fdomain_ntt,
484 skip_rounds,
485 log_batch + log_extension_degree_base_domain,
486 composition_max_degree,
487 interleaved_evals_bases,
488 extrapolated_evals_bases,
489 )?
490 }
491
492 let partial_eq_ind_evals_scalars_subslice =
495 &partial_eq_ind_evals_scalars[subcube_index << log_batch..][..1 << log_batch];
496
497 for (composition, round_evals, &pbase_prefix_len) in
499 izip!(compositions, round_evals, &pbase_prefix_lens)
500 {
501 let extrapolated_evals_iter = extrapolated_evals
502 .iter()
503 .map(|evals| &evals[..pbase_prefix_len]);
504
505 stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
506 composition
508 .batch_evaluate(batch_query, &mut composition_evals[..pbase_prefix_len])
509 })?;
510
511 let composition_evals_scalars = <PackedSubfield<P, FBase>>::unpack_scalars_mut(
514 composition_evals.as_mut_slice(),
515 );
516
517 for (round_evals_coset, composition_evals_scalars_coset) in izip!(
518 round_evals.chunks_exact_mut(1 << skip_rounds),
519 composition_evals_scalars.chunks_exact(
520 1 << subcube_vars.max(log_embedding_degree + P::LOG_WIDTH)
521 )
522 ) {
523 for (round_eval, composition_evals) in izip!(
524 round_evals_coset,
525 composition_evals_scalars_coset.chunks_exact(1 << log_batch),
526 ) {
527 *round_eval += inner_product_unchecked(
530 partial_eq_ind_evals_scalars_subslice.iter().copied(),
531 composition_evals.iter().copied(),
532 );
533 }
534 }
535
536 }
539
540 Ok(par_fold_states)
541 },
542 )
543 .map(|states| -> Result<_, Error> { Ok(states?.round_evals) })
544 .try_reduce(
545 || {
546 composition_degrees
547 .clone()
548 .map(|composition_degree| {
549 zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
550 })
551 .collect()
552 },
553 |lhs, rhs| -> Result<_, Error> {
554 let round_evals_sum = izip!(lhs, rhs)
555 .map(|(mut lhs_vals, rhs_vals)| {
556 for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
557 *lhs_val += rhs_val;
558 }
559 lhs_vals
560 })
561 .collect();
562
563 Ok(round_evals_sum)
564 },
565 )?;
566
567 let round_evals = extrapolate_round_evals(staggered_round_evals, skip_rounds, max_domain_size)?;
571
572 Ok(ZerocheckUnivariateEvalsOutput {
573 round_evals,
574 skip_rounds,
575 remaining_rounds,
576 max_domain_size,
577 partial_eq_ind_evals,
578 })
579}
580
581#[instrument(skip_all, level = "debug")]
587fn extrapolate_round_evals<F: TowerField>(
588 mut round_evals: Vec<Vec<F>>,
589 skip_rounds: usize,
590 max_domain_size: usize,
591) -> Result<Vec<Vec<F>>, Error> {
592 let ntt = SingleThreadedNTT::with_canonical_field(log2_ceil_usize(max_domain_size))?;
595
596 let mut odd_interpolates = HashMap::new();
598
599 for round_evals in &mut round_evals {
600 round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
602
603 let n = round_evals.len();
604
605 let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
607 let ell = n.trailing_zeros() as usize;
608 assert!(ell >= skip_rounds);
609
610 OddInterpolate::new(n >> ell, ell, ntt.twiddles())
611 .expect("domain large enough by construction")
612 });
613
614 odd_interpolate.inverse_transform(&ntt, round_evals)?;
616
617 let next_log_n = log2_ceil_usize(max_domain_size);
619 round_evals.resize(1 << next_log_n, F::ZERO);
620
621 ntt.forward_transform(round_evals, 0, 0, next_log_n)?;
622
623 debug_assert!(round_evals[..1 << skip_rounds]
625 .iter()
626 .all(|&coeff| coeff == F::ZERO));
627
628 round_evals.resize(max_domain_size, F::ZERO);
630 round_evals.drain(..1 << skip_rounds);
631 }
632
633 Ok(round_evals)
634}
635
636fn ntt_extrapolate<NTT, P>(
637 ntt: &NTT,
638 skip_rounds: usize,
639 log_batch: usize,
640 composition_max_degree: usize,
641 interleaved_evals: &mut [P],
642 extrapolated_evals: &mut [P],
643) -> Result<(), Error>
644where
645 P: PackedFieldIndexable<Scalar: BinaryField>,
646 NTT: AdditiveNTT<P::Scalar>,
647{
648 let subcube_vars = skip_rounds + log_batch;
649 debug_assert_eq!(1 << subcube_vars.saturating_sub(P::LOG_WIDTH), interleaved_evals.len());
650 debug_assert_eq!(
651 extrapolated_evals_packed_len::<P>(composition_max_degree, skip_rounds, log_batch),
652 extrapolated_evals.len()
653 );
654 debug_assert!(
655 NTT::log_domain_size(ntt)
656 >= log2_ceil_usize(domain_size(composition_max_degree, skip_rounds))
657 );
658
659 ntt.inverse_transform(interleaved_evals, 0, log_batch, skip_rounds)?;
661
662 for (i, extrapolated_chunk) in extrapolated_evals
664 .chunks_exact_mut(interleaved_evals.len())
665 .enumerate()
666 {
667 extrapolated_chunk.copy_from_slice(interleaved_evals);
668 ntt.forward_transform(extrapolated_chunk, (i + 1) as u32, log_batch, skip_rounds)?;
669 }
670
671 Ok(())
672}
673
674const fn extrapolated_evals_packed_len<P: PackedField>(
675 composition_degree: usize,
676 skip_rounds: usize,
677 log_batch: usize,
678) -> usize {
679 composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
680}
681
682#[cfg(test)]
683mod tests {
684 use std::sync::Arc;
685
686 use binius_field::{
687 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
688 as_packed_field::{PackScalar, PackedType},
689 underlier::UnderlierType,
690 BinaryField128b, BinaryField16b, BinaryField1b, BinaryField8b, ExtensionField, Field,
691 PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
692 };
693 use binius_hal::make_portable_backend;
694 use binius_math::{
695 CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly,
696 };
697 use binius_ntt::SingleThreadedNTT;
698 use rand::{prelude::StdRng, SeedableRng};
699
700 use crate::{
701 composition::{IndexComposition, ProductComposition},
702 polynomial::CompositionScalarAdapter,
703 protocols::{
704 sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
705 test_utils::generate_zero_product_multilinears,
706 },
707 transparent::eq_ind::EqIndPartialEval,
708 };
709
710 #[test]
711 fn ntt_extrapolate_correctness() {
712 type P = PackedBinaryField4x32b;
713 type FDomain = BinaryField16b;
714 let log_extension_degree_p_domain = 1;
715
716 let mut rng = StdRng::seed_from_u64(0);
717 let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
718 let domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
719 let max_domain = domain_factory.create(1 << 10).unwrap();
720
721 for skip_rounds in 0..5usize {
722 let domain = domain_factory.create(1 << skip_rounds).unwrap();
723 for log_batch in 0..3usize {
724 for composition_degree in 0..5usize {
725 let subcube_vars = skip_rounds + log_batch;
726 let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
727 let interleaved_evals = (0..interleaved_len)
728 .map(|_| P::random(&mut rng))
729 .collect::<Vec<_>>();
730
731 let extrapolated_scalars_cnt =
732 composition_degree.saturating_sub(1) << skip_rounds;
733 let extrapolated_ntts = composition_degree.saturating_sub(1);
734 let extrapolated_len = extrapolated_ntts * interleaved_len;
735 let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
736
737 let mut interleaved_evals_scratch = interleaved_evals.clone();
738
739 let interleaved_evals_domain =
740 P::cast_bases_mut(&mut interleaved_evals_scratch);
741 let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
742
743 super::ntt_extrapolate(
744 &ntt,
745 skip_rounds,
746 log_batch + log_extension_degree_p_domain,
747 composition_degree,
748 interleaved_evals_domain,
749 extrapolated_evals_domain,
750 )
751 .unwrap();
752
753 let interleaved_scalars =
754 &P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
755 let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
756 [..extrapolated_scalars_cnt << log_batch];
757
758 for batch_idx in 0..1 << log_batch {
759 let values = (0..1 << skip_rounds)
760 .map(|i| interleaved_scalars[(i << log_batch) + batch_idx])
761 .collect::<Vec<_>>();
762
763 for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
764 [..extrapolated_scalars_cnt]
765 .iter()
766 .take(1 << skip_rounds)
767 .enumerate()
768 {
769 let extrapolated = domain.extrapolate(&values, point.into()).unwrap();
770 let expected = extrapolated_scalars[(i << log_batch) + batch_idx];
771 assert_eq!(extrapolated, expected);
772 }
773 }
774 }
775 }
776 }
777 }
778
779 #[test]
780 fn zerocheck_univariate_evals_invariants_basic() {
781 zerocheck_univariate_evals_invariants_helper::<
782 OptimalUnderlier128b,
783 BinaryField128b,
784 BinaryField8b,
785 BinaryField16b,
786 >()
787 }
788
789 #[test]
790 fn zerocheck_univariate_evals_with_nontrivial_packing() {
791 zerocheck_univariate_evals_invariants_helper::<
794 OptimalUnderlier512b,
795 BinaryField128b,
796 BinaryField8b,
797 BinaryField16b,
798 >()
799 }
800
801 fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
802 where
803 U: UnderlierType
804 + PackScalar<F>
805 + PackScalar<FBase>
806 + PackScalar<FDomain>
807 + PackScalar<BinaryField1b>,
808 F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
809 FBase: TowerField + ExtensionField<FDomain>,
810 FDomain: TowerField + From<u8>,
811 PackedType<U, FBase>: PackedFieldIndexable,
812 PackedType<U, FDomain>: PackedFieldIndexable,
813 PackedType<U, F>: PackedFieldIndexable,
814 {
815 let mut rng = StdRng::seed_from_u64(0);
816
817 let n_vars = 7;
818 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
819
820 let mut multilinears = generate_zero_product_multilinears::<
821 PackedType<U, BinaryField1b>,
822 PackedType<U, F>,
823 >(&mut rng, n_vars, 2);
824 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
825 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
826
827 let compositions = [
828 Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
829 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
830 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
831 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
832 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
833 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
834 ];
835
836 let backend = make_portable_backend();
837 let zerocheck_challenges = (0..n_vars)
838 .map(|_| <F as Field>::random(&mut rng))
839 .collect::<Vec<_>>();
840
841 for skip_rounds in 0usize..=5 {
842 let max_domain_size = domain_size(5, skip_rounds);
843 let output =
844 zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
845 &multilinears,
846 &compositions,
847 &zerocheck_challenges[skip_rounds..],
848 skip_rounds,
849 max_domain_size,
850 &backend,
851 )
852 .unwrap();
853
854 let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
855 .multilinear_extension::<F, _>(&backend)
856 .unwrap();
857
858 let round_evals_len = 4usize << skip_rounds;
860 assert!(output
861 .round_evals
862 .iter()
863 .all(|round_evals| round_evals.len() == round_evals_len));
864
865 let compositions = compositions
866 .iter()
867 .cloned()
868 .map(CompositionScalarAdapter::new)
869 .collect::<Vec<_>>();
870
871 let mut query = [FBase::ZERO; 9];
872 let mut evals = vec![
873 PackedType::<U, F>::zero();
874 1 << skip_rounds.saturating_sub(
875 log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
876 )
877 ];
878 let domain = DefaultEvaluationDomainFactory::<FDomain>::default()
879 .create(1 << skip_rounds)
880 .unwrap();
881 for round_evals_index in 0..round_evals_len {
882 let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
883 let mut composition_sums = vec![F::ZERO; compositions.len()];
884 for subcube_index in 0..1 << (n_vars - skip_rounds) {
885 for (query, multilinear) in query.iter_mut().zip(&multilinears) {
886 multilinear
887 .subcube_evals(
888 skip_rounds,
889 subcube_index,
890 log_embedding_degree,
891 &mut evals,
892 )
893 .unwrap();
894 let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
895 PackedExtension::<FBase>::cast_bases(&evals),
896 )[..1 << skip_rounds];
897 let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
898 *query = extrapolated;
899 }
900
901 let eq_ind_factor = zerocheck_eq_ind
902 .evaluate_on_hypercube(subcube_index)
903 .unwrap();
904 for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
905 *sum += eq_ind_factor * composition.evaluate(&query).unwrap();
906 }
907 }
908
909 let univariate_skip_composition_sums = output
910 .round_evals
911 .iter()
912 .map(|round_evals| round_evals[round_evals_index])
913 .collect::<Vec<_>>();
914 assert_eq!(univariate_skip_composition_sums, composition_sums);
915 }
916 }
917 }
918}