binius_core/protocols/sumcheck/prove/
univariate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6	recast_packed_mut, util::inner_product_unchecked, BinaryField, ExtensionField, Field,
7	PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield, TowerField,
8};
9use binius_hal::{ComputationBackend, ComputationBackendExt};
10use binius_math::{
11	CompositionPoly, Error as MathError, EvaluationDomainFactory, EvaluationOrder,
12	IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearPoly,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_ntt::{AdditiveNTT, OddInterpolate, SingleThreadedNTT};
16use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
17use bytemuck::zeroed_vec;
18use itertools::izip;
19use stackalloc::stackalloc_with_iter;
20use tracing::instrument;
21use transpose::transpose;
22
23use crate::{
24	composition::{BivariateProduct, IndexComposition},
25	protocols::sumcheck::{
26		common::{
27			equal_n_vars_check, immediate_switchover_heuristic, small_field_embedding_degree_check,
28		},
29		prove::{common::fold_partial_eq_ind, RegularSumcheckProver},
30		univariate::{
31			lagrange_evals_multilinear_extension, univariatizing_reduction_composite_sum_claims,
32		},
33		univariate_zerocheck::{domain_size, extrapolated_scalars_count},
34		Error, VerificationError,
35	},
36};
37
38/// Helper method to reduce the witness to skipped variables via a partial high projection.
39#[instrument(skip_all, level = "debug")]
40pub fn reduce_to_skipped_projection<F, P, M, Backend>(
41	multilinears: Vec<M>,
42	sumcheck_challenges: &[F],
43	backend: &'_ Backend,
44) -> Result<Vec<MLEDirectAdapter<P>>, Error>
45where
46	F: Field,
47	P: PackedFieldIndexable<Scalar = F>,
48	M: MultilinearPoly<P> + Send + Sync,
49	Backend: ComputationBackend,
50{
51	let n_vars = equal_n_vars_check(&multilinears)?;
52
53	if sumcheck_challenges.len() > n_vars {
54		bail!(Error::IncorrectNumberOfChallenges);
55	}
56
57	let query = backend.multilinear_query(sumcheck_challenges)?;
58
59	let reduced_multilinears = multilinears
60		.par_iter()
61		.map(|multilinear| {
62			backend
63				.evaluate_partial_high(multilinear, query.to_ref())
64				.expect("0 <= sumcheck_challenges.len() < n_vars")
65				.into()
66		})
67		.collect();
68
69	Ok(reduced_multilinears)
70}
71
72pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
73	'a,
74	FDomain,
75	P,
76	IndexComposition<BivariateProduct, 2>,
77	MLEDirectAdapter<P>,
78	Backend,
79>;
80
81/// Create the sumcheck prover for the univariatizing reduction of multilinears
82/// (see [verifier side](crate::protocols::sumcheck::univariate::univariatizing_reduction_claim))
83///
84/// This method takes multilinears projected to first `skip_rounds` variables, constructs a multilinear
85/// extension of Lagrange evaluations at `univariate_challenge`, and creates a regular sumcheck prover,
86/// placing Lagrange evaluation in the last witness column.
87///
88/// Note that `univariatized_multilinear_evals` come from a previous sumcheck with a univariate first round.
89pub fn univariatizing_reduction_prover<'a, F, FDomain, P, Backend>(
90	mut reduced_multilinears: Vec<MLEDirectAdapter<P>>,
91	univariatized_multilinear_evals: &[F],
92	univariate_challenge: F,
93	evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
94	backend: &'a Backend,
95) -> Result<Prover<'a, FDomain, P, Backend>, Error>
96where
97	F: TowerField,
98	FDomain: TowerField,
99	P: PackedFieldIndexable<Scalar = F>
100		+ PackedExtension<F, PackedSubfield = P>
101		+ PackedExtension<FDomain>,
102	Backend: ComputationBackend,
103{
104	let skip_rounds = equal_n_vars_check(&reduced_multilinears)?;
105
106	if univariatized_multilinear_evals.len() != reduced_multilinears.len() {
107		bail!(VerificationError::NumberOfFinalEvaluations);
108	}
109
110	let evaluation_domain = EvaluationDomainFactory::<FDomain>::create(
111		&IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default(),
112		1 << skip_rounds,
113	)?;
114
115	reduced_multilinears.push(
116		lagrange_evals_multilinear_extension(&evaluation_domain, univariate_challenge)?.into(),
117	);
118
119	let composite_sum_claims =
120		univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
121
122	let prover = RegularSumcheckProver::new(
123		EvaluationOrder::LowToHigh,
124		reduced_multilinears,
125		composite_sum_claims,
126		evaluation_domain_factory,
127		immediate_switchover_heuristic,
128		backend,
129	)?;
130
131	Ok(prover)
132}
133
134#[derive(Debug)]
135struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
136	/// Evaluations of a multilinear subcube, embedded into P (see MultilinearPoly::subcube_evals). Scratch space.
137	evals: Vec<P>,
138	/// `evals` cast to base field and transposed to 2^skip_rounds * 2^log_batch row-major form. Scratch space.
139	interleaved_evals: Vec<PackedSubfield<P, FBase>>,
140	/// `interleaved_evals` extrapolated beyond first 2^skip_rounds domain points, per multilinear.
141	extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
142	/// Evals of a single composition over extrapolated multilinears. Scratch space.
143	composition_evals: Vec<PackedSubfield<P, FBase>>,
144	/// Round evals accumulators, per multilinear.
145	round_evals: Vec<Vec<P::Scalar>>,
146}
147
148impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
149	fn new(
150		n_multilinears: usize,
151		skip_rounds: usize,
152		log_batch: usize,
153		log_embedding_degree: usize,
154		composition_degrees: impl Iterator<Item = usize> + Clone,
155	) -> Self {
156		let subcube_vars = skip_rounds + log_batch;
157		let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
158		let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
159			composition_max_degree,
160			skip_rounds,
161			log_batch,
162		);
163
164		let evals =
165			zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
166		let interleaved_evals =
167			zeroed_vec(1 << subcube_vars.saturating_sub(<PackedSubfield<P, FBase>>::LOG_WIDTH));
168
169		let extrapolated_evals = (0..n_multilinears)
170			.map(|_| zeroed_vec(extrapolated_packed_pbase_len))
171			.collect();
172
173		let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
174
175		let round_evals = composition_degrees
176			.map(|composition_degree| {
177				zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
178			})
179			.collect();
180
181		Self {
182			evals,
183			interleaved_evals,
184			extrapolated_evals,
185			composition_evals,
186			round_evals,
187		}
188	}
189}
190
191#[derive(Debug)]
192pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
193where
194	F: Field,
195	P: PackedField<Scalar = F>,
196	Backend: ComputationBackend,
197{
198	pub round_evals: Vec<Vec<F>>,
199	skip_rounds: usize,
200	remaining_rounds: usize,
201	max_domain_size: usize,
202	partial_eq_ind_evals: Backend::Vec<P>,
203}
204
205pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
206where
207	F: Field,
208	P: PackedField<Scalar = F>,
209	Backend: ComputationBackend,
210{
211	pub skip_rounds: usize,
212	pub subcube_lagrange_coeffs: Vec<F>,
213	pub claimed_prime_sums: Vec<F>,
214	pub partial_eq_ind_evals: Backend::Vec<P>,
215}
216
217impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
218where
219	F: Field,
220	P: PackedFieldIndexable<Scalar = F>,
221	Backend: ComputationBackend,
222{
223	// Univariate round can be folded once the challenge has been sampled.
224	#[instrument(
225		skip_all,
226		name = "ZerocheckUnivariateEvalsOutput::fold",
227		level = "debug"
228	)]
229	pub fn fold<FDomain>(
230		self,
231		challenge: F,
232	) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
233	where
234		FDomain: TowerField,
235		F: ExtensionField<FDomain>,
236	{
237		let Self {
238			round_evals,
239			skip_rounds,
240			remaining_rounds,
241			max_domain_size,
242			mut partial_eq_ind_evals,
243		} = self;
244
245		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default();
246		let max_domain =
247			EvaluationDomainFactory::<FDomain>::create(&domain_factory, max_domain_size)?;
248
249		// Lagrange extrapolation over skipped subcube
250		let subcube_evaluation_domain =
251			EvaluationDomainFactory::<FDomain>::create(&domain_factory, 1 << skip_rounds)?;
252
253		let subcube_lagrange_coeffs = subcube_evaluation_domain.lagrange_evals(challenge);
254
255		// Zerocheck tensor expansion for the reduced zerocheck should be one variable less
256		fold_partial_eq_ind::<P, Backend>(
257			EvaluationOrder::LowToHigh,
258			remaining_rounds,
259			&mut partial_eq_ind_evals,
260		);
261
262		// Lagrange extrapolation for the entire univariate domain
263		let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
264
265		let claimed_prime_sums = round_evals
266			.into_iter()
267			.map(|evals| {
268				inner_product_unchecked::<F, F>(
269					evals,
270					round_evals_lagrange_coeffs[1 << skip_rounds..]
271						.iter()
272						.copied(),
273				)
274			})
275			.collect();
276
277		Ok(ZerocheckUnivariateFoldResult {
278			skip_rounds,
279			subcube_lagrange_coeffs,
280			claimed_prime_sums,
281			partial_eq_ind_evals,
282		})
283	}
284}
285
286/// Compute univariate skip round evaluations for zerocheck.
287///
288/// When all witness multilinear hypercube evaluations can be embedded into a small field
289/// `PBase::Scalar` that is significantly smaller than `F`, we naturally want to refrain from
290/// folding for `skip_rounds` (denoted below as $k$) to reap the benefits of faster small field
291/// multiplications. Naive extensions to sumcheck protocol which compute multivariate round
292/// polynomials do not work though, given that for a composition of degree $d$ one would need
293/// $(d+1)^k-2^k$ evaluations (assuming [Gruen24] section 3.2 optimizations), which usually grows faster
294/// than $2^k$ and thus will typically require more work than large field sumcheck. We adopt a
295/// univariatizing approach instead, where we define "oblong" multivariates:
296/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot L_u(\hat{u}_1)$$
297/// with $\mathbb{M}: \hat{u}_1 \rightarrow (u_1, \ldots, u_k)$ being some map from the univariate domain to
298/// the $\mathcal{B}_k$ hypercube and $L_u(\hat{u})$ being Lagrange polynomials.
299///
300/// The main idea of the univariatizing approach is that $\hat{M}$ are of degree $2^k-1$ in
301/// $\hat{u}_1$ and multilinear in other variables, thus evaluating a composition of degree $d$ over
302/// $\hat{M}$ yields a total degree of $d(2^k-1)$ in the first round (again, assuming [Gruen24]
303/// section 3.2 trick to avoid multiplication by the equality indicator), which is comparable to
304/// what a regular non-skipping zerocheck prover would do. The only issue is that we normally don't
305/// have an oracle for $\hat{M}$, which necessitates an extra sumcheck reduction to multilinear claims
306/// (see [univariatizing_reduction_claim](`super::super::univariate::univariatizing_reduction_claim`)).
307///
308/// One special trick of the univariate round is that round polynomial is represented in Lagrange form:
309///  1. Honest prover evaluates to zero on $2^k$ domain points mapping to $\mathcal{B}_k$, reducing proof size
310///  2. Avoiding monomial conversion saves prover time by skipping $O(N^3)$ inverse Vandermonde precomp
311///  3. Evaluation in the verifier can be made linear time when barycentric weights are precomputed
312///
313/// This implementation defines $\mathbb{M}$ to be the basis-induced mapping of the binary field `FDomain`;
314/// the main reason for that is to be able to use additive NTT from [LCH14] for extrapolation. The choice
315/// of domain field impacts performance, thus generally the smallest field with cardinality not less than
316/// the degree of the round polynomial should be used.
317///
318/// [LCH14]: <https://arxiv.org/abs/1404.3458>
319/// [Gruen24]: <https://eprint.iacr.org/2024/108>
320#[instrument(skip_all, level = "debug")]
321pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
322	multilinears: &[M],
323	compositions: &[Composition],
324	zerocheck_challenges: &[F],
325	skip_rounds: usize,
326	max_domain_size: usize,
327	backend: &Backend,
328) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
329where
330	FDomain: TowerField,
331	FBase: ExtensionField<FDomain>,
332	F: TowerField,
333	P: PackedFieldIndexable<Scalar = F>
334		+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
335		+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
336	Composition: CompositionPoly<PackedSubfield<P, FBase>>,
337	M: MultilinearPoly<P> + Send + Sync,
338	Backend: ComputationBackend,
339{
340	let n_vars = equal_n_vars_check(multilinears)?;
341	let n_multilinears = multilinears.len();
342
343	if skip_rounds > n_vars {
344		bail!(Error::TooManySkippedRounds);
345	}
346
347	let remaining_rounds = n_vars - skip_rounds;
348	if zerocheck_challenges.len() != remaining_rounds {
349		bail!(Error::IncorrectZerocheckChallengesLength);
350	}
351
352	small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
353
354	let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
355	let composition_degrees = compositions.iter().map(|composition| composition.degree());
356	let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
357
358	if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
359		bail!(Error::LagrangeDomainTooSmall);
360	}
361
362	// Batching factors for strided NTTs.
363	let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
364	let pdomain_log_width = <P as PackedExtension<FDomain>>::PackedSubfield::LOG_WIDTH;
365
366	// The lower bound is due to SingleThreadedNTT implementation quirk, which requires
367	// at least two packed field elements to be able to use PackedField::interleave.
368	let min_domain_bits = log2_ceil_usize(max_domain_size).max(pdomain_log_width + 1);
369	if min_domain_bits > FDomain::N_BITS {
370		bail!(MathError::DomainSizeTooLarge);
371	}
372
373	// Only a domain size NTT is needed.
374	let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
375		.expect("FDomain cardinality checked before")
376		.precompute_twiddles();
377
378	// Smaller subcubes are batched together to reduce interpolation/evaluation overhead.
379	const MAX_SUBCUBE_VARS: usize = 12;
380	let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
381
382	// Expand the multilinear query in all but the first `skip_rounds` variables,
383	// where each tensor expansion element serves as a constant factor of the whole
384	// univariatized subcube.
385	// NB: expansion of the first `skip_rounds` variables is applied to the round evals sum
386	let partial_eq_ind_evals = backend.tensor_product_full_query(zerocheck_challenges)?;
387	let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals);
388
389	// Evaluate each composition on a minimal packed prefix corresponding to the degree
390	let pbase_prefix_lens = composition_degrees
391		.clone()
392		.map(|composition_degree| {
393			extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
394				composition_degree,
395				skip_rounds,
396				log_batch,
397			)
398		})
399		.collect::<Vec<_>>();
400
401	let subcube_vars = log_batch + skip_rounds;
402	let log_subcube_count = n_vars - subcube_vars;
403
404	// NB: we avoid evaluation on the first 2^skip_rounds points because honest
405	// prover would always evaluate to zero there; we also factor out first
406	// skip_rounds terms of the equality indicator and apply them pointwise to
407	// the final round evaluations, which equates to lowering the composition_degree
408	// by one (this is an extension of Gruen section 3.2 trick)
409	let staggered_round_evals = (0..1 << log_subcube_count)
410		.into_par_iter()
411		.try_fold(
412			|| {
413				ParFoldStates::<FBase, P>::new(
414					n_multilinears,
415					skip_rounds,
416					log_batch,
417					log_embedding_degree,
418					composition_degrees.clone(),
419				)
420			},
421			|mut par_fold_states, subcube_index| -> Result<_, Error> {
422				let ParFoldStates {
423					evals,
424					interleaved_evals,
425					extrapolated_evals,
426					composition_evals,
427					round_evals,
428					..
429				} = &mut par_fold_states;
430
431				// Interpolate multilinear evals for each multilinear
432				for (multilinear, extrapolated_evals) in
433					izip!(multilinears, extrapolated_evals.iter_mut())
434				{
435					// Sample evals subcube from a multilinear poly
436					multilinear.subcube_evals(
437						subcube_vars,
438						subcube_index,
439						log_embedding_degree,
440						evals.as_mut_slice(),
441					)?;
442
443					// Use the PackedExtension bound to cast evals to base field.
444					let evals_base =
445						<P as PackedExtension<FBase>>::cast_bases_mut(evals.as_mut_slice());
446
447					// The evals subcube can be seen as 2^log_batch subcubes of skip_rounds
448					// variables, laid out sequentially in memory; this can be seen as
449					// row major 2^log_batch * 2^skip_rounds scalar matrix. We need to transpose
450					// it to 2^skip_rounds * 2_log_batch shape in order for the strided NTT to work.
451					let interleaved_evals_ref = if log_batch == 0 {
452						// In case of no batching, pass the slice as is.
453						evals_base
454					} else {
455						let evals_base_scalars =
456							&<PackedSubfield<P, FBase>>::unpack_scalars(evals_base)
457								[..1 << subcube_vars];
458						let interleaved_evals_scalars =
459							&mut <PackedSubfield<P, FBase>>::unpack_scalars_mut(
460								interleaved_evals.as_mut_slice(),
461							)[..1 << subcube_vars];
462
463						transpose(
464							evals_base_scalars,
465							interleaved_evals_scalars,
466							1 << skip_rounds,
467							1 << log_batch,
468						);
469
470						interleaved_evals.as_mut_slice()
471					};
472
473					// Extrapolate evals using a conservative upper bound of the composition
474					// degree. When evals are correctly strided, we can use additive NTT to
475					// extrapolate them beyond the first 2^skip_rounds. We use the fact that an NTT
476					// over the extension is just a strided NTT over the base field.
477					let interleaved_evals_bases =
478						recast_packed_mut::<P, FBase, FDomain>(interleaved_evals_ref);
479					let extrapolated_evals_bases =
480						recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
481
482					ntt_extrapolate(
483						&fdomain_ntt,
484						skip_rounds,
485						log_batch + log_extension_degree_base_domain,
486						composition_max_degree,
487						interleaved_evals_bases,
488						extrapolated_evals_bases,
489					)?
490				}
491
492				// Obtain 1 << log_batch partial equality indicator constant factors for each
493				// of the subcubes of size 1 << skip_rounds.
494				let partial_eq_ind_evals_scalars_subslice =
495					&partial_eq_ind_evals_scalars[subcube_index << log_batch..][..1 << log_batch];
496
497				// Evaluate the compositions and accumulate round results
498				for (composition, round_evals, &pbase_prefix_len) in
499					izip!(compositions, round_evals, &pbase_prefix_lens)
500				{
501					let extrapolated_evals_iter = extrapolated_evals
502						.iter()
503						.map(|evals| &evals[..pbase_prefix_len]);
504
505					stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
506						// Evaluate the small field composition
507						composition
508							.batch_evaluate(batch_query, &mut composition_evals[..pbase_prefix_len])
509					})?;
510
511					// Accumulate round evals and multiply by the constant part of the
512					// zerocheck equality indicator
513					let composition_evals_scalars = <PackedSubfield<P, FBase>>::unpack_scalars_mut(
514						composition_evals.as_mut_slice(),
515					);
516
517					for (round_evals_coset, composition_evals_scalars_coset) in izip!(
518						round_evals.chunks_exact_mut(1 << skip_rounds),
519						composition_evals_scalars.chunks_exact(
520							1 << subcube_vars.max(log_embedding_degree + P::LOG_WIDTH)
521						)
522					) {
523						for (round_eval, composition_evals) in izip!(
524							round_evals_coset,
525							composition_evals_scalars_coset.chunks_exact(1 << log_batch),
526						) {
527							// Inner product is with the high n_vars - skip_rounds projection
528							// of the zerocheck equality indicator (one factor per subcube).
529							*round_eval += inner_product_unchecked(
530								partial_eq_ind_evals_scalars_subslice.iter().copied(),
531								composition_evals.iter().copied(),
532							);
533						}
534					}
535
536					// REVIEW: only slow path is left, fast path is to be reintroduced in the followup PRs
537					//         targeted on dropping PackedFieldIndexable
538				}
539
540				Ok(par_fold_states)
541			},
542		)
543		.map(|states| -> Result<_, Error> { Ok(states?.round_evals) })
544		.try_reduce(
545			|| {
546				composition_degrees
547					.clone()
548					.map(|composition_degree| {
549						zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
550					})
551					.collect()
552			},
553			|lhs, rhs| -> Result<_, Error> {
554				let round_evals_sum = izip!(lhs, rhs)
555					.map(|(mut lhs_vals, rhs_vals)| {
556						for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
557							*lhs_val += rhs_val;
558						}
559						lhs_vals
560					})
561					.collect();
562
563				Ok(round_evals_sum)
564			},
565		)?;
566
567	// So far evals of each composition are "staggered" in a sense that they are evaluated on the smallest
568	// domain which guarantees uniqueness of the round polynomial. We extrapolate them to max_domain_size to
569	// aid in Gruen section 3.2 optimization below and batch mixing.
570	let round_evals = extrapolate_round_evals(staggered_round_evals, skip_rounds, max_domain_size)?;
571
572	Ok(ZerocheckUnivariateEvalsOutput {
573		round_evals,
574		skip_rounds,
575		remaining_rounds,
576		max_domain_size,
577		partial_eq_ind_evals,
578	})
579}
580
581// Extrapolate round evaluations to the full domain.
582// NB: this method relies on the fact that `round_evals` have specific lengths
583// (namely `d * 2^n`, where `n` is not less than the number of skipped rounds and thus d
584// is not larger than the composition degree), which enables additive-NTT based subquadratic
585// techniques.
586#[instrument(skip_all, level = "debug")]
587fn extrapolate_round_evals<F: TowerField>(
588	mut round_evals: Vec<Vec<F>>,
589	skip_rounds: usize,
590	max_domain_size: usize,
591) -> Result<Vec<Vec<F>>, Error> {
592	// Instantiate a large enough NTT over F to be able to forward transform to full domain size.
593	// REVIEW: should be possible to use an existing FDomain NTT with striding.
594	let ntt = SingleThreadedNTT::with_canonical_field(log2_ceil_usize(max_domain_size))?;
595
596	// Cache OddInterpolate instances, which, albeit small in practice, take cubic time to create.
597	let mut odd_interpolates = HashMap::new();
598
599	for round_evals in &mut round_evals {
600		// Re-add zero evaluations at the beginning.
601		round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
602
603		let n = round_evals.len();
604
605		// Get OddInterpolate instance of required size.
606		let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
607			let ell = n.trailing_zeros() as usize;
608			assert!(ell >= skip_rounds);
609
610			OddInterpolate::new(n >> ell, ell, ntt.twiddles())
611				.expect("domain large enough by construction")
612		});
613
614		// Obtain novel polynomial basis representation of round evaluations.
615		odd_interpolate.inverse_transform(&ntt, round_evals)?;
616
617		// Use forward NTT to extrapolate novel representation to the max domain size.
618		let next_log_n = log2_ceil_usize(max_domain_size);
619		round_evals.resize(1 << next_log_n, F::ZERO);
620
621		ntt.forward_transform(round_evals, 0, 0, next_log_n)?;
622
623		// Sanity check: first 1 << skip_rounds evals are still zeros.
624		debug_assert!(round_evals[..1 << skip_rounds]
625			.iter()
626			.all(|&coeff| coeff == F::ZERO));
627
628		// Trim the result.
629		round_evals.resize(max_domain_size, F::ZERO);
630		round_evals.drain(..1 << skip_rounds);
631	}
632
633	Ok(round_evals)
634}
635
636fn ntt_extrapolate<NTT, P>(
637	ntt: &NTT,
638	skip_rounds: usize,
639	log_batch: usize,
640	composition_max_degree: usize,
641	interleaved_evals: &mut [P],
642	extrapolated_evals: &mut [P],
643) -> Result<(), Error>
644where
645	P: PackedFieldIndexable<Scalar: BinaryField>,
646	NTT: AdditiveNTT<P::Scalar>,
647{
648	let subcube_vars = skip_rounds + log_batch;
649	debug_assert_eq!(1 << subcube_vars.saturating_sub(P::LOG_WIDTH), interleaved_evals.len());
650	debug_assert_eq!(
651		extrapolated_evals_packed_len::<P>(composition_max_degree, skip_rounds, log_batch),
652		extrapolated_evals.len()
653	);
654	debug_assert!(
655		NTT::log_domain_size(ntt)
656			>= log2_ceil_usize(domain_size(composition_max_degree, skip_rounds))
657	);
658
659	// Inverse NTT: convert evals to novel basis representation
660	ntt.inverse_transform(interleaved_evals, 0, log_batch, skip_rounds)?;
661
662	// Forward NTT: evaluate novel basis representation at consecutive cosets
663	for (i, extrapolated_chunk) in extrapolated_evals
664		.chunks_exact_mut(interleaved_evals.len())
665		.enumerate()
666	{
667		extrapolated_chunk.copy_from_slice(interleaved_evals);
668		ntt.forward_transform(extrapolated_chunk, (i + 1) as u32, log_batch, skip_rounds)?;
669	}
670
671	Ok(())
672}
673
674const fn extrapolated_evals_packed_len<P: PackedField>(
675	composition_degree: usize,
676	skip_rounds: usize,
677	log_batch: usize,
678) -> usize {
679	composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
680}
681
682#[cfg(test)]
683mod tests {
684	use std::sync::Arc;
685
686	use binius_field::{
687		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
688		as_packed_field::{PackScalar, PackedType},
689		underlier::UnderlierType,
690		BinaryField128b, BinaryField16b, BinaryField1b, BinaryField8b, ExtensionField, Field,
691		PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
692	};
693	use binius_hal::make_portable_backend;
694	use binius_math::{
695		CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly,
696	};
697	use binius_ntt::SingleThreadedNTT;
698	use rand::{prelude::StdRng, SeedableRng};
699
700	use crate::{
701		composition::{IndexComposition, ProductComposition},
702		polynomial::CompositionScalarAdapter,
703		protocols::{
704			sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
705			test_utils::generate_zero_product_multilinears,
706		},
707		transparent::eq_ind::EqIndPartialEval,
708	};
709
710	#[test]
711	fn ntt_extrapolate_correctness() {
712		type P = PackedBinaryField4x32b;
713		type FDomain = BinaryField16b;
714		let log_extension_degree_p_domain = 1;
715
716		let mut rng = StdRng::seed_from_u64(0);
717		let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
718		let domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
719		let max_domain = domain_factory.create(1 << 10).unwrap();
720
721		for skip_rounds in 0..5usize {
722			let domain = domain_factory.create(1 << skip_rounds).unwrap();
723			for log_batch in 0..3usize {
724				for composition_degree in 0..5usize {
725					let subcube_vars = skip_rounds + log_batch;
726					let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
727					let interleaved_evals = (0..interleaved_len)
728						.map(|_| P::random(&mut rng))
729						.collect::<Vec<_>>();
730
731					let extrapolated_scalars_cnt =
732						composition_degree.saturating_sub(1) << skip_rounds;
733					let extrapolated_ntts = composition_degree.saturating_sub(1);
734					let extrapolated_len = extrapolated_ntts * interleaved_len;
735					let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
736
737					let mut interleaved_evals_scratch = interleaved_evals.clone();
738
739					let interleaved_evals_domain =
740						P::cast_bases_mut(&mut interleaved_evals_scratch);
741					let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
742
743					super::ntt_extrapolate(
744						&ntt,
745						skip_rounds,
746						log_batch + log_extension_degree_p_domain,
747						composition_degree,
748						interleaved_evals_domain,
749						extrapolated_evals_domain,
750					)
751					.unwrap();
752
753					let interleaved_scalars =
754						&P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
755					let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
756						[..extrapolated_scalars_cnt << log_batch];
757
758					for batch_idx in 0..1 << log_batch {
759						let values = (0..1 << skip_rounds)
760							.map(|i| interleaved_scalars[(i << log_batch) + batch_idx])
761							.collect::<Vec<_>>();
762
763						for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
764							[..extrapolated_scalars_cnt]
765							.iter()
766							.take(1 << skip_rounds)
767							.enumerate()
768						{
769							let extrapolated = domain.extrapolate(&values, point.into()).unwrap();
770							let expected = extrapolated_scalars[(i << log_batch) + batch_idx];
771							assert_eq!(extrapolated, expected);
772						}
773					}
774				}
775			}
776		}
777	}
778
779	#[test]
780	fn zerocheck_univariate_evals_invariants_basic() {
781		zerocheck_univariate_evals_invariants_helper::<
782			OptimalUnderlier128b,
783			BinaryField128b,
784			BinaryField8b,
785			BinaryField16b,
786		>()
787	}
788
789	#[test]
790	fn zerocheck_univariate_evals_with_nontrivial_packing() {
791		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have a
792		// non-trivial packing width of 4.
793		zerocheck_univariate_evals_invariants_helper::<
794			OptimalUnderlier512b,
795			BinaryField128b,
796			BinaryField8b,
797			BinaryField16b,
798		>()
799	}
800
801	fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
802	where
803		U: UnderlierType
804			+ PackScalar<F>
805			+ PackScalar<FBase>
806			+ PackScalar<FDomain>
807			+ PackScalar<BinaryField1b>,
808		F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
809		FBase: TowerField + ExtensionField<FDomain>,
810		FDomain: TowerField + From<u8>,
811		PackedType<U, FBase>: PackedFieldIndexable,
812		PackedType<U, FDomain>: PackedFieldIndexable,
813		PackedType<U, F>: PackedFieldIndexable,
814	{
815		let mut rng = StdRng::seed_from_u64(0);
816
817		let n_vars = 7;
818		let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
819
820		let mut multilinears = generate_zero_product_multilinears::<
821			PackedType<U, BinaryField1b>,
822			PackedType<U, F>,
823		>(&mut rng, n_vars, 2);
824		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
825		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
826
827		let compositions = [
828			Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
829				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
830			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
831				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
832			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
833				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
834		];
835
836		let backend = make_portable_backend();
837		let zerocheck_challenges = (0..n_vars)
838			.map(|_| <F as Field>::random(&mut rng))
839			.collect::<Vec<_>>();
840
841		for skip_rounds in 0usize..=5 {
842			let max_domain_size = domain_size(5, skip_rounds);
843			let output =
844				zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
845					&multilinears,
846					&compositions,
847					&zerocheck_challenges[skip_rounds..],
848					skip_rounds,
849					max_domain_size,
850					&backend,
851				)
852				.unwrap();
853
854			let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
855				.multilinear_extension::<F, _>(&backend)
856				.unwrap();
857
858			// naive computation of the univariate skip output
859			let round_evals_len = 4usize << skip_rounds;
860			assert!(output
861				.round_evals
862				.iter()
863				.all(|round_evals| round_evals.len() == round_evals_len));
864
865			let compositions = compositions
866				.iter()
867				.cloned()
868				.map(CompositionScalarAdapter::new)
869				.collect::<Vec<_>>();
870
871			let mut query = [FBase::ZERO; 9];
872			let mut evals = vec![
873				PackedType::<U, F>::zero();
874				1 << skip_rounds.saturating_sub(
875					log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
876				)
877			];
878			let domain = DefaultEvaluationDomainFactory::<FDomain>::default()
879				.create(1 << skip_rounds)
880				.unwrap();
881			for round_evals_index in 0..round_evals_len {
882				let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
883				let mut composition_sums = vec![F::ZERO; compositions.len()];
884				for subcube_index in 0..1 << (n_vars - skip_rounds) {
885					for (query, multilinear) in query.iter_mut().zip(&multilinears) {
886						multilinear
887							.subcube_evals(
888								skip_rounds,
889								subcube_index,
890								log_embedding_degree,
891								&mut evals,
892							)
893							.unwrap();
894						let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
895							PackedExtension::<FBase>::cast_bases(&evals),
896						)[..1 << skip_rounds];
897						let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
898						*query = extrapolated;
899					}
900
901					let eq_ind_factor = zerocheck_eq_ind
902						.evaluate_on_hypercube(subcube_index)
903						.unwrap();
904					for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
905						*sum += eq_ind_factor * composition.evaluate(&query).unwrap();
906					}
907				}
908
909				let univariate_skip_composition_sums = output
910					.round_evals
911					.iter()
912					.map(|round_evals| round_evals[round_evals_index])
913					.collect::<Vec<_>>();
914				assert_eq!(univariate_skip_composition_sums, composition_sums);
915			}
916		}
917	}
918}