binius_core/protocols/sumcheck/prove/
univariate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{collections::HashMap, iter::repeat_n};
4
5use binius_field::{
6	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, TowerField,
7	packed::{get_packed_slice, get_packed_slice_checked},
8	recast_packed_mut,
9	util::inner_product_unchecked,
10};
11use binius_hal::ComputationBackend;
12use binius_math::{
13	BinarySubspace, CompositionPoly, Error as MathError, EvaluationDomain, MLEDirectAdapter,
14	MultilinearPoly, RowsBatchRef,
15};
16use binius_maybe_rayon::prelude::*;
17use binius_ntt::{
18	AdditiveNTT, NTTShape, OddInterpolate, SingleThreadedNTT, twiddle::TwiddleAccess,
19};
20use binius_utils::{bail, checked_arithmetics::log2_ceil_usize};
21use bytemuck::zeroed_vec;
22use itertools::izip;
23use stackalloc::stackalloc_with_iter;
24use tracing::instrument;
25
26use crate::{
27	composition::{BivariateProduct, IndexComposition},
28	protocols::sumcheck::{
29		Error,
30		common::{equal_n_vars_check, small_field_embedding_degree_check},
31		prove::{
32			RegularSumcheckProver,
33			logging::{ExpandQueryData, UnivariateSkipCalculateCoeffsData},
34		},
35		zerocheck::{domain_size, extrapolated_scalars_count},
36	},
37};
38
39pub type Prover<'a, FDomain, P, Backend> = RegularSumcheckProver<
40	'a,
41	FDomain,
42	P,
43	IndexComposition<BivariateProduct, 2>,
44	MLEDirectAdapter<P>,
45	Backend,
46>;
47
48#[derive(Debug)]
49struct ParFoldStates<FBase: Field, P: PackedExtension<FBase>> {
50	/// Evaluations of a multilinear subcube, embedded into P (see MultilinearPoly::subcube_evals).
51	/// Scratch space.
52	evals: Vec<P>,
53	/// `evals` extrapolated beyond first 2^skip_rounds domain points, per multilinear.
54	extrapolated_evals: Vec<Vec<PackedSubfield<P, FBase>>>,
55	/// Evals of a single composition over extrapolated multilinears. Scratch space.
56	composition_evals: Vec<PackedSubfield<P, FBase>>,
57	/// Packed round evals accumulators, per multilinear.
58	packed_round_evals: Vec<Vec<P>>,
59}
60
61impl<FBase: Field, P: PackedExtension<FBase>> ParFoldStates<FBase, P> {
62	fn new(
63		n_multilinears: usize,
64		skip_rounds: usize,
65		log_batch: usize,
66		log_embedding_degree: usize,
67		composition_degrees: impl Iterator<Item = usize> + Clone,
68	) -> Self {
69		let subcube_vars = skip_rounds + log_batch;
70		let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
71		let extrapolated_packed_pbase_len = extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
72			composition_max_degree,
73			skip_rounds,
74			log_batch,
75		);
76
77		let evals =
78			zeroed_vec(1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree));
79
80		let extrapolated_evals = (0..n_multilinears)
81			.map(|_| zeroed_vec(extrapolated_packed_pbase_len))
82			.collect();
83
84		let composition_evals = zeroed_vec(extrapolated_packed_pbase_len);
85
86		let packed_round_evals = composition_degrees
87			.map(|composition_degree| {
88				zeroed_vec(extrapolated_evals_packed_len::<P>(composition_degree, skip_rounds, 0))
89			})
90			.collect();
91
92		Self {
93			evals,
94			extrapolated_evals,
95			composition_evals,
96			packed_round_evals,
97		}
98	}
99}
100
101#[derive(Debug)]
102pub struct ZerocheckUnivariateEvalsOutput<F, P, Backend>
103where
104	F: Field,
105	P: PackedField<Scalar = F>,
106	Backend: ComputationBackend,
107{
108	pub round_evals: Vec<Vec<F>>,
109	skip_rounds: usize,
110	remaining_rounds: usize,
111	max_domain_size: usize,
112	partial_eq_ind_evals: Backend::Vec<P>,
113}
114
115pub struct ZerocheckUnivariateFoldResult<F, P, Backend>
116where
117	F: Field,
118	P: PackedField<Scalar = F>,
119	Backend: ComputationBackend,
120{
121	pub remaining_rounds: usize,
122	pub subcube_lagrange_coeffs: Vec<F>,
123	pub claimed_sums: Vec<F>,
124	pub partial_eq_ind_evals: Backend::Vec<P>,
125}
126
127impl<F, P, Backend> ZerocheckUnivariateEvalsOutput<F, P, Backend>
128where
129	F: Field,
130	P: PackedField<Scalar = F>,
131	Backend: ComputationBackend,
132{
133	// Univariate round can be folded once the challenge has been sampled.
134	#[instrument(
135		skip_all,
136		name = "ZerocheckUnivariateEvalsOutput::fold",
137		level = "debug"
138	)]
139	pub fn fold<FDomain>(
140		self,
141		challenge: F,
142	) -> Result<ZerocheckUnivariateFoldResult<F, P, Backend>, Error>
143	where
144		FDomain: TowerField,
145		F: ExtensionField<FDomain>,
146	{
147		let Self {
148			round_evals,
149			skip_rounds,
150			remaining_rounds,
151			max_domain_size,
152			partial_eq_ind_evals,
153		} = self;
154
155		// REVIEW: consider using novel basis for the univariate round representation
156		//         (instead of Lagrange)
157		let max_dim = log2_ceil_usize(max_domain_size);
158		let subspace =
159			BinarySubspace::<FDomain::Canonical>::with_dim(max_dim)?.isomorphic::<FDomain>();
160		let max_domain = EvaluationDomain::from_points(
161			subspace.iter().take(max_domain_size).collect::<Vec<_>>(),
162			false,
163		)?;
164
165		// Lagrange extrapolation over skipped subcube
166		let subcube_lagrange_coeffs = EvaluationDomain::from_points(
167			subspace.reduce_dim(skip_rounds)?.iter().collect::<Vec<_>>(),
168			false,
169		)?
170		.lagrange_evals(challenge);
171
172		// Lagrange extrapolation for the entire univariate domain
173		let round_evals_lagrange_coeffs = max_domain.lagrange_evals(challenge);
174
175		let claimed_sums = round_evals
176			.into_iter()
177			.map(|evals| {
178				inner_product_unchecked::<F, F>(
179					evals,
180					round_evals_lagrange_coeffs[1 << skip_rounds..]
181						.iter()
182						.copied(),
183				)
184			})
185			.collect();
186
187		Ok(ZerocheckUnivariateFoldResult {
188			remaining_rounds,
189			subcube_lagrange_coeffs,
190			claimed_sums,
191			partial_eq_ind_evals,
192		})
193	}
194}
195
196/// Compute univariate skip round evaluations for zerocheck.
197///
198/// When all witness multilinear hypercube evaluations can be embedded into a small field
199/// `PBase::Scalar` that is significantly smaller than `F`, we naturally want to refrain from
200/// folding for `skip_rounds` (denoted below as $k$) to reap the benefits of faster small field
201/// multiplications. Naive extensions to sumcheck protocol which compute multivariate round
202/// polynomials do not work though, given that for a composition of degree $d$ one would need
203/// $(d+1)^k-2^k$ evaluations (assuming [Gruen24] section 3.2 optimizations), which usually grows
204/// faster than $2^k$ and thus will typically require more work than large field sumcheck. We adopt
205/// a univariatizing approach instead, where we define "oblong" multivariates:
206/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot
207/// L_u(\hat{u}_1)$$ with $\mathbb{M}: \hat{u}_1 \rightarrow (u_1, \ldots, u_k)$ being some map from
208/// the univariate domain to the $\mathcal{B}_k$ hypercube and $L_u(\hat{u})$ being Lagrange
209/// polynomials.
210///
211/// The main idea of the univariatizing approach is that $\hat{M}$ are of degree $2^k-1$ in
212/// $\hat{u}_1$ and multilinear in other variables, thus evaluating a composition of degree $d$ over
213/// $\hat{M}$ yields a total degree of $d(2^k-1)$ in the first round (again, assuming [Gruen24]
214/// section 3.2 trick to avoid multiplication by the equality indicator), which is comparable to
215/// what a regular non-skipping zerocheck prover would do. The only issue is that we normally don't
216/// have an oracle for $\hat{M}$, which necessitates an extra sumcheck reduction to multilinear
217/// claims
218/// (see [univariatizing_reduction_claim](`super::super::zerocheck::univariatizing_reduction_claim`)).
219///
220/// One special trick of the univariate round is that round polynomial is represented in Lagrange
221/// form:
222///  1. Honest prover evaluates to zero on $2^k$ domain points mapping to $\mathcal{B}_k$, reducing
223///     proof size
224///  2. Avoiding monomial conversion saves prover time by skipping $O(N^3)$ inverse Vandermonde
225///     precomp
226///  3. Evaluation in the verifier can be made linear time when barycentric weights are precomputed
227///
228/// This implementation defines $\mathbb{M}$ to be the basis-induced mapping of the binary field
229/// `FDomain`; the main reason for that is to be able to use additive NTT from [LCH14] for
230/// extrapolation. The choice of domain field impacts performance, thus generally the smallest field
231/// with cardinality not less than the degree of the round polynomial should be used.
232///
233/// [LCH14]: <https://arxiv.org/abs/1404.3458>
234/// [Gruen24]: <https://eprint.iacr.org/2024/108>
235pub fn zerocheck_univariate_evals<F, FDomain, FBase, P, Composition, M, Backend>(
236	multilinears: &[M],
237	compositions: &[Composition],
238	zerocheck_challenges: &[F],
239	skip_rounds: usize,
240	max_domain_size: usize,
241	backend: &Backend,
242) -> Result<ZerocheckUnivariateEvalsOutput<F, P, Backend>, Error>
243where
244	FDomain: TowerField,
245	FBase: ExtensionField<FDomain>,
246	F: TowerField,
247	P: PackedField<Scalar = F> + PackedExtension<FBase> + PackedExtension<FDomain>,
248	Composition: CompositionPoly<PackedSubfield<P, FBase>>,
249	M: MultilinearPoly<P> + Send + Sync,
250	Backend: ComputationBackend,
251{
252	let n_vars = equal_n_vars_check(multilinears)?;
253	let n_multilinears = multilinears.len();
254
255	if skip_rounds > n_vars {
256		bail!(Error::TooManySkippedRounds);
257	}
258
259	let remaining_rounds = n_vars - skip_rounds;
260	if zerocheck_challenges.len() != remaining_rounds {
261		bail!(Error::IncorrectZerocheckChallengesLength);
262	}
263
264	small_field_embedding_degree_check::<_, FBase, P, _>(multilinears)?;
265
266	let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
267	let composition_degrees = compositions.iter().map(|composition| composition.degree());
268	let composition_max_degree = composition_degrees.clone().max().unwrap_or(0);
269
270	if max_domain_size < domain_size(composition_max_degree, skip_rounds) {
271		bail!(Error::LagrangeDomainTooSmall);
272	}
273
274	// Batching factors for strided NTTs.
275	let log_extension_degree_base_domain = <FBase as ExtensionField<FDomain>>::LOG_DEGREE;
276
277	// Check that domain field contains the required NTT cosets.
278	let min_domain_bits = log2_ceil_usize(max_domain_size);
279	if min_domain_bits > FDomain::N_BITS {
280		bail!(MathError::DomainSizeTooLarge);
281	}
282
283	// Only a domain size NTT is needed.
284	let fdomain_ntt = SingleThreadedNTT::<FDomain>::with_canonical_field(min_domain_bits)
285		.expect("FDomain cardinality checked before")
286		.precompute_twiddles();
287
288	// Smaller subcubes are batched together to reduce interpolation/evaluation overhead.
289	// REVIEW: make this a heuristic dependent on base field size and/or number of multilinears
290	//         to guarantee L1 cache (or accelerator scratchpad) non-eviction.
291	const MAX_SUBCUBE_VARS: usize = 12;
292	let log_batch = MAX_SUBCUBE_VARS.min(n_vars).saturating_sub(skip_rounds);
293
294	// Expand the multilinear query in all but the first `skip_rounds` variables,
295	// where each tensor expansion element serves as a constant factor of the whole
296	// univariatized subcube.
297	// NB: expansion of the first `skip_rounds` variables is applied to the round evals sum
298	let dimensions_data = ExpandQueryData::new(zerocheck_challenges);
299	let expand_span = tracing::debug_span!(
300		"[task] Expand Query",
301		phase = "zerocheck",
302		perfetto_category = "task.main",
303		?dimensions_data,
304	)
305	.entered();
306	let partial_eq_ind_evals: <Backend as ComputationBackend>::Vec<P> =
307		backend.tensor_product_full_query(zerocheck_challenges)?;
308	drop(expand_span);
309
310	// Evaluate each composition on a minimal packed prefix corresponding to the degree
311	let pbase_prefix_lens = composition_degrees
312		.clone()
313		.map(|composition_degree| {
314			extrapolated_evals_packed_len::<PackedSubfield<P, FBase>>(
315				composition_degree,
316				skip_rounds,
317				log_batch,
318			)
319		})
320		.collect::<Vec<_>>();
321	let dimensions_data =
322		UnivariateSkipCalculateCoeffsData::new(n_vars, skip_rounds, n_multilinears, log_batch);
323	let coeffs_span = tracing::debug_span!(
324		"[task] Univariate Skip Calculate coeffs",
325		phase = "zerocheck",
326		perfetto_category = "task.main",
327		?dimensions_data,
328	)
329	.entered();
330
331	let subcube_vars = log_batch + skip_rounds;
332	let log_subcube_count = n_vars - subcube_vars;
333
334	let p_coset_round_evals_len = 1 << skip_rounds.saturating_sub(P::LOG_WIDTH);
335	let pbase_coset_composition_evals_len =
336		1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree);
337
338	// NB: we avoid evaluation on the first 2^skip_rounds points because honest
339	// prover would always evaluate to zero there; we also factor out first
340	// skip_rounds terms of the equality indicator and apply them pointwise to
341	// the final round evaluations, which equates to lowering the composition_degree
342	// by one (this is an extension of Gruen section 3.2 trick)
343	let staggered_round_evals = (0..1 << log_subcube_count)
344		.into_par_iter()
345		.try_fold(
346			|| {
347				ParFoldStates::<FBase, P>::new(
348					n_multilinears,
349					skip_rounds,
350					log_batch,
351					log_embedding_degree,
352					composition_degrees.clone(),
353				)
354			},
355			|mut par_fold_states, subcube_index| -> Result<_, Error> {
356				let ParFoldStates {
357					evals,
358					extrapolated_evals,
359					composition_evals,
360					packed_round_evals,
361				} = &mut par_fold_states;
362
363				// Interpolate multilinear evals for each multilinear
364				for (multilinear, extrapolated_evals) in
365					izip!(multilinears, extrapolated_evals.iter_mut())
366				{
367					// Sample evals subcube from a multilinear poly
368					multilinear.subcube_evals(
369						subcube_vars,
370						subcube_index,
371						log_embedding_degree,
372						evals,
373					)?;
374
375					// Extrapolate evals using a conservative upper bound of the composition
376					// degree. We use Additive NTT to extrapolate evals beyond the first
377					// 2^skip_rounds, exploiting the fact that extension field NTT is a strided
378					// base field NTT.
379					let evals_base = <P as PackedExtension<FBase>>::cast_bases_mut(evals);
380					let evals_domain = recast_packed_mut::<P, FBase, FDomain>(evals_base);
381					let extrapolated_evals_domain =
382						recast_packed_mut::<P, FBase, FDomain>(extrapolated_evals);
383
384					ntt_extrapolate(
385						&fdomain_ntt,
386						skip_rounds,
387						log_extension_degree_base_domain,
388						log_batch,
389						evals_domain,
390						extrapolated_evals_domain,
391					)?
392				}
393
394				// Evaluate the compositions and accumulate round results
395				for (composition, packed_round_evals, &pbase_prefix_len) in
396					izip!(compositions, packed_round_evals, &pbase_prefix_lens)
397				{
398					let extrapolated_evals_iter = extrapolated_evals
399						.iter()
400						.map(|evals| &evals[..pbase_prefix_len]);
401
402					stackalloc_with_iter(n_multilinears, extrapolated_evals_iter, |batch_query| {
403						let batch_query = RowsBatchRef::new(batch_query, pbase_prefix_len);
404
405						// Evaluate the small field composition
406						composition.batch_evaluate(
407							&batch_query,
408							&mut composition_evals[..pbase_prefix_len],
409						)
410					})?;
411
412					// Accumulate round evals and multiply by the constant part of the
413					// zerocheck equality indicator
414					for (packed_round_evals_coset, composition_evals_coset) in izip!(
415						packed_round_evals.chunks_exact_mut(p_coset_round_evals_len,),
416						composition_evals.chunks_exact(pbase_coset_composition_evals_len)
417					) {
418						// At this point, the composition evals are laid out as a 3D array,
419						// with dimensions being (ordered by increasing stride):
420						//  1) 2^skip_rounds           - a low indexed subcube being "skipped"
421						//  2) 2^log_batch             - batch of adjacent subcubes
422						//  3) composition_degree - 1  - cosets of the subcube evaluation domain
423						// NB: each complete span of dim 1 gets multiplied by a constant from
424						// the equality indicator expansion, and dims 1+2 are padded up to the
425						// nearest packed field due to ntt_extrapolate implementation details
426						// (not performing sub-packed-field NTTs). This helper method handles
427						// multiplication of each dim 1 + 2 submatrix by the corresponding
428						// equality indicator subslice.
429						spread_product::<_, FBase>(
430							packed_round_evals_coset,
431							composition_evals_coset,
432							&partial_eq_ind_evals,
433							subcube_index,
434							skip_rounds,
435							log_batch,
436						);
437					}
438				}
439
440				Ok(par_fold_states)
441			},
442		)
443		.map(|states| -> Result<_, Error> {
444			let scalar_round_evals = izip!(composition_degrees.clone(), states?.packed_round_evals)
445				.map(|(composition_degree, packed_round_evals)| {
446					let mut composition_round_evals = Vec::with_capacity(
447						extrapolated_scalars_count(composition_degree, skip_rounds),
448					);
449
450					for packed_round_evals_coset in
451						packed_round_evals.chunks_exact(p_coset_round_evals_len)
452					{
453						let coset_scalars = packed_round_evals_coset
454							.iter()
455							.flat_map(|packed| packed.iter())
456							.take(1 << skip_rounds);
457
458						composition_round_evals.extend(coset_scalars);
459					}
460
461					composition_round_evals
462				})
463				.collect::<Vec<_>>();
464
465			Ok(scalar_round_evals)
466		})
467		.try_reduce(
468			|| {
469				composition_degrees
470					.clone()
471					.map(|composition_degree| {
472						zeroed_vec(extrapolated_scalars_count(composition_degree, skip_rounds))
473					})
474					.collect()
475			},
476			|lhs, rhs| -> Result<_, Error> {
477				let round_evals_sum = izip!(lhs, rhs)
478					.map(|(mut lhs_vals, rhs_vals)| {
479						debug_assert_eq!(lhs_vals.len(), rhs_vals.len());
480						for (lhs_val, rhs_val) in izip!(&mut lhs_vals, rhs_vals) {
481							*lhs_val += rhs_val;
482						}
483						lhs_vals
484					})
485					.collect();
486
487				Ok(round_evals_sum)
488			},
489		)?;
490
491	// So far evals of each composition are "staggered" in a sense that they are evaluated on the
492	// smallest domain which guarantees uniqueness of the round polynomial. We extrapolate them to
493	// max_domain_size to aid in Gruen section 3.2 optimization below and batch mixing.
494	let round_evals =
495		extrapolate_round_evals(&fdomain_ntt, staggered_round_evals, skip_rounds, max_domain_size)?;
496	drop(coeffs_span);
497
498	Ok(ZerocheckUnivariateEvalsOutput {
499		round_evals,
500		skip_rounds,
501		remaining_rounds,
502		max_domain_size,
503		partial_eq_ind_evals,
504	})
505}
506
507// A helper to perform spread multiplication of small field composition evals by appropriate
508// equality indicator scalars. See `zerocheck_univariate_evals` impl for intuition.
509fn spread_product<P, FBase>(
510	accum: &mut [P],
511	small: &[PackedSubfield<P, FBase>],
512	large: &[P],
513	subcube_index: usize,
514	log_n: usize,
515	log_batch: usize,
516) where
517	P: PackedExtension<FBase>,
518	FBase: Field,
519{
520	let log_embedding_degree = <P::Scalar as ExtensionField<FBase>>::LOG_DEGREE;
521	let pbase_log_width = P::LOG_WIDTH + log_embedding_degree;
522
523	debug_assert_eq!(accum.len(), 1 << log_n.saturating_sub(P::LOG_WIDTH));
524	debug_assert_eq!(small.len(), 1 << (log_n + log_batch).saturating_sub(pbase_log_width));
525
526	if log_n >= P::LOG_WIDTH {
527		// Use spread multiplication on fast path.
528		let mask = (1 << log_embedding_degree) - 1;
529		for batch_idx in 0..1 << log_batch {
530			let mult = get_packed_slice(large, subcube_index << log_batch | batch_idx);
531			let spread_large = P::cast_base(P::broadcast(mult));
532
533			for (block_idx, dest) in accum.iter_mut().enumerate() {
534				let block_offset = block_idx | batch_idx << (log_n - P::LOG_WIDTH);
535				let spread_small = small[block_offset >> log_embedding_degree]
536					.spread(P::LOG_WIDTH, block_offset & mask);
537				*dest += P::cast_ext(spread_large * spread_small);
538			}
539		}
540	} else {
541		// Multiple skipped subcube evaluations do fit into a single packed field
542		// This never occurs with large traces under frontloaded univariate skip batching,
543		// making this a non-critical slow path.
544		for (outer_idx, dest) in accum.iter_mut().enumerate() {
545			*dest = P::from_fn(|inner_idx| {
546				if inner_idx >= 1 << log_n {
547					return P::Scalar::ZERO;
548				}
549				(0..1 << log_batch)
550					.map(|batch_idx| {
551						let large = get_packed_slice(large, subcube_index << log_batch | batch_idx);
552						let small = get_packed_slice_checked(
553							small,
554							batch_idx << log_n | outer_idx << P::LOG_WIDTH | inner_idx,
555						)
556						.unwrap_or_default();
557						large * small
558					})
559					.sum()
560			})
561		}
562	}
563}
564
565// Extrapolate round evaluations to the full domain.
566// NB: this method relies on the fact that `round_evals` have specific lengths
567// (namely `d * 2^n`, where `n` is not less than the number of skipped rounds and thus d
568// is not larger than the composition degree), which enables additive-NTT based subquadratic
569// techniques.
570#[instrument(skip_all, level = "debug")]
571fn extrapolate_round_evals<F, FDomain, TA>(
572	ntt: &SingleThreadedNTT<FDomain, TA>,
573	mut round_evals: Vec<Vec<F>>,
574	skip_rounds: usize,
575	max_domain_size: usize,
576) -> Result<Vec<Vec<F>>, Error>
577where
578	F: BinaryField + ExtensionField<FDomain>,
579	FDomain: BinaryField,
580	TA: TwiddleAccess<FDomain>,
581{
582	// Instantiate a large enough NTT over F to be able to forward transform to full domain size.
583	// TODO: We can't currently use the `ntt` directly because it is over FDomain. It'd be nice to
584	// have helpers that apply a subfield NTT to an extension field vector, without the
585	// PackedExtension relation that `forward_transform_ext` and `inverse_transform_ext` require.
586	let subspace_upcast = BinarySubspace::new_unchecked(
587		ntt.subspace(ntt.log_domain_size())
588			.basis()
589			.iter()
590			.copied()
591			.map(F::from)
592			.collect(),
593	);
594	let ntt = SingleThreadedNTT::with_subspace(&subspace_upcast)
595		.expect("ntt provided is valid; subspace is equivalent but upcast to F");
596
597	// Cache OddInterpolate instances, which, albeit small in practice, take cubic time to create.
598	let mut odd_interpolates = HashMap::new();
599
600	for round_evals in &mut round_evals {
601		// Re-add zero evaluations at the beginning.
602		round_evals.splice(0..0, repeat_n(F::ZERO, 1 << skip_rounds));
603
604		let n = round_evals.len();
605
606		// Get OddInterpolate instance of required size.
607		let odd_interpolate = odd_interpolates.entry(n).or_insert_with(|| {
608			let ell = n.trailing_zeros() as usize;
609			assert!(ell >= skip_rounds);
610
611			let coset_bits = ntt.log_domain_size() - ell;
612			OddInterpolate::new(&ntt, n >> ell, ell, coset_bits)
613				.expect("domain large enough by construction")
614		});
615
616		// Obtain novel polynomial basis representation of round evaluations.
617		odd_interpolate.inverse_transform(round_evals)?;
618
619		// Use forward NTT to extrapolate novel representation to the max domain size.
620		let next_log_n = ntt.log_domain_size();
621		round_evals.resize(1 << next_log_n, F::ZERO);
622
623		let shape = NTTShape {
624			log_y: next_log_n,
625			..Default::default()
626		};
627		ntt.forward_transform(round_evals, shape, 0, 0, 0)?;
628
629		// Sanity check: first 1 << skip_rounds evals are still zeros.
630		debug_assert!(
631			round_evals[..1 << skip_rounds]
632				.iter()
633				.all(|&coeff| coeff == F::ZERO)
634		);
635
636		// Trim the result.
637		round_evals.resize(max_domain_size, F::ZERO);
638		round_evals.drain(..1 << skip_rounds);
639	}
640
641	Ok(round_evals)
642}
643
644fn ntt_extrapolate<NTT, P>(
645	ntt: &NTT,
646	skip_rounds: usize,
647	log_stride_batch: usize,
648	log_batch: usize,
649	evals: &mut [P],
650	extrapolated_evals: &mut [P],
651) -> Result<(), Error>
652where
653	P: PackedField<Scalar: BinaryField>,
654	NTT: AdditiveNTT<P::Scalar>,
655{
656	let shape = NTTShape {
657		log_x: log_stride_batch,
658		log_y: skip_rounds,
659		log_z: log_batch,
660	};
661
662	let coset_bits = ntt.log_domain_size() - skip_rounds;
663
664	// Inverse NTT: convert evals to novel basis representation
665	ntt.inverse_transform(evals, shape, 0, coset_bits, 0)?;
666
667	// Forward NTT: evaluate novel basis representation at consecutive cosets
668	for (coset, extrapolated_chunk) in izip!(1.., extrapolated_evals.chunks_exact_mut(evals.len()))
669	{
670		// REVIEW: can avoid that copy (and extrapolated_evals scratchpad) when
671		// composition_max_degree == 2
672		extrapolated_chunk.copy_from_slice(evals);
673		ntt.forward_transform(extrapolated_chunk, shape, coset, coset_bits, 0)?;
674	}
675
676	Ok(())
677}
678
679const fn extrapolated_evals_packed_len<P: PackedField>(
680	composition_degree: usize,
681	skip_rounds: usize,
682	log_batch: usize,
683) -> usize {
684	composition_degree.saturating_sub(1) << (skip_rounds + log_batch).saturating_sub(P::LOG_WIDTH)
685}
686
687#[cfg(test)]
688mod tests {
689	use std::sync::Arc;
690
691	use binius_field::{
692		BinaryField1b, BinaryField8b, BinaryField16b, BinaryField128b, ExtensionField, Field,
693		PackedBinaryField4x32b, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
694		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
695		as_packed_field::{PackScalar, PackedType},
696		underlier::UnderlierType,
697	};
698	use binius_hal::make_portable_backend;
699	use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearPoly};
700	use binius_ntt::SingleThreadedNTT;
701	use rand::{SeedableRng, prelude::StdRng};
702
703	use crate::{
704		composition::{IndexComposition, ProductComposition},
705		polynomial::CompositionScalarAdapter,
706		protocols::{
707			sumcheck::prove::univariate::{domain_size, zerocheck_univariate_evals},
708			test_utils::generate_zero_product_multilinears,
709		},
710		transparent::eq_ind::EqIndPartialEval,
711	};
712
713	#[test]
714	fn ntt_extrapolate_correctness() {
715		type P = PackedBinaryField4x32b;
716		type FDomain = BinaryField16b;
717		let log_extension_degree_p_domain = 1;
718
719		let mut rng = StdRng::seed_from_u64(0);
720		let ntt = SingleThreadedNTT::<FDomain>::new(10).unwrap();
721		let subspace = BinarySubspace::<FDomain>::with_dim(10).unwrap();
722		let max_domain =
723			EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
724
725		for skip_rounds in 0..5usize {
726			let subsubspace = subspace.reduce_dim(skip_rounds).unwrap();
727			let domain =
728				EvaluationDomain::from_points(subsubspace.iter().collect::<Vec<_>>(), false)
729					.unwrap();
730			for log_batch in 0..3usize {
731				for composition_degree in 0..5usize {
732					let subcube_vars = skip_rounds + log_batch;
733					let interleaved_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
734					let interleaved_evals = (0..interleaved_len)
735						.map(|_| P::random(&mut rng))
736						.collect::<Vec<_>>();
737
738					let extrapolated_scalars_cnt =
739						composition_degree.saturating_sub(1) << skip_rounds;
740					let extrapolated_ntts = composition_degree.saturating_sub(1);
741					let extrapolated_len = extrapolated_ntts * interleaved_len;
742					let mut extrapolated_evals = vec![P::zero(); extrapolated_len];
743
744					let mut interleaved_evals_scratch = interleaved_evals.clone();
745
746					let interleaved_evals_domain =
747						P::cast_bases_mut(&mut interleaved_evals_scratch);
748					let extrapolated_evals_domain = P::cast_bases_mut(&mut extrapolated_evals);
749
750					super::ntt_extrapolate(
751						&ntt,
752						skip_rounds,
753						log_extension_degree_p_domain,
754						log_batch,
755						interleaved_evals_domain,
756						extrapolated_evals_domain,
757					)
758					.unwrap();
759
760					let interleaved_scalars =
761						&P::unpack_scalars(&interleaved_evals)[..1 << subcube_vars];
762					let extrapolated_scalars = &P::unpack_scalars(&extrapolated_evals)
763						[..extrapolated_scalars_cnt << log_batch];
764
765					for batch_idx in 0..1 << log_batch {
766						let values =
767							&interleaved_scalars[batch_idx << skip_rounds..][..1 << skip_rounds];
768
769						for (i, &point) in max_domain.finite_points()[1 << skip_rounds..]
770							[..extrapolated_scalars_cnt]
771							.iter()
772							.take(1 << skip_rounds)
773							.enumerate()
774						{
775							let extrapolated = domain.extrapolate(values, point.into()).unwrap();
776							let expected = extrapolated_scalars[batch_idx << skip_rounds | i];
777							assert_eq!(extrapolated, expected);
778						}
779					}
780				}
781			}
782		}
783	}
784
785	#[test]
786	fn zerocheck_univariate_evals_invariants_basic() {
787		zerocheck_univariate_evals_invariants_helper::<
788			OptimalUnderlier128b,
789			BinaryField128b,
790			BinaryField8b,
791			BinaryField16b,
792		>()
793	}
794
795	#[test]
796	fn zerocheck_univariate_evals_with_nontrivial_packing() {
797		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have
798		// a non-trivial packing width of 4.
799		zerocheck_univariate_evals_invariants_helper::<
800			OptimalUnderlier512b,
801			BinaryField128b,
802			BinaryField8b,
803			BinaryField16b,
804		>()
805	}
806
807	fn zerocheck_univariate_evals_invariants_helper<U, F, FDomain, FBase>()
808	where
809		U: UnderlierType
810			+ PackScalar<F>
811			+ PackScalar<FBase>
812			+ PackScalar<FDomain>
813			+ PackScalar<BinaryField1b>,
814		F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase>,
815		FBase: TowerField + ExtensionField<FDomain>,
816		FDomain: TowerField + From<u8>,
817		PackedType<U, FBase>: PackedFieldIndexable,
818		PackedType<U, FDomain>: PackedFieldIndexable,
819		PackedType<U, F>: PackedFieldIndexable,
820	{
821		let mut rng = StdRng::seed_from_u64(0);
822
823		let n_vars = 7;
824		let log_embedding_degree = <F as ExtensionField<FBase>>::LOG_DEGREE;
825
826		let mut multilinears = generate_zero_product_multilinears::<
827			PackedType<U, BinaryField1b>,
828			PackedType<U, F>,
829		>(&mut rng, n_vars, 2);
830		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
831		multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
832
833		let compositions = [
834			Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap())
835				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
836			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap())
837				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
838			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap())
839				as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
840		];
841
842		let backend = make_portable_backend();
843		let zerocheck_challenges = (0..n_vars)
844			.map(|_| <F as Field>::random(&mut rng))
845			.collect::<Vec<_>>();
846
847		for skip_rounds in 0usize..=5 {
848			let max_domain_size = domain_size(5, skip_rounds);
849			let output =
850				zerocheck_univariate_evals::<F, FDomain, FBase, PackedType<U, F>, _, _, _>(
851					&multilinears,
852					&compositions,
853					&zerocheck_challenges[skip_rounds..],
854					skip_rounds,
855					max_domain_size,
856					&backend,
857				)
858				.unwrap();
859
860			let zerocheck_eq_ind = EqIndPartialEval::new(&zerocheck_challenges[skip_rounds..])
861				.multilinear_extension::<F, _>(&backend)
862				.unwrap();
863
864			// naive computation of the univariate skip output
865			let round_evals_len = 4usize << skip_rounds;
866			assert!(
867				output
868					.round_evals
869					.iter()
870					.all(|round_evals| round_evals.len() == round_evals_len)
871			);
872
873			let compositions = compositions
874				.iter()
875				.cloned()
876				.map(CompositionScalarAdapter::new)
877				.collect::<Vec<_>>();
878
879			let mut query = [FBase::ZERO; 9];
880			let mut evals = vec![
881				PackedType::<U, F>::zero();
882				1 << skip_rounds.saturating_sub(
883					log_embedding_degree + PackedType::<U, F>::LOG_WIDTH
884				)
885			];
886			let subspace = BinarySubspace::<FDomain>::with_dim(skip_rounds).unwrap();
887			let domain =
888				EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false).unwrap();
889			for round_evals_index in 0..round_evals_len {
890				let x = FDomain::from(((1 << skip_rounds) + round_evals_index) as u8);
891				let mut composition_sums = vec![F::ZERO; compositions.len()];
892				for subcube_index in 0..1 << (n_vars - skip_rounds) {
893					for (query, multilinear) in query.iter_mut().zip(&multilinears) {
894						multilinear
895							.subcube_evals(
896								skip_rounds,
897								subcube_index,
898								log_embedding_degree,
899								&mut evals,
900							)
901							.unwrap();
902						let evals_scalars = &PackedType::<U, FBase>::unpack_scalars(
903							PackedExtension::<FBase>::cast_bases(&evals),
904						)[..1 << skip_rounds];
905						let extrapolated = domain.extrapolate(evals_scalars, x.into()).unwrap();
906						*query = extrapolated;
907					}
908
909					let eq_ind_factor = zerocheck_eq_ind
910						.evaluate_on_hypercube(subcube_index)
911						.unwrap();
912					for (composition, sum) in compositions.iter().zip(composition_sums.iter_mut()) {
913						*sum += eq_ind_factor * composition.evaluate(&query).unwrap();
914					}
915				}
916
917				let univariate_skip_composition_sums = output
918					.round_evals
919					.iter()
920					.map(|round_evals| round_evals[round_evals_index])
921					.collect::<Vec<_>>();
922				assert_eq!(univariate_skip_composition_sums, composition_sums);
923			}
924		}
925	}
926}