1use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6 recast_packed_mut, util::inner_product_unchecked, BinaryField, ExtensionField, Field,
7 PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield, TowerField,
8};
9use binius_hal::{ComputationBackend, ComputationBackendExt};
10use binius_math::{
11 BinarySubspace, CompositionPoly, Error as MathError, EvaluationDomain, EvaluationOrder,
12 IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearPoly, RowsBatchRef,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_ntt::{AdditiveNTT, OddInterpolate, SingleThreadedNTT};
16use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
17use bytemuck::zeroed_vec;
18use itertools::izip;
19use stackalloc::stackalloc_with_iter;
20use tracing::instrument;
21use transpose::transpose;
22
23use crate::{
24 composition::{BivariateProduct, IndexComposition},
25 protocols::sumcheck::{
26 common::{
27 equal_n_vars_check, immediate_switchover_heuristic, small_field_embedding_degree_check,
28 },
29 prove::{common::fold_partial_eq_ind, RegularSumcheckProver},
30 univariate::{
31 lagrange_evals_multilinear_extension, univariatizing_reduction_composite_sum_claims,
32 },
33 univariate_zerocheck::{domain_size, extrapolated_scalars_count},
34 Error, VerificationError,
35 },
36};
37
38#[instrument(skip_all, level = "debug")]
40pub fn reduce_to_skipped_projection<F, P, M, Backend>(
41 multilinears: Vec<M>,
42 sumcheck_challenges: &[F],
43 backend: &'_ Backend,
44) -> Result<Vec<MLEDirectAdapter<P>>, Error>
45where
46 F: Field,
47 P: PackedFieldIndexable<Scalar = F>,
48 M: MultilinearPoly<P> + Send + Sync,
49 Backend: ComputationBackend,
50{
51 let n_vars = equal_n_vars_check(&multilinears)?;
52
53 if sumcheck_challenges.len() > n_vars {
54 bail!(Error::IncorrectNumberOfChallenges);
55 }
56
57 let query = backend.multilinear_query(sumcheck_challenges)?;
58
59 let reduced_multilinears = multilinears
60 .par_iter()
61 .map(|multilinear| {
62 backend
63 .evaluate_partial_high(multilinear, query.to_ref())
64 .expect("0 <= sumcheck_challenges.len() < n_vars")
65 .into()
66 })
67 .collect();
68
69 Ok(reduced_multilinears)
70}
71
72pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
73 'a,
74 FDomain,
75 P,
76 IndexComposition<BivariateProduct, 2>,
77 MLEDirectAdapter<P>,
78 Backend,
79>;
80
81pub fn univariatizing_reduction_prover<'a, F, FDomain, P, Backend>(
90 mut reduced_multilinears: Vec<MLEDirectAdapter<P>>,
91 univariatized_multilinear_evals: &[F],
92 univariate_challenge: F,
93 backend: &'a Backend,
94) -> Result<Prover<'a, FDomain, P, Backend>, Error>
95where
96 F: TowerField,
97 FDomain: TowerField,
98 P: PackedFieldIndexable<Scalar = F>
99 + PackedExtension<F, PackedSubfield = P>
100 + PackedExtension<FDomain>,
101 Backend: ComputationBackend,
102{
103 let skip_rounds = equal_n_vars_check(&reduced_multilinears)?;
104
105 if univariatized_multilinear_evals.len() != reduced_multilinears.len() {
106 bail!(VerificationError::NumberOfFinalEvaluations);
107 }
108
109 let subspace =
110 BinarySubspace::<FDomain::Canonical>::with_dim(skip_rounds)?.isomorphic::<FDomain>();
111 let ntt_domain = EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
112
113 reduced_multilinears
114 .push(lagrange_evals_multilinear_extension(&ntt_domain, univariate_challenge)?.into());
115
116 let composite_sum_claims =
117 univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
118
119 let prover = RegularSumcheckProver::new(
120 EvaluationOrder::LowToHigh,
121 reduced_multilinears,
122 composite_sum_claims,
123 IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default(),
124 immediate_switchover_heuristic,
125 backend,
126 )?;
127
128 Ok(prover)
129}
130
131#[derive(Debug)]
132struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
133 evals: Vec<P>,
135 interleaved_evals: Vec<PackedSubfield<P, FBase>>,
137 extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
139 composition_evals: Vec<PackedSubfield<P, FBase>>,
141 round_evals: Vec<Vec<P::Scalar>>,
143}
144
145impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
146 fn new(
147 n_multilinears: usize,
148 skip_rounds: usize,
149 log_batch: usize,
150 log_embedding_degree: usize,
151 composition_degrees: impl Iterator<Item = usize> + Clone,
152 ) -> Self {
153 let subcube_vars = skip_rounds + log_batch;
154 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
155 let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
156 composition_max_degree,
157 skip_rounds,
158 log_batch,
159 );
160
161 let evals =
162 zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
163 let interleaved_evals =
164 zeroed_vec(1 << subcube_vars.saturating_sub(<PackedSubfield<P, FBase>>::LOG_WIDTH));
165
166 let extrapolated_evals = (0..n_multilinears)
167 .map(|_| zeroed_vec(extrapolated_packed_pbase_len))
168 .collect();
169
170 let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
171
172 let round_evals = composition_degrees
173 .map(|composition_degree| {
174 zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
175 })
176 .collect();
177
178 Self {
179 evals,
180 interleaved_evals,
181 extrapolated_evals,
182 composition_evals,
183 round_evals,
184 }
185 }
186}
187
188#[derive(Debug)]
189pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
190where
191 F: Field,
192 P: PackedField<Scalar = F>,
193 Backend: ComputationBackend,
194{
195 pub round_evals: Vec<Vec<F>>,
196 skip_rounds: usize,
197 remaining_rounds: usize,
198 max_domain_size: usize,
199 partial_eq_ind_evals: Backend::Vec<P>,
200}
201
202pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
203where
204 F: Field,
205 P: PackedField<Scalar = F>,
206 Backend: ComputationBackend,
207{
208 pub skip_rounds: usize,
209 pub subcube_lagrange_coeffs: Vec<F>,
210 pub claimed_sums: Vec<F>,
211 pub partial_eq_ind_evals: Backend::Vec<P>,
212}
213
214impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
215where
216 F: Field,
217 P: PackedFieldIndexable<Scalar = F>,
218 Backend: ComputationBackend,
219{
220 #[instrument(
222 skip_all,
223 name = "ZerocheckUnivariateEvalsOutput::fold",
224 level = "debug"
225 )]
226 pub fn fold<FDomain>(
227 self,
228 challenge: F,
229 ) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
230 where
231 FDomain: TowerField,
232 F: ExtensionField<FDomain>,
233 {
234 let Self {
235 round_evals,
236 skip_rounds,
237 remaining_rounds,
238 max_domain_size,
239 mut partial_eq_ind_evals,
240 } = self;
241
242 let max_dim = log2_ceil_usize(max_domain_size);
245 let subspace =
246 BinarySubspace::<FDomain::Canonical>::with_dim(max_dim)?.isomorphic::<FDomain>();
247 let max_domain = EvaluationDomain::from_points(
248 subspace.iter().take(max_domain_size).collect::<Vec<_>>(),
249 false,
250 )?;
251
252 let subcube_lagrange_coeffs = EvaluationDomain::from_points(
254 subspace.reduce_dim(skip_rounds)?.iter().collect::<Vec<_>>(),
255 false,
256 )?
257 .lagrange_evals(challenge);
258
259 fold_partial_eq_ind::<P, Backend>(
261 EvaluationOrder::LowToHigh,
262 remaining_rounds,
263 &mut partial_eq_ind_evals,
264 );
265
266 let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
268
269 let claimed_sums = round_evals
270 .into_iter()
271 .map(|evals| {
272 inner_product_unchecked::<F, F>(
273 evals,
274 round_evals_lagrange_coeffs[1 << skip_rounds..]
275 .iter()
276 .copied(),
277 )
278 })
279 .collect();
280
281 Ok(ZerocheckUnivariateFoldResult {
282 skip_rounds,
283 subcube_lagrange_coeffs,
284 claimed_sums,
285 partial_eq_ind_evals,
286 })
287 }
288}
289
290#[instrument(skip_all, level = "debug")]
325pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
326 multilinears: &[M],
327 compositions: &[Composition],
328 zerocheck_challenges: &[F],
329 skip_rounds: usize,
330 max_domain_size: usize,
331 backend: &Backend,
332) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
333where
334 FDomain: TowerField,
335 FBase: ExtensionField<FDomain>,
336 F: TowerField,
337 P: PackedFieldIndexable<Scalar = F>
338 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
339 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
340 Composition: CompositionPoly<PackedSubfield<P, FBase>>,
341 M: MultilinearPoly<P> + Send + Sync,
342 Backend: ComputationBackend,
343{
344 let n_vars = equal_n_vars_check(multilinears)?;
345 let n_multilinears = multilinears.len();
346
347 if skip_rounds > n_vars {
348 bail!(Error::TooManySkippedRounds);
349 }
350
351 let remaining_rounds = n_vars - skip_rounds;
352 if zerocheck_challenges.len() != remaining_rounds {
353 bail!(Error::IncorrectZerocheckChallengesLength);
354 }
355
356 small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
357
358 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
359 let composition_degrees = compositions.iter().map(|composition| composition.degree());
360 let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
361
362 if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
363 bail!(Error::LagrangeDomainTooSmall);
364 }
365
366 let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
368 let pdomain_log_width = <P as PackedExtension<FDomain>>::PackedSubfield::LOG_WIDTH;
369
370 let min_domain_bits = log2_ceil_usize(max_domain_size).max(pdomain_log_width + 1);
373 if min_domain_bits > FDomain::N_BITS {
374 bail!(MathError::DomainSizeTooLarge);
375 }
376
377 let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
379 .expect("FDomain cardinality checked before")
380 .precompute_twiddles();
381
382 const MAX_SUBCUBE_VARS: usize = 12;
384 let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
385
386 let partial_eq_ind_evals = backend.tensor_product_full_query(zerocheck_challenges)?;
391 let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals);
392
393 let pbase_prefix_lens = composition_degrees
395 .clone()
396 .map(|composition_degree| {
397 extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
398 composition_degree,
399 skip_rounds,
400 log_batch,
401 )
402 })
403 .collect::<Vec<_>>();
404
405 let subcube_vars = log_batch + skip_rounds;
406 let log_subcube_count = n_vars - subcube_vars;
407
408 let staggered_round_evals = (0..1 << log_subcube_count)
414 .into_par_iter()
415 .try_fold(
416 || {
417 ParFoldStates::<FBase, P>::new(
418 n_multilinears,
419 skip_rounds,
420 log_batch,
421 log_embedding_degree,
422 composition_degrees.clone(),
423 )
424 },
425 |mut par_fold_states, subcube_index| -> Result<_, Error> {
426 let ParFoldStates {
427 evals,
428 interleaved_evals,
429 extrapolated_evals,
430 composition_evals,
431 round_evals,
432 ..
433 } = &mut par_fold_states;
434
435 for (multilinear, extrapolated_evals) in
437 izip!(multilinears, extrapolated_evals.iter_mut())
438 {
439 multilinear.subcube_evals(
441 subcube_vars,
442 subcube_index,
443 log_embedding_degree,
444 evals.as_mut_slice(),
445 )?;
446
447 let evals_base =
449 <P as PackedExtension<FBase>>::cast_bases_mut(evals.as_mut_slice());
450
451 let interleaved_evals_ref = if log_batch == 0 {
456 evals_base
458 } else {
459 let evals_base_scalars =
460 &<PackedSubfield<P, FBase>>::unpack_scalars(evals_base)
461 [..1 << subcube_vars];
462 let interleaved_evals_scalars =
463 &mut <PackedSubfield<P, FBase>>::unpack_scalars_mut(
464 interleaved_evals.as_mut_slice(),
465 )[..1 << subcube_vars];
466
467 transpose(
468 evals_base_scalars,
469 interleaved_evals_scalars,
470 1 << skip_rounds,
471 1 << log_batch,
472 );
473
474 interleaved_evals.as_mut_slice()
475 };
476
477 let interleaved_evals_bases =
482 recast_packed_mut::<P, FBase, FDomain>(interleaved_evals_ref);
483 let extrapolated_evals_bases =
484 recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
485
486 ntt_extrapolate(
487 &fdomain_ntt,
488 skip_rounds,
489 log_batch + log_extension_degree_base_domain,
490 composition_max_degree,
491 interleaved_evals_bases,
492 extrapolated_evals_bases,
493 )?
494 }
495
496 let partial_eq_ind_evals_scalars_subslice =
499 &partial_eq_ind_evals_scalars[subcube_index << log_batch..][..1 << log_batch];
500
501 for (composition, round_evals, &pbase_prefix_len) in
503 izip!(compositions, round_evals, &pbase_prefix_lens)
504 {
505 let extrapolated_evals_iter = extrapolated_evals
506 .iter()
507 .map(|evals| &evals[..pbase_prefix_len]);
508
509 stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
510 let batch_query = RowsBatchRef::new(batch_query, pbase_prefix_len);
511
512 composition.batch_evaluate(
514 &batch_query,
515 &mut composition_evals[..pbase_prefix_len],
516 )
517 })?;
518
519 let composition_evals_scalars = <PackedSubfield<P, FBase>>::unpack_scalars_mut(
522 composition_evals.as_mut_slice(),
523 );
524
525 for (round_evals_coset, composition_evals_scalars_coset) in izip!(
526 round_evals.chunks_exact_mut(1 << skip_rounds),
527 composition_evals_scalars.chunks_exact(
528 1 << subcube_vars.max(log_embedding_degree + P::LOG_WIDTH)
529 )
530 ) {
531 for (round_eval, composition_evals) in izip!(
532 round_evals_coset,
533 composition_evals_scalars_coset.chunks_exact(1 << log_batch),
534 ) {
535 *round_eval += inner_product_unchecked(
538 partial_eq_ind_evals_scalars_subslice.iter().copied(),
539 composition_evals.iter().copied(),
540 );
541 }
542 }
543
544 }
547
548 Ok(par_fold_states)
549 },
550 )
551 .map(|states| -> Result<_, Error> { Ok(states?.round_evals) })
552 .try_reduce(
553 || {
554 composition_degrees
555 .clone()
556 .map(|composition_degree| {
557 zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
558 })
559 .collect()
560 },
561 |lhs, rhs| -> Result<_, Error> {
562 let round_evals_sum = izip!(lhs, rhs)
563 .map(|(mut lhs_vals, rhs_vals)| {
564 for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
565 *lhs_val += rhs_val;
566 }
567 lhs_vals
568 })
569 .collect();
570
571 Ok(round_evals_sum)
572 },
573 )?;
574
575 let round_evals = extrapolate_round_evals(staggered_round_evals, skip_rounds, max_domain_size)?;
579
580 Ok(ZerocheckUnivariateEvalsOutput {
581 round_evals,
582 skip_rounds,
583 remaining_rounds,
584 max_domain_size,
585 partial_eq_ind_evals,
586 })
587}
588
589#[instrument(skip_all, level = "debug")]
595fn extrapolate_round_evals<F: TowerField>(
596 mut round_evals: Vec<Vec<F>>,
597 skip_rounds: usize,
598 max_domain_size: usize,
599) -> Result<Vec<Vec<F>>, Error> {
600 let ntt = SingleThreadedNTT::with_canonical_field(log2_ceil_usize(max_domain_size))?;
603
604 let mut odd_interpolates = HashMap::new();
606
607 for round_evals in &mut round_evals {
608 round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
610
611 let n = round_evals.len();
612
613 let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
615 let ell = n.trailing_zeros() as usize;
616 assert!(ell >= skip_rounds);
617
618 OddInterpolate::new(n >> ell, ell, ntt.twiddles())
619 .expect("domain large enough by construction")
620 });
621
622 odd_interpolate.inverse_transform(&ntt, round_evals)?;
624
625 let next_log_n = log2_ceil_usize(max_domain_size);
627 round_evals.resize(1 << next_log_n, F::ZERO);
628
629 ntt.forward_transform(round_evals, 0, 0, next_log_n)?;
630
631 debug_assert!(round_evals[..1 << skip_rounds]
633 .iter()
634 .all(|&coeff| coeff == F::ZERO));
635
636 round_evals.resize(max_domain_size, F::ZERO);
638 round_evals.drain(..1 << skip_rounds);
639 }
640
641 Ok(round_evals)
642}
643
644fn ntt_extrapolate<NTT, P>(
645 ntt: &NTT,
646 skip_rounds: usize,
647 log_batch: usize,
648 composition_max_degree: usize,
649 interleaved_evals: &mut [P],
650 extrapolated_evals: &mut [P],
651) -> Result<(), Error>
652where
653 P: PackedFieldIndexable<Scalar: BinaryField>,
654 NTT: AdditiveNTT<P::Scalar>,
655{
656 let subcube_vars = skip_rounds + log_batch;
657 debug_assert_eq!(1 << subcube_vars.saturating_sub(P::LOG_WIDTH), interleaved_evals.len());
658 debug_assert_eq!(
659 extrapolated_evals_packed_len::<P>(composition_max_degree, skip_rounds, log_batch),
660 extrapolated_evals.len()
661 );
662 debug_assert!(
663 NTT::log_domain_size(ntt)
664 >= log2_ceil_usize(domain_size(composition_max_degree, skip_rounds))
665 );
666
667 ntt.inverse_transform(interleaved_evals, 0, log_batch, skip_rounds)?;
669
670 for (i, extrapolated_chunk) in extrapolated_evals
672 .chunks_exact_mut(interleaved_evals.len())
673 .enumerate()
674 {
675 extrapolated_chunk.copy_from_slice(interleaved_evals);
676 ntt.forward_transform(extrapolated_chunk, (i + 1) as u32, log_batch, skip_rounds)?;
677 }
678
679 Ok(())
680}
681
682const fn extrapolated_evals_packed_len<P: PackedField>(
683 composition_degree: usize,
684 skip_rounds: usize,
685 log_batch: usize,
686) -> usize {
687 composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
688}
689
690#[cfg(test)]
691mod tests {
692 use std::sync::Arc;
693
694 use binius_field::{
695 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
696 as_packed_field::{PackScalar, PackedType},
697 underlier::UnderlierType,
698 BinaryField128b, BinaryField16b, BinaryField1b, BinaryField8b, ExtensionField, Field,
699 PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
700 };
701 use binius_hal::make_portable_backend;
702 use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearPoly};
703 use binius_ntt::SingleThreadedNTT;
704 use rand::{prelude::StdRng, SeedableRng};
705
706 use crate::{
707 composition::{IndexComposition, ProductComposition},
708 polynomial::CompositionScalarAdapter,
709 protocols::{
710 sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
711 test_utils::generate_zero_product_multilinears,
712 },
713 transparent::eq_ind::EqIndPartialEval,
714 };
715
716 #[test]
717 fn ntt_extrapolate_correctness() {
718 type P = PackedBinaryField4x32b;
719 type FDomain = BinaryField16b;
720 let log_extension_degree_p_domain = 1;
721
722 let mut rng = StdRng::seed_from_u64(0);
723 let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
724 let subspace = BinarySubspace::<FDomain>::with_dim(10).unwrap();
725 let max_domain =
726 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
727
728 for skip_rounds in 0..5usize {
729 let subsubspace = subspace.reduce_dim(skip_rounds).unwrap();
730 let domain =
731 EvaluationDomain::from_points(subsubspace.iter().collect::<Vec<_>>(), false)
732 .unwrap();
733 for log_batch in 0..3usize {
734 for composition_degree in 0..5usize {
735 let subcube_vars = skip_rounds + log_batch;
736 let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
737 let interleaved_evals = (0..interleaved_len)
738 .map(|_| P::random(&mut rng))
739 .collect::<Vec<_>>();
740
741 let extrapolated_scalars_cnt =
742 composition_degree.saturating_sub(1) << skip_rounds;
743 let extrapolated_ntts = composition_degree.saturating_sub(1);
744 let extrapolated_len = extrapolated_ntts * interleaved_len;
745 let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
746
747 let mut interleaved_evals_scratch = interleaved_evals.clone();
748
749 let interleaved_evals_domain =
750 P::cast_bases_mut(&mut interleaved_evals_scratch);
751 let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
752
753 super::ntt_extrapolate(
754 &ntt,
755 skip_rounds,
756 log_batch + log_extension_degree_p_domain,
757 composition_degree,
758 interleaved_evals_domain,
759 extrapolated_evals_domain,
760 )
761 .unwrap();
762
763 let interleaved_scalars =
764 &P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
765 let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
766 [..extrapolated_scalars_cnt << log_batch];
767
768 for batch_idx in 0..1 << log_batch {
769 let values = (0..1 << skip_rounds)
770 .map(|i| interleaved_scalars[(i << log_batch) + batch_idx])
771 .collect::<Vec<_>>();
772
773 for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
774 [..extrapolated_scalars_cnt]
775 .iter()
776 .take(1 << skip_rounds)
777 .enumerate()
778 {
779 let extrapolated = domain.extrapolate(&values, point.into()).unwrap();
780 let expected = extrapolated_scalars[(i << log_batch) + batch_idx];
781 assert_eq!(extrapolated, expected);
782 }
783 }
784 }
785 }
786 }
787 }
788
789 #[test]
790 fn zerocheck_univariate_evals_invariants_basic() {
791 zerocheck_univariate_evals_invariants_helper::<
792 OptimalUnderlier128b,
793 BinaryField128b,
794 BinaryField8b,
795 BinaryField16b,
796 >()
797 }
798
799 #[test]
800 fn zerocheck_univariate_evals_with_nontrivial_packing() {
801 zerocheck_univariate_evals_invariants_helper::<
804 OptimalUnderlier512b,
805 BinaryField128b,
806 BinaryField8b,
807 BinaryField16b,
808 >()
809 }
810
811 fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
812 where
813 U: UnderlierType
814 + PackScalar<F>
815 + PackScalar<FBase>
816 + PackScalar<FDomain>
817 + PackScalar<BinaryField1b>,
818 F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
819 FBase: TowerField + ExtensionField<FDomain>,
820 FDomain: TowerField + From<u8>,
821 PackedType<U, FBase>: PackedFieldIndexable,
822 PackedType<U, FDomain>: PackedFieldIndexable,
823 PackedType<U, F>: PackedFieldIndexable,
824 {
825 let mut rng = StdRng::seed_from_u64(0);
826
827 let n_vars = 7;
828 let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
829
830 let mut multilinears = generate_zero_product_multilinears::<
831 PackedType<U, BinaryField1b>,
832 PackedType<U, F>,
833 >(&mut rng, n_vars, 2);
834 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
835 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
836
837 let compositions = [
838 Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
839 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
840 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
841 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
842 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
843 as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
844 ];
845
846 let backend = make_portable_backend();
847 let zerocheck_challenges = (0..n_vars)
848 .map(|_| <F as Field>::random(&mut rng))
849 .collect::<Vec<_>>();
850
851 for skip_rounds in 0usize..=5 {
852 let max_domain_size = domain_size(5, skip_rounds);
853 let output =
854 zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
855 &multilinears,
856 &compositions,
857 &zerocheck_challenges[skip_rounds..],
858 skip_rounds,
859 max_domain_size,
860 &backend,
861 )
862 .unwrap();
863
864 let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
865 .multilinear_extension::<F, _>(&backend)
866 .unwrap();
867
868 let round_evals_len = 4usize << skip_rounds;
870 assert!(output
871 .round_evals
872 .iter()
873 .all(|round_evals| round_evals.len() == round_evals_len));
874
875 let compositions = compositions
876 .iter()
877 .cloned()
878 .map(CompositionScalarAdapter::new)
879 .collect::<Vec<_>>();
880
881 let mut query = [FBase::ZERO; 9];
882 let mut evals = vec![
883 PackedType::<U, F>::zero();
884 1 << skip_rounds.saturating_sub(
885 log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
886 )
887 ];
888 let subspace = BinarySubspace::<FDomain>::with_dim(skip_rounds).unwrap();
889 let domain =
890 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
891 for round_evals_index in 0..round_evals_len {
892 let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
893 let mut composition_sums = vec![F::ZERO; compositions.len()];
894 for subcube_index in 0..1 << (n_vars - skip_rounds) {
895 for (query, multilinear) in query.iter_mut().zip(&multilinears) {
896 multilinear
897 .subcube_evals(
898 skip_rounds,
899 subcube_index,
900 log_embedding_degree,
901 &mut evals,
902 )
903 .unwrap();
904 let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
905 PackedExtension::<FBase>::cast_bases(&evals),
906 )[..1 << skip_rounds];
907 let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
908 *query = extrapolated;
909 }
910
911 let eq_ind_factor = zerocheck_eq_ind
912 .evaluate_on_hypercube(subcube_index)
913 .unwrap();
914 for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
915 *sum += eq_ind_factor * composition.evaluate(&query).unwrap();
916 }
917 }
918
919 let univariate_skip_composition_sums = output
920 .round_evals
921 .iter()
922 .map(|round_evals| round_evals[round_evals_index])
923 .collect::<Vec<_>>();
924 assert_eq!(univariate_skip_composition_sums, composition_sums);
925 }
926 }
927 }
928}