binius_core/protocols/sumcheck/prove/
univariate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6	recast_packed_mut, util::inner_product_unchecked, BinaryField, ExtensionField, Field,
7	PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield, TowerField,
8};
9use binius_hal::{ComputationBackend, ComputationBackendExt};
10use binius_math::{
11	BinarySubspace, CompositionPoly, Error as MathError, EvaluationDomain, EvaluationOrder,
12	IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearPoly, RowsBatchRef,
13};
14use binius_maybe_rayon::prelude::*;
15use binius_ntt::{AdditiveNTT, OddInterpolate, SingleThreadedNTT};
16use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
17use bytemuck::zeroed_vec;
18use itertools::izip;
19use stackalloc::stackalloc_with_iter;
20use tracing::instrument;
21use transpose::transpose;
22
23use crate::{
24	composition::{BivariateProduct, IndexComposition},
25	protocols::sumcheck::{
26		common::{
27			equal_n_vars_check, immediate_switchover_heuristic, small_field_embedding_degree_check,
28		},
29		prove::{common::fold_partial_eq_ind, RegularSumcheckProver},
30		univariate::{
31			lagrange_evals_multilinear_extension, univariatizing_reduction_composite_sum_claims,
32		},
33		univariate_zerocheck::{domain_size, extrapolated_scalars_count},
34		Error, VerificationError,
35	},
36};
37
38/// Helper method to reduce the witness to skipped variables via a partial high projection.
39#[instrument(skip_all, level = "debug")]
40pub fn reduce_to_skipped_projection<F, P, M, Backend>(
41	multilinears: Vec<M>,
42	sumcheck_challenges: &[F],
43	backend: &'_ Backend,
44) -> Result<Vec<MLEDirectAdapter<P>>, Error>
45where
46	F: Field,
47	P: PackedFieldIndexable<Scalar = F>,
48	M: MultilinearPoly<P> + Send + Sync,
49	Backend: ComputationBackend,
50{
51	let n_vars = equal_n_vars_check(&multilinears)?;
52
53	if sumcheck_challenges.len() > n_vars {
54		bail!(Error::IncorrectNumberOfChallenges);
55	}
56
57	let query = backend.multilinear_query(sumcheck_challenges)?;
58
59	let reduced_multilinears = multilinears
60		.par_iter()
61		.map(|multilinear| {
62			backend
63				.evaluate_partial_high(multilinear, query.to_ref())
64				.expect("0 <= sumcheck_challenges.len() < n_vars")
65				.into()
66		})
67		.collect();
68
69	Ok(reduced_multilinears)
70}
71
72pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
73	'a,
74	FDomain,
75	P,
76	IndexComposition<BivariateProduct, 2>,
77	MLEDirectAdapter<P>,
78	Backend,
79>;
80
81/// Create the sumcheck prover for the univariatizing reduction of multilinears
82/// (see [verifier side](crate::protocols::sumcheck::univariate::univariatizing_reduction_claim))
83///
84/// This method takes multilinears projected to first `skip_rounds` variables, constructs a multilinear
85/// extension of Lagrange evaluations at `univariate_challenge`, and creates a regular sumcheck prover,
86/// placing Lagrange evaluation in the last witness column.
87///
88/// Note that `univariatized_multilinear_evals` come from a previous sumcheck with a univariate first round.
89pub fn univariatizing_reduction_prover<'a, F, FDomain, P, Backend>(
90	mut reduced_multilinears: Vec<MLEDirectAdapter<P>>,
91	univariatized_multilinear_evals: &[F],
92	univariate_challenge: F,
93	backend: &'a Backend,
94) -> Result<Prover<'a, FDomain, P, Backend>, Error>
95where
96	F: TowerField,
97	FDomain: TowerField,
98	P: PackedFieldIndexable<Scalar = F>
99		+ PackedExtension<F, PackedSubfield = P>
100		+ PackedExtension<FDomain>,
101	Backend: ComputationBackend,
102{
103	let skip_rounds = equal_n_vars_check(&reduced_multilinears)?;
104
105	if univariatized_multilinear_evals.len() != reduced_multilinears.len() {
106		bail!(VerificationError::NumberOfFinalEvaluations);
107	}
108
109	let subspace =
110		BinarySubspace::<FDomain::Canonical>::with_dim(skip_rounds)?.isomorphic::<FDomain>();
111	let ntt_domain = EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
112
113	reduced_multilinears
114		.push(lagrange_evals_multilinear_extension(&ntt_domain, univariate_challenge)?.into());
115
116	let composite_sum_claims =
117		univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
118
119	let prover = RegularSumcheckProver::new(
120		EvaluationOrder::LowToHigh,
121		reduced_multilinears,
122		composite_sum_claims,
123		IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default(),
124		immediate_switchover_heuristic,
125		backend,
126	)?;
127
128	Ok(prover)
129}
130
131#[derive(Debug)]
132struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
133	/// Evaluations of a multilinear subcube, embedded into P (see MultilinearPoly::subcube_evals). Scratch space.
134	evals: Vec<P>,
135	/// `evals` cast to base field and transposed to 2^skip_rounds * 2^log_batch row-major form. Scratch space.
136	interleaved_evals: Vec<PackedSubfield<P, FBase>>,
137	/// `interleaved_evals` extrapolated beyond first 2^skip_rounds domain points, per multilinear.
138	extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
139	/// Evals of a single composition over extrapolated multilinears. Scratch space.
140	composition_evals: Vec<PackedSubfield<P, FBase>>,
141	/// Round evals accumulators, per multilinear.
142	round_evals: Vec<Vec<P::Scalar>>,
143}
144
145impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
146	fn new(
147		n_multilinears: usize,
148		skip_rounds: usize,
149		log_batch: usize,
150		log_embedding_degree: usize,
151		composition_degrees: impl Iterator<Item = usize> + Clone,
152	) -> Self {
153		let subcube_vars = skip_rounds + log_batch;
154		let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
155		let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
156			composition_max_degree,
157			skip_rounds,
158			log_batch,
159		);
160
161		let evals =
162			zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
163		let interleaved_evals =
164			zeroed_vec(1 << subcube_vars.saturating_sub(<PackedSubfield<P, FBase>>::LOG_WIDTH));
165
166		let extrapolated_evals = (0..n_multilinears)
167			.map(|_| zeroed_vec(extrapolated_packed_pbase_len))
168			.collect();
169
170		let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
171
172		let round_evals = composition_degrees
173			.map(|composition_degree| {
174				zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
175			})
176			.collect();
177
178		Self {
179			evals,
180			interleaved_evals,
181			extrapolated_evals,
182			composition_evals,
183			round_evals,
184		}
185	}
186}
187
188#[derive(Debug)]
189pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
190where
191	F: Field,
192	P: PackedField<Scalar = F>,
193	Backend: ComputationBackend,
194{
195	pub round_evals: Vec<Vec<F>>,
196	skip_rounds: usize,
197	remaining_rounds: usize,
198	max_domain_size: usize,
199	partial_eq_ind_evals: Backend::Vec<P>,
200}
201
202pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
203where
204	F: Field,
205	P: PackedField<Scalar = F>,
206	Backend: ComputationBackend,
207{
208	pub skip_rounds: usize,
209	pub subcube_lagrange_coeffs: Vec<F>,
210	pub claimed_sums: Vec<F>,
211	pub partial_eq_ind_evals: Backend::Vec<P>,
212}
213
214impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
215where
216	F: Field,
217	P: PackedFieldIndexable<Scalar = F>,
218	Backend: ComputationBackend,
219{
220	// Univariate round can be folded once the challenge has been sampled.
221	#[instrument(
222		skip_all,
223		name = "ZerocheckUnivariateEvalsOutput::fold",
224		level = "debug"
225	)]
226	pub fn fold<FDomain>(
227		self,
228		challenge: F,
229	) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
230	where
231		FDomain: TowerField,
232		F: ExtensionField<FDomain>,
233	{
234		let Self {
235			round_evals,
236			skip_rounds,
237			remaining_rounds,
238			max_domain_size,
239			mut partial_eq_ind_evals,
240		} = self;
241
242		// REVIEW: consider using novel basis for the univariate round representation
243		//         (instead of Lagrange)
244		let max_dim = log2_ceil_usize(max_domain_size);
245		let subspace =
246			BinarySubspace::<FDomain::Canonical>::with_dim(max_dim)?.isomorphic::<FDomain>();
247		let max_domain = EvaluationDomain::from_points(
248			subspace.iter().take(max_domain_size).collect::<Vec<_>>(),
249			false,
250		)?;
251
252		// Lagrange extrapolation over skipped subcube
253		let subcube_lagrange_coeffs = EvaluationDomain::from_points(
254			subspace.reduce_dim(skip_rounds)?.iter().collect::<Vec<_>>(),
255			false,
256		)?
257		.lagrange_evals(challenge);
258
259		// Zerocheck tensor expansion for the reduced zerocheck should be one variable less
260		fold_partial_eq_ind::<P, Backend>(
261			EvaluationOrder::LowToHigh,
262			remaining_rounds,
263			&mut partial_eq_ind_evals,
264		);
265
266		// Lagrange extrapolation for the entire univariate domain
267		let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
268
269		let claimed_sums = round_evals
270			.into_iter()
271			.map(|evals| {
272				inner_product_unchecked::<F, F>(
273					evals,
274					round_evals_lagrange_coeffs[1 << skip_rounds..]
275						.iter()
276						.copied(),
277				)
278			})
279			.collect();
280
281		Ok(ZerocheckUnivariateFoldResult {
282			skip_rounds,
283			subcube_lagrange_coeffs,
284			claimed_sums,
285			partial_eq_ind_evals,
286		})
287	}
288}
289
290/// Compute univariate skip round evaluations for zerocheck.
291///
292/// When all witness multilinear hypercube evaluations can be embedded into a small field
293/// `PBase::Scalar` that is significantly smaller than `F`, we naturally want to refrain from
294/// folding for `skip_rounds` (denoted below as $k$) to reap the benefits of faster small field
295/// multiplications. Naive extensions to sumcheck protocol which compute multivariate round
296/// polynomials do not work though, given that for a composition of degree $d$ one would need
297/// $(d+1)^k-2^k$ evaluations (assuming [Gruen24] section 3.2 optimizations), which usually grows faster
298/// than $2^k$ and thus will typically require more work than large field sumcheck. We adopt a
299/// univariatizing approach instead, where we define "oblong" multivariates:
300/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot L_u(\hat{u}_1)$$
301/// with $\mathbb{M}: \hat{u}_1 \rightarrow (u_1, \ldots, u_k)$ being some map from the univariate domain to
302/// the $\mathcal{B}_k$ hypercube and $L_u(\hat{u})$ being Lagrange polynomials.
303///
304/// The main idea of the univariatizing approach is that $\hat{M}$ are of degree $2^k-1$ in
305/// $\hat{u}_1$ and multilinear in other variables, thus evaluating a composition of degree $d$ over
306/// $\hat{M}$ yields a total degree of $d(2^k-1)$ in the first round (again, assuming [Gruen24]
307/// section 3.2 trick to avoid multiplication by the equality indicator), which is comparable to
308/// what a regular non-skipping zerocheck prover would do. The only issue is that we normally don't
309/// have an oracle for $\hat{M}$, which necessitates an extra sumcheck reduction to multilinear claims
310/// (see [univariatizing_reduction_claim](`super::super::univariate::univariatizing_reduction_claim`)).
311///
312/// One special trick of the univariate round is that round polynomial is represented in Lagrange form:
313///  1. Honest prover evaluates to zero on $2^k$ domain points mapping to $\mathcal{B}_k$, reducing proof size
314///  2. Avoiding monomial conversion saves prover time by skipping $O(N^3)$ inverse Vandermonde precomp
315///  3. Evaluation in the verifier can be made linear time when barycentric weights are precomputed
316///
317/// This implementation defines $\mathbb{M}$ to be the basis-induced mapping of the binary field `FDomain`;
318/// the main reason for that is to be able to use additive NTT from [LCH14] for extrapolation. The choice
319/// of domain field impacts performance, thus generally the smallest field with cardinality not less than
320/// the degree of the round polynomial should be used.
321///
322/// [LCH14]: <https://arxiv.org/abs/1404.3458>
323/// [Gruen24]: <https://eprint.iacr.org/2024/108>
324#[instrument(skip_all, level = "debug")]
325pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
326	multilinears: &[M],
327	compositions: &[Composition],
328	zerocheck_challenges: &[F],
329	skip_rounds: usize,
330	max_domain_size: usize,
331	backend: &Backend,
332) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
333where
334	FDomain: TowerField,
335	FBase: ExtensionField<FDomain>,
336	F: TowerField,
337	P: PackedFieldIndexable<Scalar = F>
338		+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
339		+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
340	Composition: CompositionPoly<PackedSubfield<P, FBase>>,
341	M: MultilinearPoly<P> + Send + Sync,
342	Backend: ComputationBackend,
343{
344	let n_vars = equal_n_vars_check(multilinears)?;
345	let n_multilinears = multilinears.len();
346
347	if skip_rounds > n_vars {
348		bail!(Error::TooManySkippedRounds);
349	}
350
351	let remaining_rounds = n_vars - skip_rounds;
352	if zerocheck_challenges.len() != remaining_rounds {
353		bail!(Error::IncorrectZerocheckChallengesLength);
354	}
355
356	small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
357
358	let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
359	let composition_degrees = compositions.iter().map(|composition| composition.degree());
360	let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
361
362	if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
363		bail!(Error::LagrangeDomainTooSmall);
364	}
365
366	// Batching factors for strided NTTs.
367	let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
368	let pdomain_log_width = <P as PackedExtension<FDomain>>::PackedSubfield::LOG_WIDTH;
369
370	// The lower bound is due to SingleThreadedNTT implementation quirk, which requires
371	// at least two packed field elements to be able to use PackedField::interleave.
372	let min_domain_bits = log2_ceil_usize(max_domain_size).max(pdomain_log_width + 1);
373	if min_domain_bits > FDomain::N_BITS {
374		bail!(MathError::DomainSizeTooLarge);
375	}
376
377	// Only a domain size NTT is needed.
378	let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
379		.expect("FDomain cardinality checked before")
380		.precompute_twiddles();
381
382	// Smaller subcubes are batched together to reduce interpolation/evaluation overhead.
383	const MAX_SUBCUBE_VARS: usize = 12;
384	let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
385
386	// Expand the multilinear query in all but the first `skip_rounds` variables,
387	// where each tensor expansion element serves as a constant factor of the whole
388	// univariatized subcube.
389	// NB: expansion of the first `skip_rounds` variables is applied to the round evals sum
390	let partial_eq_ind_evals = backend.tensor_product_full_query(zerocheck_challenges)?;
391	let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals);
392
393	// Evaluate each composition on a minimal packed prefix corresponding to the degree
394	let pbase_prefix_lens = composition_degrees
395		.clone()
396		.map(|composition_degree| {
397			extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
398				composition_degree,
399				skip_rounds,
400				log_batch,
401			)
402		})
403		.collect::<Vec<_>>();
404
405	let subcube_vars = log_batch + skip_rounds;
406	let log_subcube_count = n_vars - subcube_vars;
407
408	// NB: we avoid evaluation on the first 2^skip_rounds points because honest
409	// prover would always evaluate to zero there; we also factor out first
410	// skip_rounds terms of the equality indicator and apply them pointwise to
411	// the final round evaluations, which equates to lowering the composition_degree
412	// by one (this is an extension of Gruen section 3.2 trick)
413	let staggered_round_evals = (0..1 << log_subcube_count)
414		.into_par_iter()
415		.try_fold(
416			|| {
417				ParFoldStates::<FBase, P>::new(
418					n_multilinears,
419					skip_rounds,
420					log_batch,
421					log_embedding_degree,
422					composition_degrees.clone(),
423				)
424			},
425			|mut par_fold_states, subcube_index| -> Result<_, Error> {
426				let ParFoldStates {
427					evals,
428					interleaved_evals,
429					extrapolated_evals,
430					composition_evals,
431					round_evals,
432					..
433				} = &mut par_fold_states;
434
435				// Interpolate multilinear evals for each multilinear
436				for (multilinear, extrapolated_evals) in
437					izip!(multilinears, extrapolated_evals.iter_mut())
438				{
439					// Sample evals subcube from a multilinear poly
440					multilinear.subcube_evals(
441						subcube_vars,
442						subcube_index,
443						log_embedding_degree,
444						evals.as_mut_slice(),
445					)?;
446
447					// Use the PackedExtension bound to cast evals to base field.
448					let evals_base =
449						<P as PackedExtension<FBase>>::cast_bases_mut(evals.as_mut_slice());
450
451					// The evals subcube can be seen as 2^log_batch subcubes of skip_rounds
452					// variables, laid out sequentially in memory; this can be seen as
453					// row major 2^log_batch * 2^skip_rounds scalar matrix. We need to transpose
454					// it to 2^skip_rounds * 2_log_batch shape in order for the strided NTT to work.
455					let interleaved_evals_ref = if log_batch == 0 {
456						// In case of no batching, pass the slice as is.
457						evals_base
458					} else {
459						let evals_base_scalars =
460							&<PackedSubfield<P, FBase>>::unpack_scalars(evals_base)
461								[..1 << subcube_vars];
462						let interleaved_evals_scalars =
463							&mut <PackedSubfield<P, FBase>>::unpack_scalars_mut(
464								interleaved_evals.as_mut_slice(),
465							)[..1 << subcube_vars];
466
467						transpose(
468							evals_base_scalars,
469							interleaved_evals_scalars,
470							1 << skip_rounds,
471							1 << log_batch,
472						);
473
474						interleaved_evals.as_mut_slice()
475					};
476
477					// Extrapolate evals using a conservative upper bound of the composition
478					// degree. When evals are correctly strided, we can use additive NTT to
479					// extrapolate them beyond the first 2^skip_rounds. We use the fact that an NTT
480					// over the extension is just a strided NTT over the base field.
481					let interleaved_evals_bases =
482						recast_packed_mut::<P, FBase, FDomain>(interleaved_evals_ref);
483					let extrapolated_evals_bases =
484						recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
485
486					ntt_extrapolate(
487						&fdomain_ntt,
488						skip_rounds,
489						log_batch + log_extension_degree_base_domain,
490						composition_max_degree,
491						interleaved_evals_bases,
492						extrapolated_evals_bases,
493					)?
494				}
495
496				// Obtain 1 << log_batch partial equality indicator constant factors for each
497				// of the subcubes of size 1 << skip_rounds.
498				let partial_eq_ind_evals_scalars_subslice =
499					&partial_eq_ind_evals_scalars[subcube_index << log_batch..][..1 << log_batch];
500
501				// Evaluate the compositions and accumulate round results
502				for (composition, round_evals, &pbase_prefix_len) in
503					izip!(compositions, round_evals, &pbase_prefix_lens)
504				{
505					let extrapolated_evals_iter = extrapolated_evals
506						.iter()
507						.map(|evals| &evals[..pbase_prefix_len]);
508
509					stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
510						let batch_query = RowsBatchRef::new(batch_query, pbase_prefix_len);
511
512						// Evaluate the small field composition
513						composition.batch_evaluate(
514							&batch_query,
515							&mut composition_evals[..pbase_prefix_len],
516						)
517					})?;
518
519					// Accumulate round evals and multiply by the constant part of the
520					// zerocheck equality indicator
521					let composition_evals_scalars = <PackedSubfield<P, FBase>>::unpack_scalars_mut(
522						composition_evals.as_mut_slice(),
523					);
524
525					for (round_evals_coset, composition_evals_scalars_coset) in izip!(
526						round_evals.chunks_exact_mut(1 << skip_rounds),
527						composition_evals_scalars.chunks_exact(
528							1 << subcube_vars.max(log_embedding_degree + P::LOG_WIDTH)
529						)
530					) {
531						for (round_eval, composition_evals) in izip!(
532							round_evals_coset,
533							composition_evals_scalars_coset.chunks_exact(1 << log_batch),
534						) {
535							// Inner product is with the high n_vars - skip_rounds projection
536							// of the zerocheck equality indicator (one factor per subcube).
537							*round_eval += inner_product_unchecked(
538								partial_eq_ind_evals_scalars_subslice.iter().copied(),
539								composition_evals.iter().copied(),
540							);
541						}
542					}
543
544					// REVIEW: only slow path is left, fast path is to be reintroduced in the followup PRs
545					//         targeted on dropping PackedFieldIndexable
546				}
547
548				Ok(par_fold_states)
549			},
550		)
551		.map(|states| -> Result<_, Error> { Ok(states?.round_evals) })
552		.try_reduce(
553			|| {
554				composition_degrees
555					.clone()
556					.map(|composition_degree| {
557						zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
558					})
559					.collect()
560			},
561			|lhs, rhs| -> Result<_, Error> {
562				let round_evals_sum = izip!(lhs, rhs)
563					.map(|(mut lhs_vals, rhs_vals)| {
564						for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
565							*lhs_val += rhs_val;
566						}
567						lhs_vals
568					})
569					.collect();
570
571				Ok(round_evals_sum)
572			},
573		)?;
574
575	// So far evals of each composition are "staggered" in a sense that they are evaluated on the smallest
576	// domain which guarantees uniqueness of the round polynomial. We extrapolate them to max_domain_size to
577	// aid in Gruen section 3.2 optimization below and batch mixing.
578	let round_evals = extrapolate_round_evals(staggered_round_evals, skip_rounds, max_domain_size)?;
579
580	Ok(ZerocheckUnivariateEvalsOutput {
581		round_evals,
582		skip_rounds,
583		remaining_rounds,
584		max_domain_size,
585		partial_eq_ind_evals,
586	})
587}
588
589// Extrapolate round evaluations to the full domain.
590// NB: this method relies on the fact that `round_evals` have specific lengths
591// (namely `d * 2^n`, where `n` is not less than the number of skipped rounds and thus d
592// is not larger than the composition degree), which enables additive-NTT based subquadratic
593// techniques.
594#[instrument(skip_all, level = "debug")]
595fn extrapolate_round_evals<F: TowerField>(
596	mut round_evals: Vec<Vec<F>>,
597	skip_rounds: usize,
598	max_domain_size: usize,
599) -> Result<Vec<Vec<F>>, Error> {
600	// Instantiate a large enough NTT over F to be able to forward transform to full domain size.
601	// REVIEW: should be possible to use an existing FDomain NTT with striding.
602	let ntt = SingleThreadedNTT::with_canonical_field(log2_ceil_usize(max_domain_size))?;
603
604	// Cache OddInterpolate instances, which, albeit small in practice, take cubic time to create.
605	let mut odd_interpolates = HashMap::new();
606
607	for round_evals in &mut round_evals {
608		// Re-add zero evaluations at the beginning.
609		round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
610
611		let n = round_evals.len();
612
613		// Get OddInterpolate instance of required size.
614		let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
615			let ell = n.trailing_zeros() as usize;
616			assert!(ell >= skip_rounds);
617
618			OddInterpolate::new(n >> ell, ell, ntt.twiddles())
619				.expect("domain large enough by construction")
620		});
621
622		// Obtain novel polynomial basis representation of round evaluations.
623		odd_interpolate.inverse_transform(&ntt, round_evals)?;
624
625		// Use forward NTT to extrapolate novel representation to the max domain size.
626		let next_log_n = log2_ceil_usize(max_domain_size);
627		round_evals.resize(1 << next_log_n, F::ZERO);
628
629		ntt.forward_transform(round_evals, 0, 0, next_log_n)?;
630
631		// Sanity check: first 1 << skip_rounds evals are still zeros.
632		debug_assert!(round_evals[..1 << skip_rounds]
633			.iter()
634			.all(|&coeff| coeff == F::ZERO));
635
636		// Trim the result.
637		round_evals.resize(max_domain_size, F::ZERO);
638		round_evals.drain(..1 << skip_rounds);
639	}
640
641	Ok(round_evals)
642}
643
644fn ntt_extrapolate<NTT, P>(
645	ntt: &NTT,
646	skip_rounds: usize,
647	log_batch: usize,
648	composition_max_degree: usize,
649	interleaved_evals: &mut [P],
650	extrapolated_evals: &mut [P],
651) -> Result<(), Error>
652where
653	P: PackedFieldIndexable<Scalar: BinaryField>,
654	NTT: AdditiveNTT<P::Scalar>,
655{
656	let subcube_vars = skip_rounds + log_batch;
657	debug_assert_eq!(1 << subcube_vars.saturating_sub(P::LOG_WIDTH), interleaved_evals.len());
658	debug_assert_eq!(
659		extrapolated_evals_packed_len::<P>(composition_max_degree, skip_rounds, log_batch),
660		extrapolated_evals.len()
661	);
662	debug_assert!(
663		NTT::log_domain_size(ntt)
664			>= log2_ceil_usize(domain_size(composition_max_degree, skip_rounds))
665	);
666
667	// Inverse NTT: convert evals to novel basis representation
668	ntt.inverse_transform(interleaved_evals, 0, log_batch, skip_rounds)?;
669
670	// Forward NTT: evaluate novel basis representation at consecutive cosets
671	for (i, extrapolated_chunk) in extrapolated_evals
672		.chunks_exact_mut(interleaved_evals.len())
673		.enumerate()
674	{
675		extrapolated_chunk.copy_from_slice(interleaved_evals);
676		ntt.forward_transform(extrapolated_chunk, (i + 1) as u32, log_batch, skip_rounds)?;
677	}
678
679	Ok(())
680}
681
682const fn extrapolated_evals_packed_len<P: PackedField>(
683	composition_degree: usize,
684	skip_rounds: usize,
685	log_batch: usize,
686) -> usize {
687	composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
688}
689
690#[cfg(test)]
691mod tests {
692	use std::sync::Arc;
693
694	use binius_field::{
695		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
696		as_packed_field::{PackScalar, PackedType},
697		underlier::UnderlierType,
698		BinaryField128b, BinaryField16b, BinaryField1b, BinaryField8b, ExtensionField, Field,
699		PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
700	};
701	use binius_hal::make_portable_backend;
702	use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearPoly};
703	use binius_ntt::SingleThreadedNTT;
704	use rand::{prelude::StdRng, SeedableRng};
705
706	use crate::{
707		composition::{IndexComposition, ProductComposition},
708		polynomial::CompositionScalarAdapter,
709		protocols::{
710			sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
711			test_utils::generate_zero_product_multilinears,
712		},
713		transparent::eq_ind::EqIndPartialEval,
714	};
715
716	#[test]
717	fn ntt_extrapolate_correctness() {
718		type P = PackedBinaryField4x32b;
719		type FDomain = BinaryField16b;
720		let log_extension_degree_p_domain = 1;
721
722		let mut rng = StdRng::seed_from_u64(0);
723		let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
724		let subspace = BinarySubspace::<FDomain>::with_dim(10).unwrap();
725		let max_domain =
726			EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
727
728		for skip_rounds in 0..5usize {
729			let subsubspace = subspace.reduce_dim(skip_rounds).unwrap();
730			let domain =
731				EvaluationDomain::from_points(subsubspace.iter().collect::<Vec<_>>(), false)
732					.unwrap();
733			for log_batch in 0..3usize {
734				for composition_degree in 0..5usize {
735					let subcube_vars = skip_rounds + log_batch;
736					let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
737					let interleaved_evals = (0..interleaved_len)
738						.map(|_| P::random(&mut rng))
739						.collect::<Vec<_>>();
740
741					let extrapolated_scalars_cnt =
742						composition_degree.saturating_sub(1) << skip_rounds;
743					let extrapolated_ntts = composition_degree.saturating_sub(1);
744					let extrapolated_len = extrapolated_ntts * interleaved_len;
745					let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
746
747					let mut interleaved_evals_scratch = interleaved_evals.clone();
748
749					let interleaved_evals_domain =
750						P::cast_bases_mut(&mut interleaved_evals_scratch);
751					let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
752
753					super::ntt_extrapolate(
754						&ntt,
755						skip_rounds,
756						log_batch + log_extension_degree_p_domain,
757						composition_degree,
758						interleaved_evals_domain,
759						extrapolated_evals_domain,
760					)
761					.unwrap();
762
763					let interleaved_scalars =
764						&P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
765					let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
766						[..extrapolated_scalars_cnt << log_batch];
767
768					for batch_idx in 0..1 << log_batch {
769						let values = (0..1 << skip_rounds)
770							.map(|i| interleaved_scalars[(i << log_batch) + batch_idx])
771							.collect::<Vec<_>>();
772
773						for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
774							[..extrapolated_scalars_cnt]
775							.iter()
776							.take(1 << skip_rounds)
777							.enumerate()
778						{
779							let extrapolated = domain.extrapolate(&values, point.into()).unwrap();
780							let expected = extrapolated_scalars[(i << log_batch) + batch_idx];
781							assert_eq!(extrapolated, expected);
782						}
783					}
784				}
785			}
786		}
787	}
788
789	#[test]
790	fn zerocheck_univariate_evals_invariants_basic() {
791		zerocheck_univariate_evals_invariants_helper::<
792			OptimalUnderlier128b,
793			BinaryField128b,
794			BinaryField8b,
795			BinaryField16b,
796		>()
797	}
798
799	#[test]
800	fn zerocheck_univariate_evals_with_nontrivial_packing() {
801		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have a
802		// non-trivial packing width of 4.
803		zerocheck_univariate_evals_invariants_helper::<
804			OptimalUnderlier512b,
805			BinaryField128b,
806			BinaryField8b,
807			BinaryField16b,
808		>()
809	}
810
811	fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
812	where
813		U: UnderlierType
814			+ PackScalar<F>
815			+ PackScalar<FBase>
816			+ PackScalar<FDomain>
817			+ PackScalar<BinaryField1b>,
818		F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
819		FBase: TowerField + ExtensionField<FDomain>,
820		FDomain: TowerField + From<u8>,
821		PackedType<U, FBase>: PackedFieldIndexable,
822		PackedType<U, FDomain>: PackedFieldIndexable,
823		PackedType<U, F>: PackedFieldIndexable,
824	{
825		let mut rng = StdRng::seed_from_u64(0);
826
827		let n_vars = 7;
828		let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
829
830		let mut multilinears = generate_zero_product_multilinears::<
831			PackedType<U, BinaryField1b>,
832			PackedType<U, F>,
833		>(&mut rng, n_vars, 2);
834		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
835		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
836
837		let compositions = [
838			Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
839				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
840			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
841				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
842			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
843				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
844		];
845
846		let backend = make_portable_backend();
847		let zerocheck_challenges = (0..n_vars)
848			.map(|_| <F as Field>::random(&mut rng))
849			.collect::<Vec<_>>();
850
851		for skip_rounds in 0usize..=5 {
852			let max_domain_size = domain_size(5, skip_rounds);
853			let output =
854				zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
855					&multilinears,
856					&compositions,
857					&zerocheck_challenges[skip_rounds..],
858					skip_rounds,
859					max_domain_size,
860					&backend,
861				)
862				.unwrap();
863
864			let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
865				.multilinear_extension::<F, _>(&backend)
866				.unwrap();
867
868			// naive computation of the univariate skip output
869			let round_evals_len = 4usize << skip_rounds;
870			assert!(output
871				.round_evals
872				.iter()
873				.all(|round_evals| round_evals.len() == round_evals_len));
874
875			let compositions = compositions
876				.iter()
877				.cloned()
878				.map(CompositionScalarAdapter::new)
879				.collect::<Vec<_>>();
880
881			let mut query = [FBase::ZERO; 9];
882			let mut evals = vec![
883				PackedType::<U, F>::zero();
884				1 << skip_rounds.saturating_sub(
885					log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
886				)
887			];
888			let subspace = BinarySubspace::<FDomain>::with_dim(skip_rounds).unwrap();
889			let domain =
890				EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
891			for round_evals_index in 0..round_evals_len {
892				let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
893				let mut composition_sums = vec![F::ZERO; compositions.len()];
894				for subcube_index in 0..1 << (n_vars - skip_rounds) {
895					for (query, multilinear) in query.iter_mut().zip(&multilinears) {
896						multilinear
897							.subcube_evals(
898								skip_rounds,
899								subcube_index,
900								log_embedding_degree,
901								&mut evals,
902							)
903							.unwrap();
904						let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
905							PackedExtension::<FBase>::cast_bases(&evals),
906						)[..1 << skip_rounds];
907						let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
908						*query = extrapolated;
909					}
910
911					let eq_ind_factor = zerocheck_eq_ind
912						.evaluate_on_hypercube(subcube_index)
913						.unwrap();
914					for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
915						*sum += eq_ind_factor * composition.evaluate(&query).unwrap();
916					}
917				}
918
919				let univariate_skip_composition_sums = output
920					.round_evals
921					.iter()
922					.map(|round_evals| round_evals[round_evals_index])
923					.collect::<Vec<_>>();
924				assert_eq!(univariate_skip_composition_sums, composition_sums);
925			}
926		}
927	}
928}