binius_core/protocols/sumcheck/prove/
univariate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6	packed::{get_packed_slice, get_packed_slice_checked},
7	recast_packed_mut,
8	util::inner_product_unchecked,
9	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, TowerField,
10};
11use binius_hal::ComputationBackend;
12use binius_math::{
13	BinarySubspace, CompositionPoly, Error as MathError, EvaluationDomain, MLEDirectAdapter,
14	MultilinearPoly, RowsBatchRef,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_ntt::{AdditiveNTT, NTTShape, OddInterpolate, SingleThreadedNTT};
18use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
19use bytemuck::zeroed_vec;
20use itertools::izip;
21use stackalloc::stackalloc_with_iter;
22use tracing::instrument;
23
24use crate::{
25	composition::{BivariateProduct, IndexComposition},
26	protocols::sumcheck::{
27		common::{equal_n_vars_check, small_field_embedding_degree_check},
28		prove::{
29			logging::{ExpandQueryData, UnivariateSkipCalculateCoeffsData},
30			RegularSumcheckProver,
31		},
32		zerocheck::{domain_size, extrapolated_scalars_count},
33		Error,
34	},
35};
36
37pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
38	'a,
39	FDomain,
40	P,
41	IndexComposition<BivariateProduct, 2>,
42	MLEDirectAdapter<P>,
43	Backend,
44>;
45
46#[derive(Debug)]
47struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
48	/// Evaluations of a multilinear subcube, embedded into P (see MultilinearPoly::subcube_evals). Scratch space.
49	evals: Vec<P>,
50	/// `evals` extrapolated beyond first 2^skip_rounds domain points, per multilinear.
51	extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
52	/// Evals of a single composition over extrapolated multilinears. Scratch space.
53	composition_evals: Vec<PackedSubfield<P, FBase>>,
54	/// Packed round evals accumulators, per multilinear.
55	packed_round_evals: Vec<Vec<P>>,
56}
57
58impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
59	fn new(
60		n_multilinears: usize,
61		skip_rounds: usize,
62		log_batch: usize,
63		log_embedding_degree: usize,
64		composition_degrees: impl Iterator<Item = usize> + Clone,
65	) -> Self {
66		let subcube_vars = skip_rounds + log_batch;
67		let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
68		let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
69			composition_max_degree,
70			skip_rounds,
71			log_batch,
72		);
73
74		let evals =
75			zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
76
77		let extrapolated_evals = (0..n_multilinears)
78			.map(|_| zeroed_vec(extrapolated_packed_pbase_len))
79			.collect();
80
81		let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
82
83		let packed_round_evals = composition_degrees
84			.map(|composition_degree| {
85				zeroed_vec(extrapolated_evals_packed_len::<P>(composition_degree, skip_rounds, 0))
86			})
87			.collect();
88
89		Self {
90			evals,
91			extrapolated_evals,
92			composition_evals,
93			packed_round_evals,
94		}
95	}
96}
97
98#[derive(Debug)]
99pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
100where
101	F: Field,
102	P: PackedField<Scalar = F>,
103	Backend: ComputationBackend,
104{
105	pub round_evals: Vec<Vec<F>>,
106	skip_rounds: usize,
107	remaining_rounds: usize,
108	max_domain_size: usize,
109	partial_eq_ind_evals: Backend::Vec<P>,
110}
111
112pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
113where
114	F: Field,
115	P: PackedField<Scalar = F>,
116	Backend: ComputationBackend,
117{
118	pub remaining_rounds: usize,
119	pub subcube_lagrange_coeffs: Vec<F>,
120	pub claimed_sums: Vec<F>,
121	pub partial_eq_ind_evals: Backend::Vec<P>,
122}
123
124impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
125where
126	F: Field,
127	P: PackedField<Scalar = F>,
128	Backend: ComputationBackend,
129{
130	// Univariate round can be folded once the challenge has been sampled.
131	#[instrument(
132		skip_all,
133		name = "ZerocheckUnivariateEvalsOutput::fold",
134		level = "debug"
135	)]
136	pub fn fold<FDomain>(
137		self,
138		challenge: F,
139	) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
140	where
141		FDomain: TowerField,
142		F: ExtensionField<FDomain>,
143	{
144		let Self {
145			round_evals,
146			skip_rounds,
147			remaining_rounds,
148			max_domain_size,
149			partial_eq_ind_evals,
150		} = self;
151
152		// REVIEW: consider using novel basis for the univariate round representation
153		//         (instead of Lagrange)
154		let max_dim = log2_ceil_usize(max_domain_size);
155		let subspace =
156			BinarySubspace::<FDomain::Canonical>::with_dim(max_dim)?.isomorphic::<FDomain>();
157		let max_domain = EvaluationDomain::from_points(
158			subspace.iter().take(max_domain_size).collect::<Vec<_>>(),
159			false,
160		)?;
161
162		// Lagrange extrapolation over skipped subcube
163		let subcube_lagrange_coeffs = EvaluationDomain::from_points(
164			subspace.reduce_dim(skip_rounds)?.iter().collect::<Vec<_>>(),
165			false,
166		)?
167		.lagrange_evals(challenge);
168
169		// Lagrange extrapolation for the entire univariate domain
170		let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
171
172		let claimed_sums = round_evals
173			.into_iter()
174			.map(|evals| {
175				inner_product_unchecked::<F, F>(
176					evals,
177					round_evals_lagrange_coeffs[1 << skip_rounds..]
178						.iter()
179						.copied(),
180				)
181			})
182			.collect();
183
184		Ok(ZerocheckUnivariateFoldResult {
185			remaining_rounds,
186			subcube_lagrange_coeffs,
187			claimed_sums,
188			partial_eq_ind_evals,
189		})
190	}
191}
192
193/// Compute univariate skip round evaluations for zerocheck.
194///
195/// When all witness multilinear hypercube evaluations can be embedded into a small field
196/// `PBase::Scalar` that is significantly smaller than `F`, we naturally want to refrain from
197/// folding for `skip_rounds` (denoted below as $k$) to reap the benefits of faster small field
198/// multiplications. Naive extensions to sumcheck protocol which compute multivariate round
199/// polynomials do not work though, given that for a composition of degree $d$ one would need
200/// $(d+1)^k-2^k$ evaluations (assuming [Gruen24] section 3.2 optimizations), which usually grows faster
201/// than $2^k$ and thus will typically require more work than large field sumcheck. We adopt a
202/// univariatizing approach instead, where we define "oblong" multivariates:
203/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot L_u(\hat{u}_1)$$
204/// with $\mathbb{M}: \hat{u}_1 \rightarrow (u_1, \ldots, u_k)$ being some map from the univariate domain to
205/// the $\mathcal{B}_k$ hypercube and $L_u(\hat{u})$ being Lagrange polynomials.
206///
207/// The main idea of the univariatizing approach is that $\hat{M}$ are of degree $2^k-1$ in
208/// $\hat{u}_1$ and multilinear in other variables, thus evaluating a composition of degree $d$ over
209/// $\hat{M}$ yields a total degree of $d(2^k-1)$ in the first round (again, assuming [Gruen24]
210/// section 3.2 trick to avoid multiplication by the equality indicator), which is comparable to
211/// what a regular non-skipping zerocheck prover would do. The only issue is that we normally don't
212/// have an oracle for $\hat{M}$, which necessitates an extra sumcheck reduction to multilinear claims
213/// (see [univariatizing_reduction_claim](`super::super::zerocheck::univariatizing_reduction_claim`)).
214///
215/// One special trick of the univariate round is that round polynomial is represented in Lagrange form:
216///  1. Honest prover evaluates to zero on $2^k$ domain points mapping to $\mathcal{B}_k$, reducing proof size
217///  2. Avoiding monomial conversion saves prover time by skipping $O(N^3)$ inverse Vandermonde precomp
218///  3. Evaluation in the verifier can be made linear time when barycentric weights are precomputed
219///
220/// This implementation defines $\mathbb{M}$ to be the basis-induced mapping of the binary field `FDomain`;
221/// the main reason for that is to be able to use additive NTT from [LCH14] for extrapolation. The choice
222/// of domain field impacts performance, thus generally the smallest field with cardinality not less than
223/// the degree of the round polynomial should be used.
224///
225/// [LCH14]: <https://arxiv.org/abs/1404.3458>
226/// [Gruen24]: <https://eprint.iacr.org/2024/108>
227pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
228	multilinears: &[M],
229	compositions: &[Composition],
230	zerocheck_challenges: &[F],
231	skip_rounds: usize,
232	max_domain_size: usize,
233	backend: &Backend,
234) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
235where
236	FDomain: TowerField,
237	FBase: ExtensionField<FDomain>,
238	F: TowerField,
239	P: PackedField<Scalar = F> + PackedExtension<FBase> + PackedExtension<FDomain>,
240	Composition: CompositionPoly<PackedSubfield<P, FBase>>,
241	M: MultilinearPoly<P> + Send + Sync,
242	Backend: ComputationBackend,
243{
244	let n_vars = equal_n_vars_check(multilinears)?;
245	let n_multilinears = multilinears.len();
246
247	if skip_rounds > n_vars {
248		bail!(Error::TooManySkippedRounds);
249	}
250
251	let remaining_rounds = n_vars - skip_rounds;
252	if zerocheck_challenges.len() != remaining_rounds {
253		bail!(Error::IncorrectZerocheckChallengesLength);
254	}
255
256	small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
257
258	let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
259	let composition_degrees = compositions.iter().map(|composition| composition.degree());
260	let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
261
262	if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
263		bail!(Error::LagrangeDomainTooSmall);
264	}
265
266	// Batching factors for strided NTTs.
267	let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
268
269	// Check that domain field contains the required NTT cosets.
270	let min_domain_bits = log2_ceil_usize(max_domain_size);
271	if min_domain_bits > FDomain::N_BITS {
272		bail!(MathError::DomainSizeTooLarge);
273	}
274
275	// Only a domain size NTT is needed.
276	let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
277		.expect("FDomain cardinality checked before")
278		.precompute_twiddles();
279
280	// Smaller subcubes are batched together to reduce interpolation/evaluation overhead.
281	// REVIEW: make this a heuristic dependent on base field size and/or number of multilinears
282	//         to guarantee L1 cache (or accelerator scratchpad) non-eviction.
283	const MAX_SUBCUBE_VARS: usize = 12;
284	let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
285
286	// Expand the multilinear query in all but the first `skip_rounds` variables,
287	// where each tensor expansion element serves as a constant factor of the whole
288	// univariatized subcube.
289	// NB: expansion of the first `skip_rounds` variables is applied to the round evals sum
290	let dimensions_data = ExpandQueryData::new(zerocheck_challenges);
291	let expand_span = tracing::debug_span!(
292		"[task] Expand Query",
293		phase = "zerocheck",
294		perfetto_category = "task.main",
295		?dimensions_data,
296	)
297	.entered();
298	let partial_eq_ind_evals: <Backend as ComputationBackend>::Vec<P> =
299		backend.tensor_product_full_query(zerocheck_challenges)?;
300	drop(expand_span);
301
302	// Evaluate each composition on a minimal packed prefix corresponding to the degree
303	let pbase_prefix_lens = composition_degrees
304		.clone()
305		.map(|composition_degree| {
306			extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
307				composition_degree,
308				skip_rounds,
309				log_batch,
310			)
311		})
312		.collect::<Vec<_>>();
313	let dimensions_data =
314		UnivariateSkipCalculateCoeffsData::new(n_vars, skip_rounds, n_multilinears, log_batch);
315	let coeffs_span = tracing::debug_span!(
316		"[task] Univariate Skip Calculate coeffs",
317		phase = "zerocheck",
318		perfetto_category = "task.main",
319		?dimensions_data,
320	)
321	.entered();
322
323	let subcube_vars = log_batch + skip_rounds;
324	let log_subcube_count = n_vars - subcube_vars;
325
326	let p_coset_round_evals_len = 1 << skip_rounds.saturating_sub(P::LOG_WIDTH);
327	let pbase_coset_composition_evals_len =
328		1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree);
329
330	// NB: we avoid evaluation on the first 2^skip_rounds points because honest
331	// prover would always evaluate to zero there; we also factor out first
332	// skip_rounds terms of the equality indicator and apply them pointwise to
333	// the final round evaluations, which equates to lowering the composition_degree
334	// by one (this is an extension of Gruen section 3.2 trick)
335	let staggered_round_evals = (0..1 << log_subcube_count)
336		.into_par_iter()
337		.try_fold(
338			|| {
339				ParFoldStates::<FBase, P>::new(
340					n_multilinears,
341					skip_rounds,
342					log_batch,
343					log_embedding_degree,
344					composition_degrees.clone(),
345				)
346			},
347			|mut par_fold_states, subcube_index| -> Result<_, Error> {
348				let ParFoldStates {
349					evals,
350					extrapolated_evals,
351					composition_evals,
352					packed_round_evals,
353				} = &mut par_fold_states;
354
355				// Interpolate multilinear evals for each multilinear
356				for (multilinear, extrapolated_evals) in
357					izip!(multilinears, extrapolated_evals.iter_mut())
358				{
359					// Sample evals subcube from a multilinear poly
360					multilinear.subcube_evals(
361						subcube_vars,
362						subcube_index,
363						log_embedding_degree,
364						evals,
365					)?;
366
367					// Extrapolate evals using a conservative upper bound of the composition
368					// degree. We use Additive NTT to extrapolate evals beyond the first 2^skip_rounds,
369					// exploiting the fact that extension field NTT is a strided base field NTT.
370					let evals_base = <P as PackedExtension<FBase>>::cast_bases_mut(evals);
371					let evals_domain = recast_packed_mut::<P, FBase, FDomain>(evals_base);
372					let extrapolated_evals_domain =
373						recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
374
375					ntt_extrapolate(
376						&fdomain_ntt,
377						skip_rounds,
378						log_extension_degree_base_domain,
379						log_batch,
380						evals_domain,
381						extrapolated_evals_domain,
382					)?
383				}
384
385				// Evaluate the compositions and accumulate round results
386				for (composition, packed_round_evals, &pbase_prefix_len) in
387					izip!(compositions, packed_round_evals, &pbase_prefix_lens)
388				{
389					let extrapolated_evals_iter = extrapolated_evals
390						.iter()
391						.map(|evals| &evals[..pbase_prefix_len]);
392
393					stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
394						let batch_query = RowsBatchRef::new(batch_query, pbase_prefix_len);
395
396						// Evaluate the small field composition
397						composition.batch_evaluate(
398							&batch_query,
399							&mut composition_evals[..pbase_prefix_len],
400						)
401					})?;
402
403					// Accumulate round evals and multiply by the constant part of the
404					// zerocheck equality indicator
405					for (packed_round_evals_coset, composition_evals_coset) in izip!(
406						packed_round_evals.chunks_exact_mut(p_coset_round_evals_len,),
407						composition_evals.chunks_exact(pbase_coset_composition_evals_len)
408					) {
409						// At this point, the composition evals are laid out as a 3D array,
410						// with dimensions being (ordered by increasing stride):
411						//  1) 2^skip_rounds           - a low indexed subcube being "skipped"
412						//  2) 2^log_batch             - batch of adjacent subcubes
413						//  3) composition_degree - 1  - cosets of the subcube evaluation domain
414						// NB: each complete span of dim 1 gets multiplied by a constant from
415						// the equality indicator expansion, and dims 1+2 are padded up to the
416						// nearest packed field due to ntt_extrapolate implementation details
417						// (not performing sub-packed-field NTTs). This helper method handles
418						// multiplication of each dim 1 + 2 submatrix by the corresponding
419						// equality indicator subslice.
420						spread_product::<_, FBase>(
421							packed_round_evals_coset,
422							composition_evals_coset,
423							&partial_eq_ind_evals,
424							subcube_index,
425							skip_rounds,
426							log_batch,
427						);
428					}
429				}
430
431				Ok(par_fold_states)
432			},
433		)
434		.map(|states| -> Result<_, Error> {
435			let scalar_round_evals = izip!(composition_degrees.clone(), states?.packed_round_evals)
436				.map(|(composition_degree, packed_round_evals)| {
437					let mut composition_round_evals = Vec::with_capacity(
438						extrapolated_scalars_count(composition_degree, skip_rounds),
439					);
440
441					for packed_round_evals_coset in
442						packed_round_evals.chunks_exact(p_coset_round_evals_len)
443					{
444						let coset_scalars = packed_round_evals_coset
445							.iter()
446							.flat_map(|packed| packed.iter())
447							.take(1 << skip_rounds);
448
449						composition_round_evals.extend(coset_scalars);
450					}
451
452					composition_round_evals
453				})
454				.collect::<Vec<_>>();
455
456			Ok(scalar_round_evals)
457		})
458		.try_reduce(
459			|| {
460				composition_degrees
461					.clone()
462					.map(|composition_degree| {
463						zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
464					})
465					.collect()
466			},
467			|lhs, rhs| -> Result<_, Error> {
468				let round_evals_sum = izip!(lhs, rhs)
469					.map(|(mut lhs_vals, rhs_vals)| {
470						debug_assert_eq!(lhs_vals.len(), rhs_vals.len());
471						for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
472							*lhs_val += rhs_val;
473						}
474						lhs_vals
475					})
476					.collect();
477
478				Ok(round_evals_sum)
479			},
480		)?;
481
482	// So far evals of each composition are "staggered" in a sense that they are evaluated on the smallest
483	// domain which guarantees uniqueness of the round polynomial. We extrapolate them to max_domain_size to
484	// aid in Gruen section 3.2 optimization below and batch mixing.
485	let round_evals = extrapolate_round_evals(staggered_round_evals, skip_rounds, max_domain_size)?;
486	drop(coeffs_span);
487
488	Ok(ZerocheckUnivariateEvalsOutput {
489		round_evals,
490		skip_rounds,
491		remaining_rounds,
492		max_domain_size,
493		partial_eq_ind_evals,
494	})
495}
496
497// A helper to perform spread multiplication of small field composition evals by appropriate
498// equality indicator scalars. See `zerocheck_univariate_evals` impl for intuition.
499fn spread_product<P, FBase>(
500	accum: &mut [P],
501	small: &[PackedSubfield<P, FBase>],
502	large: &[P],
503	subcube_index: usize,
504	log_n: usize,
505	log_batch: usize,
506) where
507	P: PackedExtension<FBase>,
508	FBase: Field,
509{
510	let log_embedding_degree = <P::Scalar as ExtensionField<FBase>>::LOG_DEGREE;
511	let pbase_log_width = P::LOG_WIDTH + log_embedding_degree;
512
513	debug_assert_eq!(accum.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH));
514	debug_assert_eq!(small.len(), 1 << (log_n + log_batch).saturating_sub(pbase_log_width));
515
516	if log_n >= P::LOG_WIDTH {
517		// Use spread multiplication on fast path.
518		let mask = (1 << log_embedding_degree) - 1;
519		for batch_idx in 0..1 << log_batch {
520			let mult = get_packed_slice(large, subcube_index << log_batch | batch_idx);
521			let spread_large = P::cast_base(P::broadcast(mult));
522
523			for (block_idx, dest) in accum.iter_mut().enumerate() {
524				let block_offset = block_idx | batch_idx << (log_n - P::LOG_WIDTH);
525				let spread_small = small[block_offset >> log_embedding_degree]
526					.spread(P::LOG_WIDTH, block_offset & mask);
527				*dest += P::cast_ext(spread_large * spread_small);
528			}
529		}
530	} else {
531		// Multiple skipped subcube evaluations do fit into a single packed field
532		// This never occurs with large traces under frontloaded univariate skip batching,
533		// making this a non-critical slow path.
534		for (outer_idx, dest) in accum.iter_mut().enumerate() {
535			*dest = P::from_fn(|inner_idx| {
536				if inner_idx >= 1 << log_n {
537					return P::Scalar::ZERO;
538				}
539				(0..1 << log_batch)
540					.map(|batch_idx| {
541						let large = get_packed_slice(large, subcube_index << log_batch | batch_idx);
542						let small = get_packed_slice_checked(
543							small,
544							batch_idx << log_n | outer_idx << P::LOG_WIDTH | inner_idx,
545						)
546						.unwrap_or_default();
547						large * small
548					})
549					.sum()
550			})
551		}
552	}
553}
554
555// Extrapolate round evaluations to the full domain.
556// NB: this method relies on the fact that `round_evals` have specific lengths
557// (namely `d * 2^n`, where `n` is not less than the number of skipped rounds and thus d
558// is not larger than the composition degree), which enables additive-NTT based subquadratic
559// techniques.
560#[instrument(skip_all, level = "debug")]
561fn extrapolate_round_evals<F: TowerField>(
562	mut round_evals: Vec<Vec<F>>,
563	skip_rounds: usize,
564	max_domain_size: usize,
565) -> Result<Vec<Vec<F>>, Error> {
566	// Instantiate a large enough NTT over F to be able to forward transform to full domain size.
567	// REVIEW: should be possible to use an existing FDomain NTT with striding, possibly with larger domain.
568	let ntt = SingleThreadedNTT::with_canonical_field(log2_ceil_usize(max_domain_size))?;
569
570	// Cache OddInterpolate instances, which, albeit small in practice, take cubic time to create.
571	let mut odd_interpolates = HashMap::new();
572
573	for round_evals in &mut round_evals {
574		// Re-add zero evaluations at the beginning.
575		round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
576
577		let n = round_evals.len();
578
579		// Get OddInterpolate instance of required size.
580		let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
581			let ell = n.trailing_zeros() as usize;
582			assert!(ell >= skip_rounds);
583
584			OddInterpolate::new(n >> ell, ell, ntt.twiddles())
585				.expect("domain large enough by construction")
586		});
587
588		// Obtain novel polynomial basis representation of round evaluations.
589		odd_interpolate.inverse_transform(&ntt, round_evals)?;
590
591		// Use forward NTT to extrapolate novel representation to the max domain size.
592		let next_log_n = log2_ceil_usize(max_domain_size);
593		round_evals.resize(1 << next_log_n, F::ZERO);
594
595		let shape = NTTShape {
596			log_y: next_log_n,
597			..Default::default()
598		};
599		ntt.forward_transform(round_evals, shape, 0, 0)?;
600
601		// Sanity check: first 1 << skip_rounds evals are still zeros.
602		debug_assert!(round_evals[..1 << skip_rounds]
603			.iter()
604			.all(|&coeff| coeff == F::ZERO));
605
606		// Trim the result.
607		round_evals.resize(max_domain_size, F::ZERO);
608		round_evals.drain(..1 << skip_rounds);
609	}
610
611	Ok(round_evals)
612}
613
614fn ntt_extrapolate<NTT, P>(
615	ntt: &NTT,
616	skip_rounds: usize,
617	log_stride_batch: usize,
618	log_batch: usize,
619	evals: &mut [P],
620	extrapolated_evals: &mut [P],
621) -> Result<(), Error>
622where
623	P: PackedField<Scalar: BinaryField>,
624	NTT: AdditiveNTT<P::Scalar>,
625{
626	let shape = NTTShape {
627		log_x: log_stride_batch,
628		log_y: skip_rounds,
629		log_z: log_batch,
630	};
631
632	// Inverse NTT: convert evals to novel basis representation
633	ntt.inverse_transform(evals, shape, 0, 0)?;
634
635	// Forward NTT: evaluate novel basis representation at consecutive cosets
636	for (coset, extrapolated_chunk) in
637		izip!(1u32.., extrapolated_evals.chunks_exact_mut(evals.len()))
638	{
639		// REVIEW: can avoid that copy (and extrapolated_evals scratchpad) when composition_max_degree == 2
640		extrapolated_chunk.copy_from_slice(evals);
641		ntt.forward_transform(extrapolated_chunk, shape, coset, 0)?;
642	}
643
644	Ok(())
645}
646
647const fn extrapolated_evals_packed_len<P: PackedField>(
648	composition_degree: usize,
649	skip_rounds: usize,
650	log_batch: usize,
651) -> usize {
652	composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
653}
654
655#[cfg(test)]
656mod tests {
657	use std::sync::Arc;
658
659	use binius_field::{
660		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
661		as_packed_field::{PackScalar, PackedType},
662		underlier::UnderlierType,
663		BinaryField128b, BinaryField16b, BinaryField1b, BinaryField8b, ExtensionField, Field,
664		PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
665	};
666	use binius_hal::make_portable_backend;
667	use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearPoly};
668	use binius_ntt::SingleThreadedNTT;
669	use rand::{prelude::StdRng, SeedableRng};
670
671	use crate::{
672		composition::{IndexComposition, ProductComposition},
673		polynomial::CompositionScalarAdapter,
674		protocols::{
675			sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
676			test_utils::generate_zero_product_multilinears,
677		},
678		transparent::eq_ind::EqIndPartialEval,
679	};
680
681	#[test]
682	fn ntt_extrapolate_correctness() {
683		type P = PackedBinaryField4x32b;
684		type FDomain = BinaryField16b;
685		let log_extension_degree_p_domain = 1;
686
687		let mut rng = StdRng::seed_from_u64(0);
688		let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
689		let subspace = BinarySubspace::<FDomain>::with_dim(10).unwrap();
690		let max_domain =
691			EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
692
693		for skip_rounds in 0..5usize {
694			let subsubspace = subspace.reduce_dim(skip_rounds).unwrap();
695			let domain =
696				EvaluationDomain::from_points(subsubspace.iter().collect::<Vec<_>>(), false)
697					.unwrap();
698			for log_batch in 0..3usize {
699				for composition_degree in 0..5usize {
700					let subcube_vars = skip_rounds + log_batch;
701					let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
702					let interleaved_evals = (0..interleaved_len)
703						.map(|_| P::random(&mut rng))
704						.collect::<Vec<_>>();
705
706					let extrapolated_scalars_cnt =
707						composition_degree.saturating_sub(1) << skip_rounds;
708					let extrapolated_ntts = composition_degree.saturating_sub(1);
709					let extrapolated_len = extrapolated_ntts * interleaved_len;
710					let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
711
712					let mut interleaved_evals_scratch = interleaved_evals.clone();
713
714					let interleaved_evals_domain =
715						P::cast_bases_mut(&mut interleaved_evals_scratch);
716					let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
717
718					super::ntt_extrapolate(
719						&ntt,
720						skip_rounds,
721						log_extension_degree_p_domain,
722						log_batch,
723						interleaved_evals_domain,
724						extrapolated_evals_domain,
725					)
726					.unwrap();
727
728					let interleaved_scalars =
729						&P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
730					let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
731						[..extrapolated_scalars_cnt << log_batch];
732
733					for batch_idx in 0..1 << log_batch {
734						let values =
735							&interleaved_scalars[batch_idx << skip_rounds..][..1 << skip_rounds];
736
737						for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
738							[..extrapolated_scalars_cnt]
739							.iter()
740							.take(1 << skip_rounds)
741							.enumerate()
742						{
743							let extrapolated = domain.extrapolate(values, point.into()).unwrap();
744							let expected = extrapolated_scalars[batch_idx << skip_rounds | i];
745							assert_eq!(extrapolated, expected);
746						}
747					}
748				}
749			}
750		}
751	}
752
753	#[test]
754	fn zerocheck_univariate_evals_invariants_basic() {
755		zerocheck_univariate_evals_invariants_helper::<
756			OptimalUnderlier128b,
757			BinaryField128b,
758			BinaryField8b,
759			BinaryField16b,
760		>()
761	}
762
763	#[test]
764	fn zerocheck_univariate_evals_with_nontrivial_packing() {
765		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have a
766		// non-trivial packing width of 4.
767		zerocheck_univariate_evals_invariants_helper::<
768			OptimalUnderlier512b,
769			BinaryField128b,
770			BinaryField8b,
771			BinaryField16b,
772		>()
773	}
774
775	fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
776	where
777		U: UnderlierType
778			+ PackScalar<F>
779			+ PackScalar<FBase>
780			+ PackScalar<FDomain>
781			+ PackScalar<BinaryField1b>,
782		F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
783		FBase: TowerField + ExtensionField<FDomain>,
784		FDomain: TowerField + From<u8>,
785		PackedType<U, FBase>: PackedFieldIndexable,
786		PackedType<U, FDomain>: PackedFieldIndexable,
787		PackedType<U, F>: PackedFieldIndexable,
788	{
789		let mut rng = StdRng::seed_from_u64(0);
790
791		let n_vars = 7;
792		let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
793
794		let mut multilinears = generate_zero_product_multilinears::<
795			PackedType<U, BinaryField1b>,
796			PackedType<U, F>,
797		>(&mut rng, n_vars, 2);
798		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
799		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
800
801		let compositions = [
802			Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
803				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
804			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
805				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
806			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
807				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
808		];
809
810		let backend = make_portable_backend();
811		let zerocheck_challenges = (0..n_vars)
812			.map(|_| <F as Field>::random(&mut rng))
813			.collect::<Vec<_>>();
814
815		for skip_rounds in 0usize..=5 {
816			let max_domain_size = domain_size(5, skip_rounds);
817			let output =
818				zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
819					&multilinears,
820					&compositions,
821					&zerocheck_challenges[skip_rounds..],
822					skip_rounds,
823					max_domain_size,
824					&backend,
825				)
826				.unwrap();
827
828			let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
829				.multilinear_extension::<F, _>(&backend)
830				.unwrap();
831
832			// naive computation of the univariate skip output
833			let round_evals_len = 4usize << skip_rounds;
834			assert!(output
835				.round_evals
836				.iter()
837				.all(|round_evals| round_evals.len() == round_evals_len));
838
839			let compositions = compositions
840				.iter()
841				.cloned()
842				.map(CompositionScalarAdapter::new)
843				.collect::<Vec<_>>();
844
845			let mut query = [FBase::ZERO; 9];
846			let mut evals = vec![
847				PackedType::<U, F>::zero();
848				1 << skip_rounds.saturating_sub(
849					log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
850				)
851			];
852			let subspace = BinarySubspace::<FDomain>::with_dim(skip_rounds).unwrap();
853			let domain =
854				EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
855			for round_evals_index in 0..round_evals_len {
856				let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
857				let mut composition_sums = vec![F::ZERO; compositions.len()];
858				for subcube_index in 0..1 << (n_vars - skip_rounds) {
859					for (query, multilinear) in query.iter_mut().zip(&multilinears) {
860						multilinear
861							.subcube_evals(
862								skip_rounds,
863								subcube_index,
864								log_embedding_degree,
865								&mut evals,
866							)
867							.unwrap();
868						let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
869							PackedExtension::<FBase>::cast_bases(&evals),
870						)[..1 << skip_rounds];
871						let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
872						*query = extrapolated;
873					}
874
875					let eq_ind_factor = zerocheck_eq_ind
876						.evaluate_on_hypercube(subcube_index)
877						.unwrap();
878					for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
879						*sum += eq_ind_factor * composition.evaluate(&query).unwrap();
880					}
881				}
882
883				let univariate_skip_composition_sums = output
884					.round_evals
885					.iter()
886					.map(|round_evals| round_evals[round_evals_index])
887					.collect::<Vec<_>>();
888				assert_eq!(univariate_skip_composition_sums, composition_sums);
889			}
890		}
891	}
892}