binius_core/protocols/sumcheck/prove/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{marker::PhantomData, ops::Range, sync::Arc};
4
5use binius_field::{
6	util::{eq, powers},
7	ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield,
8	TowerField,
9};
10use binius_hal::{ComputationBackend, SumcheckEvaluator};
11use binius_math::{
12	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
13	MLEDirectAdapter, MultilinearPoly, MultilinearQuery,
14};
15use binius_maybe_rayon::prelude::*;
16use binius_utils::bail;
17use bytemuck::zeroed_vec;
18use getset::Getters;
19use itertools::izip;
20use stackalloc::stackalloc_with_default;
21use tracing::instrument;
22
23use crate::{
24	polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
25	protocols::sumcheck::{
26		common::{determine_switchovers, equal_n_vars_check, get_nontrivial_evaluation_points},
27		prove::{
28			common::fold_partial_eq_ind,
29			univariate::{
30				zerocheck_univariate_evals, ZerocheckUnivariateEvalsOutput,
31				ZerocheckUnivariateFoldResult,
32			},
33			ProverState, SumcheckInterpolator, SumcheckProver, UnivariateZerocheckProver,
34		},
35		univariate::LagrangeRoundEvals,
36		univariate_zerocheck::domain_size,
37		Error, RoundCoeffs,
38	},
39	witness::MultilinearWitness,
40};
41
42pub fn validate_witness<'a, F, P, M, Composition>(
43	multilinears: &[M],
44	zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
45) -> Result<(), Error>
46where
47	F: Field,
48	P: PackedField<Scalar = F>,
49	M: MultilinearPoly<P> + Send + Sync,
50	Composition: CompositionPoly<P> + 'a,
51{
52	let n_vars = multilinears
53		.first()
54		.map(|multilinear| multilinear.n_vars())
55		.unwrap_or_default();
56	for multilinear in multilinears {
57		if multilinear.n_vars() != n_vars {
58			bail!(Error::NumberOfVariablesMismatch);
59		}
60	}
61
62	let multilinears = multilinears.iter().collect::<Vec<_>>();
63
64	for (name, composition) in zero_claims {
65		let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
66		(0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
67			if witness.evaluate_on_hypercube(j)? != F::ZERO {
68				return Err(Error::ZerocheckNaiveValidationFailure {
69					composition_name: name.to_string(),
70					vertex_index: j,
71				});
72			}
73			Ok(())
74		})?;
75	}
76	Ok(())
77}
78
79/// A prover that is capable of performing univariate skip.
80///
81/// By recasting `skip_rounds` first variables in a multilinear sumcheck into a univariate domain,
82/// it becomes possible to compute all of these rounds in small fields, unlocking significant
83/// performance gains. See [`zerocheck_univariate_evals`] rustdoc for a more detailed explanation.
84///
85/// This struct is an entrypoint to proving all zerochecks instances, univariatized and regular.
86/// "Regular" multilinear case is covered by calling [`Self::into_regular_zerocheck`] right away,
87/// producing a [`ZerocheckProver`]. Univariatized case is handled by using methods from a
88/// [`UnivariateZerocheckProver`] trait, where folding results in a reduced multilinear zerocheck
89/// prover for the remaining rounds.
90#[derive(Debug, Getters)]
91pub struct UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
92where
93	FDomain: Field,
94	FBase: Field,
95	P: PackedField,
96	Backend: ComputationBackend,
97{
98	n_vars: usize,
99	#[getset(get = "pub")]
100	multilinears: Vec<M>,
101	switchover_rounds: Vec<usize>,
102	compositions: Vec<(String, CompositionBase, Composition)>,
103	zerocheck_challenges: Vec<P::Scalar>,
104	domains: Vec<InterpolationDomain<FDomain>>,
105	backend: &'a Backend,
106	univariate_evals_output: Option<ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>>,
107	_p_base_marker: PhantomData<FBase>,
108	_m_marker: PhantomData<&'m ()>,
109}
110
111impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
112	UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
113where
114	F: Field,
115	FDomain: Field,
116	FBase: ExtensionField<FDomain>,
117	P: PackedFieldIndexable<Scalar = F>
118		+ PackedExtension<F, PackedSubfield = P>
119		+ PackedExtension<FBase>
120		+ PackedExtension<FDomain>,
121	CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
122	Composition: CompositionPoly<P>,
123	M: MultilinearPoly<P> + Send + Sync + 'm,
124	Backend: ComputationBackend,
125{
126	pub fn new(
127		multilinears: Vec<M>,
128		zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
129		zerocheck_challenges: &[F],
130		evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
131		switchover_fn: impl Fn(usize) -> usize,
132		backend: &'a Backend,
133	) -> Result<Self, Error> {
134		let n_vars = equal_n_vars_check(&multilinears)?;
135
136		let compositions = zero_claims.into_iter().collect::<Vec<_>>();
137		for (_, composition_base, composition) in &compositions {
138			if composition_base.n_vars() != multilinears.len()
139				|| composition.n_vars() != multilinears.len()
140				|| composition_base.degree() != composition.degree()
141			{
142				bail!(Error::InvalidComposition {
143					actual: composition.n_vars(),
144					expected: multilinears.len(),
145				});
146			}
147		}
148		#[cfg(feature = "debug_validate_sumcheck")]
149		{
150			let compositions = compositions
151				.iter()
152				.map(|(name, _, a)| (name.clone(), a))
153				.collect::<Vec<_>>();
154			validate_witness(&multilinears, &compositions)?;
155		}
156
157		let switchover_rounds = determine_switchovers(&multilinears, switchover_fn);
158		let zerocheck_challenges = zerocheck_challenges.to_vec();
159
160		let domains = compositions
161			.iter()
162			.map(|(_, _, composition)| {
163				let degree = composition.degree();
164				let domain =
165					evaluation_domain_factory.create_with_infinity(degree + 1, degree >= 2)?;
166				Ok(domain.into())
167			})
168			.collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
169			.map_err(Error::MathError)?;
170
171		Ok(Self {
172			n_vars,
173			multilinears,
174			switchover_rounds,
175			compositions,
176			zerocheck_challenges,
177			domains,
178			backend,
179			univariate_evals_output: None,
180			_p_base_marker: PhantomData,
181			_m_marker: PhantomData,
182		})
183	}
184
185	#[instrument(skip_all, level = "debug")]
186	#[allow(clippy::type_complexity)]
187	pub fn into_regular_zerocheck(
188		self,
189	) -> Result<
190		ZerocheckProver<'a, FDomain, P, Composition, MultilinearWitness<'m, P>, Backend>,
191		Error,
192	> {
193		if self.univariate_evals_output.is_some() {
194			bail!(Error::ExpectedFold);
195		}
196
197		// Type erase the multilinears
198		// REVIEW: this may result in "double boxing" if M is already a trait object;
199		//         consider implementing MultilinearPoly on an Either, or
200		//         supporting two different SumcheckProver<F> types in batch_prove
201		let multilinears = self
202			.multilinears
203			.into_iter()
204			.map(|multilinear| Arc::new(multilinear) as MultilinearWitness<'_, P>)
205			.collect::<Vec<_>>();
206
207		#[cfg(feature = "debug_validate_sumcheck")]
208		{
209			let compositions = self
210				.compositions
211				.iter()
212				.map(|(name, _, a)| (name.clone(), a))
213				.collect::<Vec<_>>();
214			validate_witness(&multilinears, &compositions)?;
215		}
216
217		let compositions = self
218			.compositions
219			.into_iter()
220			.map(|(_, _, composition)| composition)
221			.collect::<Vec<_>>();
222
223		// Evaluate zerocheck partial indicator in variables 1..n_vars
224		let start = self.n_vars.min(1);
225		let partial_eq_ind_evals = self
226			.backend
227			.tensor_product_full_query(&self.zerocheck_challenges[start..])?;
228		let claimed_sums = vec![F::ZERO; compositions.len()];
229
230		// This is a regular multilinear zerocheck constructor, split over two creation stages.
231		ZerocheckProver::new(
232			EvaluationOrder::LowToHigh,
233			multilinears,
234			&self.switchover_rounds,
235			compositions,
236			partial_eq_ind_evals,
237			self.zerocheck_challenges,
238			claimed_sums,
239			self.domains,
240			RegularFirstRound::SkipCube,
241			self.backend,
242		)
243	}
244}
245
246impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
247	UnivariateZerocheckProver<'a, F>
248	for UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
249where
250	F: TowerField,
251	FDomain: TowerField,
252	FBase: ExtensionField<FDomain>,
253	P: PackedFieldIndexable<Scalar = F>
254		+ PackedExtension<F, PackedSubfield = P>
255		+ PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
256		+ PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
257	CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
258	Composition: CompositionPoly<P> + 'static,
259	M: MultilinearPoly<P> + Send + Sync + 'm,
260	Backend: ComputationBackend,
261{
262	fn n_vars(&self) -> usize {
263		self.n_vars
264	}
265
266	fn domain_size(&self, skip_rounds: usize) -> usize {
267		self.compositions
268			.iter()
269			.map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
270			.max()
271			.unwrap_or(0)
272	}
273
274	#[instrument(skip_all, level = "debug")]
275	fn execute_univariate_round(
276		&mut self,
277		skip_rounds: usize,
278		max_domain_size: usize,
279		batch_coeff: F,
280	) -> Result<LagrangeRoundEvals<F>, Error> {
281		if self.univariate_evals_output.is_some() {
282			bail!(Error::ExpectedFold);
283		}
284
285		// Only use base compositions in the univariate round (it's the whole point)
286		let compositions_base = self
287			.compositions
288			.iter()
289			.map(|(_, composition_base, _)| composition_base)
290			.collect::<Vec<_>>();
291
292		// Output contains values that are needed for computations that happen after
293		// the round challenge has been sampled
294		let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
295			&self.multilinears,
296			&compositions_base,
297			&self.zerocheck_challenges,
298			skip_rounds,
299			max_domain_size,
300			self.backend,
301		)?;
302
303		// Batch together Lagrange round evals using powers of batch_coeff
304		let zeros_prefix_len = 1 << skip_rounds;
305		let batched_round_evals = univariate_evals_output
306			.round_evals
307			.iter()
308			.zip(powers(batch_coeff))
309			.map(|(evals, scalar)| {
310				let round_evals = LagrangeRoundEvals {
311					zeros_prefix_len,
312					evals: evals.clone(),
313				};
314				round_evals * scalar
315			})
316			.try_fold(
317				LagrangeRoundEvals::zeros(max_domain_size),
318				|mut accum, evals| -> Result<_, Error> {
319					accum.add_assign_lagrange(&evals)?;
320					Ok(accum)
321				},
322			)?;
323
324		self.univariate_evals_output = Some(univariate_evals_output);
325
326		Ok(batched_round_evals)
327	}
328
329	#[instrument(skip_all, level = "debug")]
330	fn fold_univariate_round(
331		self: Box<Self>,
332		challenge: F,
333	) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
334		if self.univariate_evals_output.is_none() {
335			bail!(Error::ExpectedExecution);
336		}
337
338		// Once the challenge is known, values required for the instantiation of the
339		// multilinear prover for the remaining rounds become known.
340		let ZerocheckUnivariateFoldResult {
341			skip_rounds,
342			subcube_lagrange_coeffs,
343			claimed_prime_sums,
344			partial_eq_ind_evals,
345		} = self
346			.univariate_evals_output
347			.expect("validated to be Some")
348			.fold::<FDomain>(challenge)?;
349
350		// For each subcube of size 2**skip_rounds, we need to compute its
351		// inner product with Lagrange coefficients at challenge point in order
352		// to obtain the witness for the remaining multilinear rounds.
353		// REVIEW: Currently MultilinearPoly lacks a method to do that, so we
354		//         hack the needed functionality by overwriting the inner content
355		//         of a MultilinearQuery and performing an evaluate_partial_low,
356		//         which accidentally does what's needed. There should obviously
357		//         be a dedicated method for this someday.
358		let mut packed_subcube_lagrange_coeffs =
359			zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
360		P::unpack_scalars_mut(&mut packed_subcube_lagrange_coeffs)[..1 << skip_rounds]
361			.copy_from_slice(&subcube_lagrange_coeffs);
362		let lagrange_coeffs_query =
363			MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
364
365		let partial_low_multilinears = self
366			.multilinears
367			.into_par_iter()
368			.map(|multilinear| -> Result<_, Error> {
369				let multilinear =
370					multilinear.evaluate_partial_low(lagrange_coeffs_query.to_ref())?;
371				let mle_adapter = Arc::new(MLEDirectAdapter::from(multilinear));
372				Ok(mle_adapter as MultilinearWitness<'static, P>)
373			})
374			.collect::<Result<Vec<_>, _>>()?;
375
376		let switchover_rounds = self
377			.switchover_rounds
378			.into_iter()
379			.map(|switchover_round| switchover_round.saturating_sub(skip_rounds))
380			.collect::<Vec<_>>();
381
382		let zerocheck_challenges = self.zerocheck_challenges.clone();
383
384		let compositions = self
385			.compositions
386			.into_iter()
387			.map(|(_, _, composition)| composition)
388			.collect();
389
390		// This is also regular multilinear zerocheck constructor, but "jump started" in round
391		// `skip_rounds` while using witness with a projected univariate round.
392		// NB: first round evaluator has to be overridden due to issues proving
393		// `P: RepackedExtension<P>` relation in the generic context, as well as the need
394		// to use later round evaluator (as this _is_ a "later" round, albeit numbered at zero)
395		let regular_prover = ZerocheckProver::new(
396			EvaluationOrder::LowToHigh,
397			partial_low_multilinears,
398			&switchover_rounds,
399			compositions,
400			partial_eq_ind_evals,
401			zerocheck_challenges,
402			claimed_prime_sums,
403			self.domains,
404			RegularFirstRound::LaterRound,
405			self.backend,
406		)?;
407
408		Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
409	}
410}
411
412#[derive(Debug, Clone, Copy)]
413enum RegularFirstRound {
414	SkipCube,
415	LaterRound,
416}
417
418/// A "regular" multilinear zerocheck prover.
419///
420/// The main difference of this prover from a regular sumcheck prover is that it computes
421/// round evaluations of a much simpler "prime" polynomial multiplied by a "higher" portion
422/// of the equality indicator. This "prime" polynomial has the same degree as the underlying
423/// composition, reducing the number of would-be evaluation points by one, and the tensor
424/// expansion of the zerocheck indicator doesn't have to be interpolated. Round evaluations
425/// for the "full" assumed zerocheck composition are computed in monomial form, out of hot loop.
426/// See [Gruen24] Section 3.2 for details.
427///
428/// When "jump starting" a zerocheck prover in a middle of zerocheck, pay attention that
429/// `claimed_prime_sums` are on "prime" polynomial, and not on full zerocheck polynomial.
430///
431/// [Gruen24]: <https://eprint.iacr.org/2024/108>
432#[derive(Debug)]
433pub struct ZerocheckProver<'a, FDomain, P, Composition, M, Backend>
434where
435	FDomain: Field,
436	P: PackedField,
437	M: MultilinearPoly<P> + Send + Sync,
438	Backend: ComputationBackend,
439{
440	n_vars: usize,
441	state: ProverState<'a, FDomain, P, M, Backend>,
442	eq_ind_eval: P::Scalar,
443	partial_eq_ind_evals: Backend::Vec<P>,
444	zerocheck_challenges: Vec<P::Scalar>,
445	compositions: Vec<Composition>,
446	domains: Vec<InterpolationDomain<FDomain>>,
447	first_round: RegularFirstRound,
448}
449
450impl<'a, F, FDomain, P, Composition, M, Backend>
451	ZerocheckProver<'a, FDomain, P, Composition, M, Backend>
452where
453	F: Field,
454	FDomain: Field,
455	P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
456	Composition: CompositionPoly<P>,
457	M: MultilinearPoly<P> + Send + Sync,
458	Backend: ComputationBackend,
459{
460	#[allow(clippy::too_many_arguments)]
461	fn new(
462		// REVIEW: given that high-to-low zerocheck may only be instantiated via
463		//         reduction from high-to-low univariate prover, actual implementation
464		//         of high-to-low zerocheck is deferred until the introduction of high-to-low
465		//         univariate skip.
466		evaluation_order: EvaluationOrder,
467		multilinears: Vec<M>,
468		switchover_rounds: &[usize],
469		compositions: Vec<Composition>,
470		partial_eq_ind_evals: Backend::Vec<P>,
471		zerocheck_challenges: Vec<F>,
472		claimed_prime_sums: Vec<F>,
473		domains: Vec<InterpolationDomain<FDomain>>,
474		first_round: RegularFirstRound,
475		backend: &'a Backend,
476	) -> Result<Self, Error> {
477		if claimed_prime_sums.len() != compositions.len() {
478			bail!(Error::IncorrectClaimedPrimeSumsLength);
479		}
480
481		let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
482
483		let state = ProverState::new_with_switchover_rounds(
484			evaluation_order,
485			multilinears,
486			switchover_rounds,
487			claimed_prime_sums,
488			nontrivial_evaluation_points,
489			backend,
490		)?;
491		let n_vars = state.n_vars();
492
493		if zerocheck_challenges.len() != n_vars {
494			bail!(Error::IncorrectZerocheckChallengesLength);
495		}
496
497		// Only one value of the expanded zerocheck equality indicator is used per each
498		// 1-variable subcube, thus it should be twice smaller.
499		if partial_eq_ind_evals.len() != 1 << n_vars.saturating_sub(1 + P::LOG_WIDTH) {
500			bail!(Error::IncorrectZerocheckPartialEqIndSize);
501		}
502
503		let eq_ind_eval = F::ONE;
504
505		Ok(Self {
506			n_vars,
507			state,
508			eq_ind_eval,
509			partial_eq_ind_evals,
510			zerocheck_challenges,
511			compositions,
512			domains,
513			first_round,
514		})
515	}
516
517	fn round(&self) -> usize {
518		self.n_vars - self.n_rounds_remaining()
519	}
520
521	fn n_rounds_remaining(&self) -> usize {
522		self.state.n_vars()
523	}
524
525	fn update_eq_ind_eval(&mut self, challenge: F) {
526		// Update the running eq ind evaluation.
527		let alpha = self.zerocheck_challenges[self.round()];
528		self.eq_ind_eval *= eq(alpha, challenge);
529	}
530
531	#[instrument(skip_all, level = "debug")]
532	fn fold_partial_eq_ind(&mut self) {
533		fold_partial_eq_ind::<P, Backend>(
534			self.state.evaluation_order(),
535			self.n_rounds_remaining(),
536			&mut self.partial_eq_ind_evals,
537		);
538	}
539}
540
541impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
542	for ZerocheckProver<'_, FDomain, P, Composition, M, Backend>
543where
544	F: TowerField + ExtensionField<FDomain>,
545	FDomain: Field,
546	P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
547	Composition: CompositionPoly<P>,
548	M: MultilinearPoly<P> + Send + Sync,
549	Backend: ComputationBackend,
550{
551	fn n_vars(&self) -> usize {
552		self.n_vars
553	}
554
555	fn evaluation_order(&self) -> EvaluationOrder {
556		self.state.evaluation_order()
557	}
558
559	#[instrument(skip_all, name = "ZerocheckProver::fold", level = "debug")]
560	fn fold(&mut self, challenge: F) -> Result<(), Error> {
561		self.update_eq_ind_eval(challenge);
562		self.state.fold(challenge)?;
563
564		// This must happen after state fold, which decrements n_rounds_remaining.
565		self.fold_partial_eq_ind();
566
567		Ok(())
568	}
569
570	#[instrument(skip_all, name = "ZerocheckProver::execute", level = "debug")]
571	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
572		let round = self.round();
573		let skip_cube_first_round =
574			round == 0 && matches!(self.first_round, RegularFirstRound::SkipCube);
575		let coeffs = if skip_cube_first_round {
576			let evaluators = izip!(&self.compositions, &self.domains)
577				.map(|(composition, interpolation_domain)| {
578					let composition_at_infinity =
579						ArithCircuitPoly::new(composition.expression().leading_term());
580
581					ZerocheckFirstRoundEvaluator {
582						composition,
583						composition_at_infinity,
584						interpolation_domain,
585						partial_eq_ind_evals: &self.partial_eq_ind_evals,
586					}
587				})
588				.collect::<Vec<_>>();
589			let evals = self.state.calculate_round_evals(&evaluators)?;
590			self.state
591				.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?
592		} else {
593			let evaluators = izip!(&self.compositions, &self.domains)
594				.map(|(composition, interpolation_domain)| {
595					let composition_at_infinity =
596						ArithCircuitPoly::new(composition.expression().leading_term());
597
598					ZerocheckLaterRoundEvaluator {
599						composition,
600						composition_at_infinity,
601						interpolation_domain,
602						partial_eq_ind_evals: &self.partial_eq_ind_evals,
603						round_zerocheck_challenge: self.zerocheck_challenges[round],
604					}
605				})
606				.collect::<Vec<_>>();
607			let evals = self.state.calculate_round_evals(&evaluators)?;
608			self.state
609				.calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?
610		};
611
612		// Convert v' polynomial into v polynomial
613		let alpha = self.zerocheck_challenges[round];
614
615		// eq(X, α) = (1 − α) + (2 α − 1) X
616		// NB: In binary fields, this expression is simply  eq(X, α) = 1 + α + X
617		// However, we opt to keep this prover generic over all fields.
618		let constant_scalar = F::ONE - alpha;
619		let linear_scalar = alpha.double() - F::ONE;
620
621		let coeffs_scaled_by_constant_term = coeffs.clone() * constant_scalar;
622		let mut coeffs_scaled_by_linear_term = coeffs * linear_scalar;
623		coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); // Multiply polynomial by X
624
625		let sumcheck_coeffs = coeffs_scaled_by_constant_term + &coeffs_scaled_by_linear_term;
626		Ok(sumcheck_coeffs * self.eq_ind_eval)
627	}
628
629	#[instrument(skip_all, name = "ZerocheckProver::finish", level = "debug")]
630	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
631		let mut evals = self.state.finish()?;
632		evals.push(self.eq_ind_eval);
633		Ok(evals)
634	}
635}
636
637struct ZerocheckFirstRoundEvaluator<'a, P, FDomain, Composition>
638where
639	P: PackedField,
640	FDomain: Field,
641{
642	composition: &'a Composition,
643	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
644	interpolation_domain: &'a InterpolationDomain<FDomain>,
645	partial_eq_ind_evals: &'a [P],
646}
647
648impl<P, FDomain, Composition> SumcheckEvaluator<P, Composition>
649	for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition>
650where
651	P: PackedField<Scalar: TowerField + ExtensionField<FDomain>>,
652	FDomain: Field,
653	Composition: CompositionPoly<P>,
654{
655	fn eval_point_indices(&self) -> Range<usize> {
656		// In the first round of zerocheck we can uniquely determine the degree d
657		// univariate round polynomial $R(X)$ with evaluations at X = 2, ..., d
658		// because we know r(0) = r(1) = 0
659		2..self.composition.degree() + 1
660	}
661
662	fn process_subcube_at_eval_point(
663		&self,
664		subcube_vars: usize,
665		subcube_index: usize,
666		is_infinity_point: bool,
667		batch_query: &[&[P]],
668	) -> P {
669		// If the composition is a linear polynomial, then the composite multivariate polynomial
670		// is multilinear. If the prover is honest, then this multilinear is identically zero,
671		// hence the sum over the subcube is zero.
672		if self.composition.degree() == 1 {
673			return P::zero();
674		}
675		let row_len = batch_query.first().map_or(0, |row| row.len());
676
677		stackalloc_with_default(row_len, |evals| {
678			if is_infinity_point {
679				self.composition_at_infinity
680					.batch_evaluate(batch_query, evals)
681					.expect("correct by query construction invariant");
682			} else {
683				self.composition
684					.batch_evaluate(batch_query, evals)
685					.expect("correct by query construction invariant");
686			}
687
688			let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
689			let partial_eq_ind_evals_slice = &self.partial_eq_ind_evals[subcube_start..];
690			let field_sum = PackedField::iter_slice(partial_eq_ind_evals_slice)
691				.zip(PackedField::iter_slice(evals))
692				.map(|(eq_ind_scalar, base_scalar)| eq_ind_scalar * base_scalar)
693				.sum();
694
695			P::set_single(field_sum)
696		})
697	}
698
699	fn composition(&self) -> &Composition {
700		self.composition
701	}
702
703	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
704		Some(self.partial_eq_ind_evals)
705	}
706}
707
708impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
709	for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition>
710where
711	F: Field + ExtensionField<FDomain>,
712	P: PackedField<Scalar = F>,
713	FDomain: Field,
714{
715	fn round_evals_to_coeffs(
716		&self,
717		last_round_sum: F,
718		mut round_evals: Vec<F>,
719	) -> Result<Vec<F>, PolynomialError> {
720		assert_eq!(last_round_sum, F::ZERO);
721
722		// We are given $r(2), \ldots, r(d)$.
723		// From context, we infer that $r(0) = r(1) = 0$.
724		round_evals.insert(0, P::Scalar::ZERO);
725		round_evals.insert(0, P::Scalar::ZERO);
726
727		if round_evals.len() > 3 {
728			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then subspace points.
729			// InterpolationDomain expects "infinity" at the last position, thus reordering is needed.
730			// Putting "special" evaluation points at the beginning of domain allows benefitting from
731			// faster/skipped interpolation even in case of mixed degree compositions .
732			let infinity_round_eval = round_evals.remove(2);
733			round_evals.push(infinity_round_eval);
734		}
735
736		let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
737		Ok(coeffs)
738	}
739}
740
741struct ZerocheckLaterRoundEvaluator<'a, P, FDomain, Composition>
742where
743	P: PackedField,
744	FDomain: Field,
745{
746	composition: &'a Composition,
747	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
748	interpolation_domain: &'a InterpolationDomain<FDomain>,
749	partial_eq_ind_evals: &'a [P],
750	round_zerocheck_challenge: P::Scalar,
751}
752
753impl<P, FDomain, Composition> SumcheckEvaluator<P, Composition>
754	for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition>
755where
756	P: PackedField<Scalar: TowerField + ExtensionField<FDomain>>,
757	FDomain: Field,
758	Composition: CompositionPoly<P>,
759{
760	fn eval_point_indices(&self) -> Range<usize> {
761		// We can uniquely derive the degree d univariate round polynomial r from evaluations at
762		// X = 1, ..., d because we have an identity that relates r(0), r(1), and the current
763		// round's claimed sum
764		1..self.composition.degree() + 1
765	}
766
767	fn process_subcube_at_eval_point(
768		&self,
769		subcube_vars: usize,
770		subcube_index: usize,
771		is_infinity_point: bool,
772		batch_query: &[&[P]],
773	) -> P {
774		// If the composition is a linear polynomial, then the composite multivariate polynomial
775		// is multilinear. If the prover is honest, then this multilinear is identically zero,
776		// hence the sum over the subcube is zero.
777		if self.composition.degree() == 1 {
778			return P::zero();
779		}
780		let row_len = batch_query.first().map_or(0, |row| row.len());
781
782		stackalloc_with_default(row_len, |evals| {
783			if is_infinity_point {
784				self.composition_at_infinity
785					.batch_evaluate(batch_query, evals)
786					.expect("correct by query construction invariant");
787			} else {
788				self.composition
789					.batch_evaluate(batch_query, evals)
790					.expect("correct by query construction invariant");
791			}
792
793			let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
794			for (i, eval) in evals.iter_mut().enumerate() {
795				*eval *= self.partial_eq_ind_evals[subcube_start + i];
796			}
797
798			evals.iter().copied().sum::<P>()
799		})
800	}
801
802	fn composition(&self) -> &Composition {
803		self.composition
804	}
805
806	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
807		Some(self.partial_eq_ind_evals)
808	}
809}
810
811impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
812	for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition>
813where
814	F: Field,
815	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
816	FDomain: Field,
817{
818	fn round_evals_to_coeffs(
819		&self,
820		last_round_sum: F,
821		mut round_evals: Vec<F>,
822	) -> Result<Vec<F>, PolynomialError> {
823		// This is a subsequent round of a sumcheck that came from zerocheck, given $r(1), \ldots, r(d)$
824		// Letting $s$ be the current round's claimed sum, and $\alpha_i$ the ith zerocheck challenge
825		// we have the identity $r(0) = \frac{1}{1 - \alpha_i} * (s - \alpha_i * r(1))$
826		// which allows us to compute the value of $r(0)$
827
828		let alpha = self.round_zerocheck_challenge;
829		let one_evaluation = round_evals[0]; // r(1)
830		let zero_evaluation_numerator = last_round_sum - one_evaluation * alpha;
831		let zero_evaluation_denominator_inv = (F::ONE - alpha).invert_or_zero();
832		let zero_evaluation = zero_evaluation_numerator * zero_evaluation_denominator_inv;
833
834		round_evals.insert(0, zero_evaluation);
835
836		if round_evals.len() > 3 {
837			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then subspace points.
838			// InterpolationDomain expects "infinity" at the last position, thus reordering is needed.
839			// Putting "special" evaluation points at the beginning of domain allows benefitting from
840			// faster/skipped interpolation even in case of mixed degree compositions .
841			let infinity_round_eval = round_evals.remove(2);
842			round_evals.push(infinity_round_eval);
843		}
844
845		let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
846		Ok(coeffs)
847	}
848}