binius_core/protocols/sumcheck/prove/
eq_ind.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{cmp::Reverse, marker::PhantomData, ops::Range};
4
5use binius_fast_compute::arith_circuit::ArithCircuitPoly;
6use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField, util::eq};
7use binius_hal::{
8	ComputationBackend, Error as HalError, SumcheckEvaluator, SumcheckMultilinear,
9	make_portable_backend,
10};
11use binius_math::{
12	CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
13	MLEDirectAdapter, MultilinearPoly, RowsBatchRef,
14};
15use binius_maybe_rayon::prelude::*;
16use binius_utils::bail;
17use getset::Getters;
18use itertools::izip;
19use stackalloc::stackalloc_with_default;
20use tracing::instrument;
21
22use crate::{
23	polynomial::{Error as PolynomialError, MultivariatePoly},
24	protocols::sumcheck::{
25		CompositeSumClaim, Error,
26		common::{
27			RoundCoeffs, equal_n_vars_check, get_nontrivial_evaluation_points,
28			interpolation_domains_for_composition_degrees,
29		},
30		prove::{ProverState, SumcheckInterpolator, SumcheckProver, common::fold_partial_eq_ind},
31	},
32	transparent::{eq_ind::EqIndPartialEval, step_up::StepUp},
33};
34
35pub fn validate_witness<F, P, M, Composition>(
36	n_vars: usize,
37	multilinears: &[SumcheckMultilinear<P, M>],
38	eq_ind_challenges: &[F],
39	eq_ind_sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
40) -> Result<(), Error>
41where
42	F: Field,
43	P: PackedField<Scalar = F>,
44	M: MultilinearPoly<P> + Send + Sync,
45	Composition: CompositionPoly<P>,
46{
47	if eq_ind_challenges.len() != n_vars {
48		bail!(Error::IncorrectEqIndChallengesLength);
49	}
50
51	for multilinear in multilinears {
52		match *multilinear {
53			SumcheckMultilinear::Transparent {
54				ref multilinear,
55				const_suffix: (suffix_eval, suffix_len),
56				..
57			} => {
58				if multilinear.n_vars() != n_vars {
59					bail!(Error::NumberOfVariablesMismatch);
60				}
61
62				let first_const = (1usize << n_vars)
63					.checked_sub(suffix_len)
64					.ok_or(Error::IncorrectConstSuffixes)?;
65
66				for i in first_const..1 << n_vars {
67					if multilinear.evaluate_on_hypercube(i)? != suffix_eval {
68						bail!(Error::IncorrectConstSuffixes);
69					}
70				}
71			}
72
73			SumcheckMultilinear::Folded {
74				large_field_folded_evals: ref evals,
75				..
76			} => {
77				if evals.len() > 1 << n_vars.saturating_sub(P::LOG_WIDTH) {
78					bail!(Error::IncorrectConstSuffixes)
79				}
80			}
81		}
82	}
83
84	let backend = make_portable_backend();
85	let eq_ind =
86		EqIndPartialEval::new(eq_ind_challenges).multilinear_extension::<P, _>(&backend)?;
87
88	for (i, claim) in eq_ind_sum_claims.into_iter().enumerate() {
89		let CompositeSumClaim {
90			composition,
91			sum: expected_sum,
92		} = claim;
93		let sum = (0..(1 << n_vars))
94			.into_par_iter()
95			.try_fold(
96				|| (vec![P::zero(); multilinears.len()], F::ZERO),
97				|(mut multilinear_evals, mut running_sum), j| -> Result<_, Error> {
98					for (eval, multilinear) in izip!(&mut multilinear_evals, multilinears) {
99						*eval = P::broadcast(match multilinear {
100							SumcheckMultilinear::Transparent { multilinear, .. } => {
101								multilinear.evaluate_on_hypercube(j)?
102							}
103							SumcheckMultilinear::Folded {
104								large_field_folded_evals,
105								suffix_eval,
106							} => binius_field::packed::get_packed_slice_checked(
107								large_field_folded_evals,
108								j,
109							)
110							.unwrap_or(*suffix_eval),
111						});
112					}
113
114					running_sum += eq_ind.evaluate_on_hypercube(j)?
115						* composition.evaluate(&multilinear_evals)?.get(0);
116					Ok((multilinear_evals, running_sum))
117				},
118			)
119			.map(|fold_state| -> Result<_, Error> { Ok(fold_state?.1) })
120			.try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
121
122		if sum != expected_sum {
123			bail!(Error::SumcheckNaiveValidationFailure {
124				composition_index: i,
125			});
126		}
127	}
128	Ok(())
129}
130
131/// An "eq-ind" sumcheck prover.
132///
133/// The main difference of this prover from the `RegularSumcheckProver` is that
134/// it computes round evaluations of a much simpler "prime" polynomial
135/// multiplied by an already substituted portion of the equality indicator. This
136/// "prime" polynomial has the same degree as the underlying composition,
137/// reducing the number of would-be evaluation points by one, and avoids
138/// interpolating the tensor expansion of the equality indicator.  Round
139/// evaluations for the "full" assumed composition are computed in
140/// monomial form, out of hot loop.  See [Gruen24] Section 3.2 for details.
141///
142/// The rationale behind builder interface is the need to specify the pre-expanded
143/// equality indicator and potentially known evaluations at one in first round.
144///
145/// [Gruen24]: <https://eprint.iacr.org/2024/108>
146pub struct EqIndSumcheckProverBuilder<'a, P, M, Backend>
147where
148	P: PackedField,
149	M: MultilinearPoly<P>,
150	Backend: ComputationBackend,
151{
152	n_vars: usize,
153	eq_ind_partial_evals: Option<Backend::Vec<P>>,
154	first_round_eval_1s: Option<Vec<P::Scalar>>,
155	multilinears: Vec<SumcheckMultilinear<P, M>>,
156	backend: &'a Backend,
157}
158
159impl<'a, F, P, Backend> EqIndSumcheckProverBuilder<'a, P, MLEDirectAdapter<P, Vec<P>>, Backend>
160where
161	F: TowerField,
162	P: PackedField<Scalar = F>,
163	Backend: ComputationBackend,
164{
165	pub fn without_switchover(
166		n_vars: usize,
167		multilinears: Vec<Vec<P>>,
168		backend: &'a Backend,
169	) -> Self {
170		let multilinears = multilinears
171			.into_iter()
172			.map(SumcheckMultilinear::folded)
173			.collect();
174
175		Self {
176			n_vars,
177			eq_ind_partial_evals: None,
178			first_round_eval_1s: None,
179			multilinears,
180			backend,
181		}
182	}
183}
184
185impl<'a, F, P, M, Backend> EqIndSumcheckProverBuilder<'a, P, M, Backend>
186where
187	F: TowerField,
188	P: PackedField<Scalar = F>,
189	M: MultilinearPoly<P> + Send + Sync,
190	Backend: ComputationBackend,
191{
192	pub fn with_switchover(
193		multilinears: Vec<M>,
194		switchover_fn: impl Fn(usize) -> usize,
195		backend: &'a Backend,
196	) -> Result<Self, Error> {
197		let n_vars = equal_n_vars_check(&multilinears)?;
198		let multilinears = multilinears
199			.into_iter()
200			.map(|multilinear| SumcheckMultilinear::transparent(multilinear, &switchover_fn))
201			.collect();
202
203		Ok(Self {
204			n_vars,
205			eq_ind_partial_evals: None,
206			first_round_eval_1s: None,
207			multilinears,
208			backend,
209		})
210	}
211
212	/// Specify an existing tensor expansion for `eq_ind_challenges` in [`Self::build`]. Avoids
213	/// duplicate work.
214	pub fn with_eq_ind_partial_evals(mut self, eq_ind_partial_evals: Backend::Vec<P>) -> Self {
215		self.eq_ind_partial_evals = Some(eq_ind_partial_evals);
216		self
217	}
218
219	/// Specify the value of round polynomial at 1 in the first round if it is available beforehand.
220	///
221	/// Prime example of this is GPA (grand product argument), where the value of the previous GKR
222	/// layer may be used as an advice to compute the round polynomial at 1 directly with less
223	/// effort compared to direct composite evaluation.
224	pub fn with_first_round_eval_1s(mut self, first_round_eval_1s: &[F]) -> Self {
225		self.first_round_eval_1s = Some(first_round_eval_1s.to_vec());
226		self
227	}
228
229	/// Specify the const suffixes for multilinears.
230	///
231	/// The provided array specifies the const suffixes at the end of each multilinear.
232	/// Prover is able to reduce multilinear storage and compute using this information.
233	pub fn with_const_suffixes(mut self, const_suffixes: &[(F, usize)]) -> Result<Self, Error> {
234		if const_suffixes.len() != self.multilinears.len() {
235			bail!(Error::IncorrectConstSuffixes);
236		}
237
238		for (multilinear, &const_suffix) in izip!(&mut self.multilinears, const_suffixes) {
239			let (_, suffix_len) = const_suffix;
240
241			if suffix_len > 1 << self.n_vars {
242				bail!(Error::IncorrectConstSuffixes);
243			}
244
245			multilinear.update_const_suffix(self.n_vars, const_suffix);
246		}
247
248		Ok(self)
249	}
250
251	#[instrument(skip_all, level = "debug", name = "EqIndSumcheckProverBuilder::build")]
252	pub fn build<FDomain, Composition>(
253		self,
254		evaluation_order: EvaluationOrder,
255		eq_ind_challenges: &[F],
256		composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
257		domain_factory: impl EvaluationDomainFactory<FDomain>,
258	) -> Result<EqIndSumcheckProver<'a, FDomain, P, Composition, M, Backend>, Error>
259	where
260		F: ExtensionField<FDomain>,
261		P: PackedExtension<FDomain>,
262		FDomain: Field,
263		Composition: CompositionPoly<P>,
264	{
265		let Self {
266			n_vars,
267			backend,
268			multilinears,
269			..
270		} = self;
271		let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
272
273		#[cfg(feature = "debug_validate_sumcheck")]
274		{
275			let composite_claims = composite_claims
276				.iter()
277				.map(|composite_claim| CompositeSumClaim {
278					composition: &composite_claim.composition,
279					sum: composite_claim.sum,
280				})
281				.collect::<Vec<_>>();
282			validate_witness(n_vars, &multilinears, eq_ind_challenges, composite_claims.clone())?;
283		}
284
285		if eq_ind_challenges.len() != n_vars {
286			bail!(Error::IncorrectEqIndChallengesLength);
287		}
288
289		// Only one value of the expanded equality indicator is used per each
290		// 1-variable subcube, thus it should be twice smaller.
291		let eq_ind_partial_evals = if let Some(eq_ind_partial_evals) = self.eq_ind_partial_evals {
292			if eq_ind_partial_evals.len() != 1 << n_vars.saturating_sub(P::LOG_WIDTH + 1) {
293				bail!(Error::IncorrectEqIndPartialEvalsSize);
294			}
295
296			eq_ind_partial_evals
297		} else {
298			eq_ind_expand(evaluation_order, eq_ind_challenges, backend)?
299		};
300
301		if let Some(ref first_round_eval_1s) = self.first_round_eval_1s {
302			if first_round_eval_1s.len() != composite_claims.len() {
303				bail!(Error::IncorrectFirstRoundEvalOnesLength);
304			}
305		}
306
307		for claim in &composite_claims {
308			if claim.composition.n_vars() != multilinears.len() {
309				bail!(Error::InvalidComposition {
310					expected: multilinears.len(),
311					actual: claim.composition.n_vars(),
312				});
313			}
314		}
315
316		let (compositions, claimed_sums) = determine_const_eval_suffixes(
317			composite_claims,
318			multilinears
319				.iter()
320				.map(|multilinear| multilinear.const_suffix(n_vars)),
321		);
322
323		let domains = interpolation_domains_for_composition_degrees(
324			domain_factory,
325			compositions
326				.iter()
327				.map(|(composition, _)| composition.degree()),
328		)?;
329
330		let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
331
332		let state = ProverState::new(
333			evaluation_order,
334			n_vars,
335			multilinears,
336			claimed_sums,
337			nontrivial_evaluation_points,
338			backend,
339		)?;
340
341		let eq_ind_prefix_eval = F::ONE;
342		let eq_ind_challenges = eq_ind_challenges.to_vec();
343		let first_round_eval_1s = self.first_round_eval_1s;
344
345		Ok(EqIndSumcheckProver {
346			n_vars,
347			state,
348			eq_ind_prefix_eval,
349			eq_ind_partial_evals,
350			eq_ind_challenges,
351			compositions,
352			domains,
353			first_round_eval_1s,
354			backend: PhantomData,
355		})
356	}
357}
358
359#[derive(Default, PartialEq, Eq, Debug)]
360pub struct ConstEvalSuffix<F: Field> {
361	pub suffix: usize,
362	pub value: F,
363	pub value_at_inf: F,
364}
365
366impl<F: Field> ConstEvalSuffix<F> {
367	fn update(&mut self, evaluation_order: EvaluationOrder, n_vars: usize) {
368		let eval_prefix = (1 << n_vars) - self.suffix;
369		let updated_eval_prefix = match evaluation_order {
370			EvaluationOrder::LowToHigh => eval_prefix.div_ceil(2),
371			EvaluationOrder::HighToLow => eval_prefix.min(1 << (n_vars - 1)),
372		};
373		self.suffix = (1 << (n_vars - 1)) - updated_eval_prefix;
374	}
375}
376
377#[derive(Debug, Getters)]
378pub struct EqIndSumcheckProver<'a, FDomain, P, Composition, M, Backend>
379where
380	FDomain: Field,
381	P: PackedField,
382	M: MultilinearPoly<P> + Send + Sync,
383	Backend: ComputationBackend,
384{
385	n_vars: usize,
386	state: ProverState<'a, FDomain, P, M, Backend>,
387	eq_ind_prefix_eval: P::Scalar,
388	eq_ind_partial_evals: Backend::Vec<P>,
389	eq_ind_challenges: Vec<P::Scalar>,
390	#[getset(get = "pub")]
391	compositions: Vec<(Composition, ConstEvalSuffix<P::Scalar>)>,
392	domains: Vec<InterpolationDomain<FDomain>>,
393	first_round_eval_1s: Option<Vec<P::Scalar>>,
394	backend: PhantomData<Backend>,
395}
396
397impl<F, FDomain, P, Composition, M, Backend>
398	EqIndSumcheckProver<'_, FDomain, P, Composition, M, Backend>
399where
400	F: TowerField + ExtensionField<FDomain>,
401	FDomain: Field,
402	P: PackedExtension<FDomain, Scalar = F>,
403	Composition: CompositionPoly<P>,
404	M: MultilinearPoly<P> + Send + Sync,
405	Backend: ComputationBackend,
406{
407	fn round(&self) -> usize {
408		self.n_vars - self.n_rounds_remaining()
409	}
410
411	fn n_rounds_remaining(&self) -> usize {
412		self.state.n_vars()
413	}
414
415	fn eq_ind_round_challenge(&self) -> F {
416		match self.state.evaluation_order() {
417			EvaluationOrder::LowToHigh => self.eq_ind_challenges[self.round()],
418			EvaluationOrder::HighToLow => {
419				self.eq_ind_challenges[self.eq_ind_challenges.len() - 1 - self.round()]
420			}
421		}
422	}
423
424	fn update_eq_ind_prefix_eval(&mut self, challenge: F) {
425		// Update the running eq ind evaluation.
426		self.eq_ind_prefix_eval *= eq(self.eq_ind_round_challenge(), challenge);
427	}
428}
429
430pub fn eq_ind_expand<P, Backend>(
431	evaluation_order: EvaluationOrder,
432	eq_ind_challenges: &[P::Scalar],
433	backend: &Backend,
434) -> Result<Backend::Vec<P>, HalError>
435where
436	P: PackedField,
437	Backend: ComputationBackend,
438{
439	let n_vars = eq_ind_challenges.len();
440	backend.tensor_product_full_query(match evaluation_order {
441		EvaluationOrder::LowToHigh => &eq_ind_challenges[n_vars.min(1)..],
442		EvaluationOrder::HighToLow => &eq_ind_challenges[..n_vars.saturating_sub(1)],
443	})
444}
445
446type CompositionsAndSums<F, Composition> = (Vec<(Composition, ConstEvalSuffix<F>)>, Vec<F>);
447
448// Automatically determine trace suffix which evaluates to constant polynomials during sumcheck.
449//
450// Algorithm outline:
451//  * sort multilinears by non-increasing const suffix length
452//  * processing multilinears in this order, symbolically substitute suffix eval for the current
453//    variable and optimize
454//  * if the remaining expressions at finite points and Karatsuba infinity are constant, assume this
455//    suffix
456fn determine_const_eval_suffixes<F, P, Composition>(
457	composite_claims: Vec<CompositeSumClaim<F, Composition>>,
458	const_suffixes: impl IntoIterator<Item = (F, usize)>,
459) -> CompositionsAndSums<F, Composition>
460where
461	F: Field,
462	P: PackedField<Scalar = F>,
463	Composition: CompositionPoly<P>,
464{
465	let mut const_suffixes = const_suffixes.into_iter().enumerate().collect::<Vec<_>>();
466
467	const_suffixes.sort_by_key(|&(_var, (_suffix_eval, suffix_len))| Reverse(suffix_len));
468
469	composite_claims
470		.into_iter()
471		.map(|claim| {
472			let CompositeSumClaim { composition, sum } = claim;
473			assert_eq!(const_suffixes.len(), composition.n_vars());
474
475			let mut const_eval_suffix = Default::default();
476
477			let mut expr = composition.expression();
478			let mut expr_at_inf = composition.expression().leading_term();
479
480			let expr_vars = expr.vars_usage();
481			let expr_at_inf_vars = expr_at_inf.vars_usage();
482
483			for &(var_index, (suffix_eval, suffix)) in &const_suffixes {
484				if !expr_at_inf_vars.get(var_index).unwrap_or(&false)
485					|| !expr_vars.get(var_index).unwrap_or(&false)
486				{
487					continue;
488				}
489
490				expr = expr
491					.const_subst(var_index, suffix_eval)
492					.optimize_constants();
493				// NB: infinity point has a different interpolation result; in characteristic 2,
494				// it's always zero.
495				expr_at_inf = expr_at_inf
496					.const_subst(var_index, suffix_eval + suffix_eval)
497					.optimize_constants();
498
499				if let Some((value, value_at_inf)) =
500					expr.get_constant().zip(expr_at_inf.get_constant())
501				{
502					const_eval_suffix = ConstEvalSuffix {
503						suffix,
504						value,
505						value_at_inf,
506					};
507					break;
508				}
509			}
510
511			((composition, const_eval_suffix), sum)
512		})
513		.unzip::<_, _, Vec<_>, Vec<_>>()
514}
515
516impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
517	for EqIndSumcheckProver<'_, FDomain, P, Composition, M, Backend>
518where
519	F: TowerField + ExtensionField<FDomain>,
520	FDomain: Field,
521	P: PackedExtension<FDomain, Scalar = F>,
522	Composition: CompositionPoly<P>,
523	M: MultilinearPoly<P> + Send + Sync,
524	Backend: ComputationBackend,
525{
526	fn n_vars(&self) -> usize {
527		self.n_vars
528	}
529
530	fn evaluation_order(&self) -> EvaluationOrder {
531		self.state.evaluation_order()
532	}
533
534	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
535		let round = self.round();
536		let n_rounds_remaining = self.n_rounds_remaining();
537
538		let alpha = self.eq_ind_round_challenge();
539		let eq_ind_partial_evals = &self.eq_ind_partial_evals;
540
541		let first_round_eval_1s = self.first_round_eval_1s.take();
542		let have_first_round_eval_1s = first_round_eval_1s.is_some();
543
544		let eq_ind_challenges = match self.state.evaluation_order() {
545			EvaluationOrder::LowToHigh => &self.eq_ind_challenges[self.n_vars.min(round + 1)..],
546			EvaluationOrder::HighToLow => {
547				&self.eq_ind_challenges[..self.n_vars.saturating_sub(round + 1)]
548			}
549		};
550
551		let evaluators = self
552			.compositions
553			.iter_mut()
554			.map(|(composition, const_eval_suffix)| {
555				let composition_at_infinity =
556					ArithCircuitPoly::new(composition.expression().leading_term());
557
558				const_eval_suffix.update(self.state.evaluation_order(), n_rounds_remaining);
559
560				Evaluator {
561					n_rounds_remaining,
562					composition,
563					composition_at_infinity,
564					have_first_round_eval_1s,
565					eq_ind_challenges,
566					eq_ind_partial_evals,
567					const_eval_suffix,
568				}
569			})
570			.collect::<Vec<_>>();
571
572		let interpolators = self
573			.domains
574			.iter()
575			.enumerate()
576			.map(|(index, interpolation_domain)| Interpolator {
577				interpolation_domain,
578				alpha,
579				first_round_eval_1: first_round_eval_1s
580					.as_ref()
581					.map(|first_round_eval_1s| first_round_eval_1s[index]),
582			})
583			.collect::<Vec<_>>();
584
585		let round_evals = self.state.calculate_round_evals(&evaluators)?;
586
587		let prime_coeffs = self.state.calculate_round_coeffs_from_evals(
588			&interpolators,
589			batch_coeff,
590			round_evals,
591		)?;
592
593		// Convert v' polynomial into v polynomial
594
595		// eq(X, α) = (1 − α) + (2 α − 1) X
596		// NB: In binary fields, this expression can be simplified to 1 + α + challenge.
597		let (prime_coeffs_scaled_by_constant_term, mut prime_coeffs_scaled_by_linear_term) =
598			if F::CHARACTERISTIC == 2 {
599				(prime_coeffs.clone() * (F::ONE + alpha), prime_coeffs)
600			} else {
601				(prime_coeffs.clone() * (F::ONE - alpha), prime_coeffs * (alpha.double() - F::ONE))
602			};
603
604		prime_coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); // Multiply prime polynomial by X
605
606		let coeffs = (prime_coeffs_scaled_by_constant_term + &prime_coeffs_scaled_by_linear_term)
607			* self.eq_ind_prefix_eval;
608
609		Ok(coeffs)
610	}
611
612	#[instrument(skip_all, name = "EqIndSumcheckProver::fold", level = "debug")]
613	fn fold(&mut self, challenge: F) -> Result<(), Error> {
614		self.update_eq_ind_prefix_eval(challenge);
615
616		let evaluation_order = self.state.evaluation_order();
617		let n_rounds_remaining = self.n_rounds_remaining();
618
619		let Self {
620			state,
621			eq_ind_partial_evals,
622			..
623		} = self;
624
625		binius_maybe_rayon::join(
626			|| state.fold(challenge),
627			|| {
628				fold_partial_eq_ind::<P, Backend>(
629					evaluation_order,
630					n_rounds_remaining - 1,
631					eq_ind_partial_evals,
632				);
633			},
634		)
635		.0?;
636		Ok(())
637	}
638
639	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
640		let mut evals = self.state.finish()?;
641		evals.push(self.eq_ind_prefix_eval);
642		Ok(evals)
643	}
644}
645
646struct Evaluator<'a, P, Composition>
647where
648	P: PackedField,
649{
650	n_rounds_remaining: usize,
651	composition: &'a Composition,
652	composition_at_infinity: ArithCircuitPoly<P::Scalar>,
653	have_first_round_eval_1s: bool,
654	eq_ind_challenges: &'a [P::Scalar],
655	eq_ind_partial_evals: &'a [P],
656	const_eval_suffix: &'a ConstEvalSuffix<P::Scalar>,
657}
658
659impl<P, Composition> SumcheckEvaluator<P, Composition> for Evaluator<'_, P, Composition>
660where
661	P: PackedField<Scalar: TowerField>,
662	Composition: CompositionPoly<P>,
663{
664	fn eval_point_indices(&self) -> Range<usize> {
665		// Do not evaluate r(1) in first round when its value is known
666		let start_index = if self.have_first_round_eval_1s { 2 } else { 1 };
667		start_index..self.composition.degree() + 1
668	}
669
670	fn process_subcube_at_eval_point(
671		&self,
672		subcube_vars: usize,
673		subcube_index: usize,
674		is_infinity_point: bool,
675		batch_query: &RowsBatchRef<P>,
676	) -> P {
677		let row_len = batch_query.row_len();
678
679		stackalloc_with_default(row_len, |evals| {
680			if is_infinity_point {
681				self.composition_at_infinity
682					.batch_evaluate(batch_query, evals)
683					.expect("correct by query construction invariant");
684			} else {
685				self.composition
686					.batch_evaluate(batch_query, evals)
687					.expect("correct by query construction invariant");
688			};
689
690			let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
691			for (i, eval) in evals.iter_mut().enumerate() {
692				// REVIEW: investigate whether its possible to access a subcube smaller than
693				//         the packing width and unaligned on the packed field binary; in that
694				//         case spread multiplication may be needed.
695				*eval *= self.eq_ind_partial_evals[subcube_start + i];
696			}
697			evals.iter().copied().sum::<P>()
698		})
699	}
700
701	fn process_constant_eval_suffix(
702		&self,
703		const_eval_suffix: usize,
704		is_infinity_point: bool,
705	) -> P::Scalar {
706		let eval_prefix = (1 << self.n_rounds_remaining) - const_eval_suffix;
707		let eq_ind_suffix_sum = StepUp::new(self.eq_ind_challenges.len(), eval_prefix)
708			.expect("eval_prefix does not exceed the equality indicator size")
709			.evaluate(self.eq_ind_challenges)
710			.expect("StepUp is initialized with eq_ind_challenges.len()");
711
712		eq_ind_suffix_sum
713			* if is_infinity_point {
714				self.const_eval_suffix.value_at_inf
715			} else {
716				self.const_eval_suffix.value
717			}
718	}
719
720	fn composition(&self) -> &Composition {
721		self.composition
722	}
723
724	fn eq_ind_partial_eval(&self) -> Option<&[P]> {
725		Some(self.eq_ind_partial_evals)
726	}
727
728	fn const_eval_suffix(&self) -> usize {
729		self.const_eval_suffix.suffix
730	}
731}
732
733struct Interpolator<'a, F, FDomain>
734where
735	F: Field,
736	FDomain: Field,
737{
738	interpolation_domain: &'a InterpolationDomain<FDomain>,
739	alpha: F,
740	first_round_eval_1: Option<F>,
741}
742
743impl<F, FDomain> SumcheckInterpolator<F> for Interpolator<'_, F, FDomain>
744where
745	F: ExtensionField<FDomain>,
746	FDomain: Field,
747{
748	#[instrument(
749		skip_all,
750		name = "eq_ind::Interpolator::round_evals_to_coeffs",
751		level = "debug"
752	)]
753	fn round_evals_to_coeffs(
754		&self,
755		last_round_sum: F,
756		mut round_evals: Vec<F>,
757	) -> Result<Vec<F>, PolynomialError> {
758		if let Some(first_round_eval_1) = self.first_round_eval_1 {
759			round_evals.insert(0, first_round_eval_1);
760		}
761
762		let one_evaluation = round_evals[0];
763		let zero_evaluation_numerator = last_round_sum - one_evaluation * self.alpha;
764		let zero_evaluation_denominator_inv = (F::ONE - self.alpha).invert_or_zero();
765		let zero_evaluation = zero_evaluation_numerator * zero_evaluation_denominator_inv;
766		round_evals.insert(0, zero_evaluation);
767
768		if round_evals.len() > 3 {
769			// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then
770			// subspace points. InterpolationDomain expects "infinity" at the last position, thus
771			// reordering is needed. Putting "special" evaluation points at the beginning of
772			// domain allows benefitting from faster/skipped interpolation even in case of mixed
773			// degree compositions .
774			let infinity_round_eval = round_evals.remove(2);
775			round_evals.push(infinity_round_eval);
776		}
777
778		Ok(self.interpolation_domain.interpolate(&round_evals)?)
779	}
780}