1use std::{marker::PhantomData, ops::Range, sync::Arc};
4
5use binius_field::{
6 util::{eq, powers},
7 ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, PackedSubfield,
8 TowerField,
9};
10use binius_hal::{ComputationBackend, SumcheckEvaluator};
11use binius_math::{
12 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
13 MLEDirectAdapter, MultilinearPoly, MultilinearQuery,
14};
15use binius_maybe_rayon::prelude::*;
16use binius_utils::bail;
17use bytemuck::zeroed_vec;
18use getset::Getters;
19use itertools::izip;
20use stackalloc::stackalloc_with_default;
21use tracing::instrument;
22
23use crate::{
24 polynomial::{ArithCircuitPoly, Error as PolynomialError, MultilinearComposite},
25 protocols::sumcheck::{
26 common::{determine_switchovers, equal_n_vars_check, get_nontrivial_evaluation_points},
27 prove::{
28 common::fold_partial_eq_ind,
29 univariate::{
30 zerocheck_univariate_evals, ZerocheckUnivariateEvalsOutput,
31 ZerocheckUnivariateFoldResult,
32 },
33 ProverState, SumcheckInterpolator, SumcheckProver, UnivariateZerocheckProver,
34 },
35 univariate::LagrangeRoundEvals,
36 univariate_zerocheck::domain_size,
37 Error, RoundCoeffs,
38 },
39 witness::MultilinearWitness,
40};
41
42pub fn validate_witness<'a, F, P, M, Composition>(
43 multilinears: &[M],
44 zero_claims: impl IntoIterator<Item = &'a (String, Composition)>,
45) -> Result<(), Error>
46where
47 F: Field,
48 P: PackedField<Scalar = F>,
49 M: MultilinearPoly<P> + Send + Sync,
50 Composition: CompositionPoly<P> + 'a,
51{
52 let n_vars = multilinears
53 .first()
54 .map(|multilinear| multilinear.n_vars())
55 .unwrap_or_default();
56 for multilinear in multilinears {
57 if multilinear.n_vars() != n_vars {
58 bail!(Error::NumberOfVariablesMismatch);
59 }
60 }
61
62 let multilinears = multilinears.iter().collect::<Vec<_>>();
63
64 for (name, composition) in zero_claims {
65 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
66 (0..(1 << n_vars)).into_par_iter().try_for_each(|j| {
67 if witness.evaluate_on_hypercube(j)? != F::ZERO {
68 return Err(Error::ZerocheckNaiveValidationFailure {
69 composition_name: name.to_string(),
70 vertex_index: j,
71 });
72 }
73 Ok(())
74 })?;
75 }
76 Ok(())
77}
78
79#[derive(Debug, Getters)]
91pub struct UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
92where
93 FDomain: Field,
94 FBase: Field,
95 P: PackedField,
96 Backend: ComputationBackend,
97{
98 n_vars: usize,
99 #[getset(get = "pub")]
100 multilinears: Vec<M>,
101 switchover_rounds: Vec<usize>,
102 compositions: Vec<(String, CompositionBase, Composition)>,
103 zerocheck_challenges: Vec<P::Scalar>,
104 domains: Vec<InterpolationDomain<FDomain>>,
105 backend: &'a Backend,
106 univariate_evals_output: Option<ZerocheckUnivariateEvalsOutput<P::Scalar, P, Backend>>,
107 _p_base_marker: PhantomData<FBase>,
108 _m_marker: PhantomData<&'m ()>,
109}
110
111impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
112 UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
113where
114 F: Field,
115 FDomain: Field,
116 FBase: ExtensionField<FDomain>,
117 P: PackedFieldIndexable<Scalar = F>
118 + PackedExtension<F, PackedSubfield = P>
119 + PackedExtension<FBase>
120 + PackedExtension<FDomain>,
121 CompositionBase: CompositionPoly<<P as PackedExtension<FBase>>::PackedSubfield>,
122 Composition: CompositionPoly<P>,
123 M: MultilinearPoly<P> + Send + Sync + 'm,
124 Backend: ComputationBackend,
125{
126 pub fn new(
127 multilinears: Vec<M>,
128 zero_claims: impl IntoIterator<Item = (String, CompositionBase, Composition)>,
129 zerocheck_challenges: &[F],
130 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
131 switchover_fn: impl Fn(usize) -> usize,
132 backend: &'a Backend,
133 ) -> Result<Self, Error> {
134 let n_vars = equal_n_vars_check(&multilinears)?;
135
136 let compositions = zero_claims.into_iter().collect::<Vec<_>>();
137 for (_, composition_base, composition) in &compositions {
138 if composition_base.n_vars() != multilinears.len()
139 || composition.n_vars() != multilinears.len()
140 || composition_base.degree() != composition.degree()
141 {
142 bail!(Error::InvalidComposition {
143 actual: composition.n_vars(),
144 expected: multilinears.len(),
145 });
146 }
147 }
148 #[cfg(feature = "debug_validate_sumcheck")]
149 {
150 let compositions = compositions
151 .iter()
152 .map(|(name, _, a)| (name.clone(), a))
153 .collect::<Vec<_>>();
154 validate_witness(&multilinears, &compositions)?;
155 }
156
157 let switchover_rounds = determine_switchovers(&multilinears, switchover_fn);
158 let zerocheck_challenges = zerocheck_challenges.to_vec();
159
160 let domains = compositions
161 .iter()
162 .map(|(_, _, composition)| {
163 let degree = composition.degree();
164 let domain =
165 evaluation_domain_factory.create_with_infinity(degree + 1, degree >= 2)?;
166 Ok(domain.into())
167 })
168 .collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
169 .map_err(Error::MathError)?;
170
171 Ok(Self {
172 n_vars,
173 multilinears,
174 switchover_rounds,
175 compositions,
176 zerocheck_challenges,
177 domains,
178 backend,
179 univariate_evals_output: None,
180 _p_base_marker: PhantomData,
181 _m_marker: PhantomData,
182 })
183 }
184
185 #[instrument(skip_all, level = "debug")]
186 #[allow(clippy::type_complexity)]
187 pub fn into_regular_zerocheck(
188 self,
189 ) -> Result<
190 ZerocheckProver<'a, FDomain, P, Composition, MultilinearWitness<'m, P>, Backend>,
191 Error,
192 > {
193 if self.univariate_evals_output.is_some() {
194 bail!(Error::ExpectedFold);
195 }
196
197 let multilinears = self
202 .multilinears
203 .into_iter()
204 .map(|multilinear| Arc::new(multilinear) as MultilinearWitness<'_, P>)
205 .collect::<Vec<_>>();
206
207 #[cfg(feature = "debug_validate_sumcheck")]
208 {
209 let compositions = self
210 .compositions
211 .iter()
212 .map(|(name, _, a)| (name.clone(), a))
213 .collect::<Vec<_>>();
214 validate_witness(&multilinears, &compositions)?;
215 }
216
217 let compositions = self
218 .compositions
219 .into_iter()
220 .map(|(_, _, composition)| composition)
221 .collect::<Vec<_>>();
222
223 let start = self.n_vars.min(1);
225 let partial_eq_ind_evals = self
226 .backend
227 .tensor_product_full_query(&self.zerocheck_challenges[start..])?;
228 let claimed_sums = vec![F::ZERO; compositions.len()];
229
230 ZerocheckProver::new(
232 EvaluationOrder::LowToHigh,
233 multilinears,
234 &self.switchover_rounds,
235 compositions,
236 partial_eq_ind_evals,
237 self.zerocheck_challenges,
238 claimed_sums,
239 self.domains,
240 RegularFirstRound::SkipCube,
241 self.backend,
242 )
243 }
244}
245
246impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
247 UnivariateZerocheckProver<'a, F>
248 for UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend>
249where
250 F: TowerField,
251 FDomain: TowerField,
252 FBase: ExtensionField<FDomain>,
253 P: PackedFieldIndexable<Scalar = F>
254 + PackedExtension<F, PackedSubfield = P>
255 + PackedExtension<FBase, PackedSubfield: PackedFieldIndexable>
256 + PackedExtension<FDomain, PackedSubfield: PackedFieldIndexable>,
257 CompositionBase: CompositionPoly<PackedSubfield<P, FBase>> + 'static,
258 Composition: CompositionPoly<P> + 'static,
259 M: MultilinearPoly<P> + Send + Sync + 'm,
260 Backend: ComputationBackend,
261{
262 fn n_vars(&self) -> usize {
263 self.n_vars
264 }
265
266 fn domain_size(&self, skip_rounds: usize) -> usize {
267 self.compositions
268 .iter()
269 .map(|(_, composition, _)| domain_size(composition.degree(), skip_rounds))
270 .max()
271 .unwrap_or(0)
272 }
273
274 #[instrument(skip_all, level = "debug")]
275 fn execute_univariate_round(
276 &mut self,
277 skip_rounds: usize,
278 max_domain_size: usize,
279 batch_coeff: F,
280 ) -> Result<LagrangeRoundEvals<F>, Error> {
281 if self.univariate_evals_output.is_some() {
282 bail!(Error::ExpectedFold);
283 }
284
285 let compositions_base = self
287 .compositions
288 .iter()
289 .map(|(_, composition_base, _)| composition_base)
290 .collect::<Vec<_>>();
291
292 let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
295 &self.multilinears,
296 &compositions_base,
297 &self.zerocheck_challenges,
298 skip_rounds,
299 max_domain_size,
300 self.backend,
301 )?;
302
303 let zeros_prefix_len = 1 << skip_rounds;
305 let batched_round_evals = univariate_evals_output
306 .round_evals
307 .iter()
308 .zip(powers(batch_coeff))
309 .map(|(evals, scalar)| {
310 let round_evals = LagrangeRoundEvals {
311 zeros_prefix_len,
312 evals: evals.clone(),
313 };
314 round_evals * scalar
315 })
316 .try_fold(
317 LagrangeRoundEvals::zeros(max_domain_size),
318 |mut accum, evals| -> Result<_, Error> {
319 accum.add_assign_lagrange(&evals)?;
320 Ok(accum)
321 },
322 )?;
323
324 self.univariate_evals_output = Some(univariate_evals_output);
325
326 Ok(batched_round_evals)
327 }
328
329 #[instrument(skip_all, level = "debug")]
330 fn fold_univariate_round(
331 self: Box<Self>,
332 challenge: F,
333 ) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
334 if self.univariate_evals_output.is_none() {
335 bail!(Error::ExpectedExecution);
336 }
337
338 let ZerocheckUnivariateFoldResult {
341 skip_rounds,
342 subcube_lagrange_coeffs,
343 claimed_prime_sums,
344 partial_eq_ind_evals,
345 } = self
346 .univariate_evals_output
347 .expect("validated to be Some")
348 .fold::<FDomain>(challenge)?;
349
350 let mut packed_subcube_lagrange_coeffs =
359 zeroed_vec::<P>(1 << skip_rounds.saturating_sub(P::LOG_WIDTH));
360 P::unpack_scalars_mut(&mut packed_subcube_lagrange_coeffs)[..1 << skip_rounds]
361 .copy_from_slice(&subcube_lagrange_coeffs);
362 let lagrange_coeffs_query =
363 MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;
364
365 let partial_low_multilinears = self
366 .multilinears
367 .into_par_iter()
368 .map(|multilinear| -> Result<_, Error> {
369 let multilinear =
370 multilinear.evaluate_partial_low(lagrange_coeffs_query.to_ref())?;
371 let mle_adapter = Arc::new(MLEDirectAdapter::from(multilinear));
372 Ok(mle_adapter as MultilinearWitness<'static, P>)
373 })
374 .collect::<Result<Vec<_>, _>>()?;
375
376 let switchover_rounds = self
377 .switchover_rounds
378 .into_iter()
379 .map(|switchover_round| switchover_round.saturating_sub(skip_rounds))
380 .collect::<Vec<_>>();
381
382 let zerocheck_challenges = self.zerocheck_challenges.clone();
383
384 let compositions = self
385 .compositions
386 .into_iter()
387 .map(|(_, _, composition)| composition)
388 .collect();
389
390 let regular_prover = ZerocheckProver::new(
396 EvaluationOrder::LowToHigh,
397 partial_low_multilinears,
398 &switchover_rounds,
399 compositions,
400 partial_eq_ind_evals,
401 zerocheck_challenges,
402 claimed_prime_sums,
403 self.domains,
404 RegularFirstRound::LaterRound,
405 self.backend,
406 )?;
407
408 Ok(Box::new(regular_prover) as Box<dyn SumcheckProver<F> + 'a>)
409 }
410}
411
412#[derive(Debug, Clone, Copy)]
413enum RegularFirstRound {
414 SkipCube,
415 LaterRound,
416}
417
418#[derive(Debug)]
433pub struct ZerocheckProver<'a, FDomain, P, Composition, M, Backend>
434where
435 FDomain: Field,
436 P: PackedField,
437 M: MultilinearPoly<P> + Send + Sync,
438 Backend: ComputationBackend,
439{
440 n_vars: usize,
441 state: ProverState<'a, FDomain, P, M, Backend>,
442 eq_ind_eval: P::Scalar,
443 partial_eq_ind_evals: Backend::Vec<P>,
444 zerocheck_challenges: Vec<P::Scalar>,
445 compositions: Vec<Composition>,
446 domains: Vec<InterpolationDomain<FDomain>>,
447 first_round: RegularFirstRound,
448}
449
450impl<'a, F, FDomain, P, Composition, M, Backend>
451 ZerocheckProver<'a, FDomain, P, Composition, M, Backend>
452where
453 F: Field,
454 FDomain: Field,
455 P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
456 Composition: CompositionPoly<P>,
457 M: MultilinearPoly<P> + Send + Sync,
458 Backend: ComputationBackend,
459{
460 #[allow(clippy::too_many_arguments)]
461 fn new(
462 evaluation_order: EvaluationOrder,
467 multilinears: Vec<M>,
468 switchover_rounds: &[usize],
469 compositions: Vec<Composition>,
470 partial_eq_ind_evals: Backend::Vec<P>,
471 zerocheck_challenges: Vec<F>,
472 claimed_prime_sums: Vec<F>,
473 domains: Vec<InterpolationDomain<FDomain>>,
474 first_round: RegularFirstRound,
475 backend: &'a Backend,
476 ) -> Result<Self, Error> {
477 if claimed_prime_sums.len() != compositions.len() {
478 bail!(Error::IncorrectClaimedPrimeSumsLength);
479 }
480
481 let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
482
483 let state = ProverState::new_with_switchover_rounds(
484 evaluation_order,
485 multilinears,
486 switchover_rounds,
487 claimed_prime_sums,
488 nontrivial_evaluation_points,
489 backend,
490 )?;
491 let n_vars = state.n_vars();
492
493 if zerocheck_challenges.len() != n_vars {
494 bail!(Error::IncorrectZerocheckChallengesLength);
495 }
496
497 if partial_eq_ind_evals.len() != 1 << n_vars.saturating_sub(1 + P::LOG_WIDTH) {
500 bail!(Error::IncorrectZerocheckPartialEqIndSize);
501 }
502
503 let eq_ind_eval = F::ONE;
504
505 Ok(Self {
506 n_vars,
507 state,
508 eq_ind_eval,
509 partial_eq_ind_evals,
510 zerocheck_challenges,
511 compositions,
512 domains,
513 first_round,
514 })
515 }
516
517 fn round(&self) -> usize {
518 self.n_vars - self.n_rounds_remaining()
519 }
520
521 fn n_rounds_remaining(&self) -> usize {
522 self.state.n_vars()
523 }
524
525 fn update_eq_ind_eval(&mut self, challenge: F) {
526 let alpha = self.zerocheck_challenges[self.round()];
528 self.eq_ind_eval *= eq(alpha, challenge);
529 }
530
531 #[instrument(skip_all, level = "debug")]
532 fn fold_partial_eq_ind(&mut self) {
533 fold_partial_eq_ind::<P, Backend>(
534 self.state.evaluation_order(),
535 self.n_rounds_remaining(),
536 &mut self.partial_eq_ind_evals,
537 );
538 }
539}
540
541impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
542 for ZerocheckProver<'_, FDomain, P, Composition, M, Backend>
543where
544 F: TowerField + ExtensionField<FDomain>,
545 FDomain: Field,
546 P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
547 Composition: CompositionPoly<P>,
548 M: MultilinearPoly<P> + Send + Sync,
549 Backend: ComputationBackend,
550{
551 fn n_vars(&self) -> usize {
552 self.n_vars
553 }
554
555 fn evaluation_order(&self) -> EvaluationOrder {
556 self.state.evaluation_order()
557 }
558
559 #[instrument(skip_all, name = "ZerocheckProver::fold", level = "debug")]
560 fn fold(&mut self, challenge: F) -> Result<(), Error> {
561 self.update_eq_ind_eval(challenge);
562 self.state.fold(challenge)?;
563
564 self.fold_partial_eq_ind();
566
567 Ok(())
568 }
569
570 #[instrument(skip_all, name = "ZerocheckProver::execute", level = "debug")]
571 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
572 let round = self.round();
573 let skip_cube_first_round =
574 round == 0 && matches!(self.first_round, RegularFirstRound::SkipCube);
575 let coeffs = if skip_cube_first_round {
576 let evaluators = izip!(&self.compositions, &self.domains)
577 .map(|(composition, interpolation_domain)| {
578 let composition_at_infinity =
579 ArithCircuitPoly::new(composition.expression().leading_term());
580
581 ZerocheckFirstRoundEvaluator {
582 composition,
583 composition_at_infinity,
584 interpolation_domain,
585 partial_eq_ind_evals: &self.partial_eq_ind_evals,
586 }
587 })
588 .collect::<Vec<_>>();
589 let evals = self.state.calculate_round_evals(&evaluators)?;
590 self.state
591 .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?
592 } else {
593 let evaluators = izip!(&self.compositions, &self.domains)
594 .map(|(composition, interpolation_domain)| {
595 let composition_at_infinity =
596 ArithCircuitPoly::new(composition.expression().leading_term());
597
598 ZerocheckLaterRoundEvaluator {
599 composition,
600 composition_at_infinity,
601 interpolation_domain,
602 partial_eq_ind_evals: &self.partial_eq_ind_evals,
603 round_zerocheck_challenge: self.zerocheck_challenges[round],
604 }
605 })
606 .collect::<Vec<_>>();
607 let evals = self.state.calculate_round_evals(&evaluators)?;
608 self.state
609 .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?
610 };
611
612 let alpha = self.zerocheck_challenges[round];
614
615 let constant_scalar = F::ONE - alpha;
619 let linear_scalar = alpha.double() - F::ONE;
620
621 let coeffs_scaled_by_constant_term = coeffs.clone() * constant_scalar;
622 let mut coeffs_scaled_by_linear_term = coeffs * linear_scalar;
623 coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); let sumcheck_coeffs = coeffs_scaled_by_constant_term + &coeffs_scaled_by_linear_term;
626 Ok(sumcheck_coeffs * self.eq_ind_eval)
627 }
628
629 #[instrument(skip_all, name = "ZerocheckProver::finish", level = "debug")]
630 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
631 let mut evals = self.state.finish()?;
632 evals.push(self.eq_ind_eval);
633 Ok(evals)
634 }
635}
636
637struct ZerocheckFirstRoundEvaluator<'a, P, FDomain, Composition>
638where
639 P: PackedField,
640 FDomain: Field,
641{
642 composition: &'a Composition,
643 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
644 interpolation_domain: &'a InterpolationDomain<FDomain>,
645 partial_eq_ind_evals: &'a [P],
646}
647
648impl<P, FDomain, Composition> SumcheckEvaluator<P, Composition>
649 for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition>
650where
651 P: PackedField<Scalar: TowerField + ExtensionField<FDomain>>,
652 FDomain: Field,
653 Composition: CompositionPoly<P>,
654{
655 fn eval_point_indices(&self) -> Range<usize> {
656 2..self.composition.degree() + 1
660 }
661
662 fn process_subcube_at_eval_point(
663 &self,
664 subcube_vars: usize,
665 subcube_index: usize,
666 is_infinity_point: bool,
667 batch_query: &[&[P]],
668 ) -> P {
669 if self.composition.degree() == 1 {
673 return P::zero();
674 }
675 let row_len = batch_query.first().map_or(0, |row| row.len());
676
677 stackalloc_with_default(row_len, |evals| {
678 if is_infinity_point {
679 self.composition_at_infinity
680 .batch_evaluate(batch_query, evals)
681 .expect("correct by query construction invariant");
682 } else {
683 self.composition
684 .batch_evaluate(batch_query, evals)
685 .expect("correct by query construction invariant");
686 }
687
688 let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
689 let partial_eq_ind_evals_slice = &self.partial_eq_ind_evals[subcube_start..];
690 let field_sum = PackedField::iter_slice(partial_eq_ind_evals_slice)
691 .zip(PackedField::iter_slice(evals))
692 .map(|(eq_ind_scalar, base_scalar)| eq_ind_scalar * base_scalar)
693 .sum();
694
695 P::set_single(field_sum)
696 })
697 }
698
699 fn composition(&self) -> &Composition {
700 self.composition
701 }
702
703 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
704 Some(self.partial_eq_ind_evals)
705 }
706}
707
708impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
709 for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition>
710where
711 F: Field + ExtensionField<FDomain>,
712 P: PackedField<Scalar = F>,
713 FDomain: Field,
714{
715 fn round_evals_to_coeffs(
716 &self,
717 last_round_sum: F,
718 mut round_evals: Vec<F>,
719 ) -> Result<Vec<F>, PolynomialError> {
720 assert_eq!(last_round_sum, F::ZERO);
721
722 round_evals.insert(0, P::Scalar::ZERO);
725 round_evals.insert(0, P::Scalar::ZERO);
726
727 if round_evals.len() > 3 {
728 let infinity_round_eval = round_evals.remove(2);
733 round_evals.push(infinity_round_eval);
734 }
735
736 let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
737 Ok(coeffs)
738 }
739}
740
741struct ZerocheckLaterRoundEvaluator<'a, P, FDomain, Composition>
742where
743 P: PackedField,
744 FDomain: Field,
745{
746 composition: &'a Composition,
747 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
748 interpolation_domain: &'a InterpolationDomain<FDomain>,
749 partial_eq_ind_evals: &'a [P],
750 round_zerocheck_challenge: P::Scalar,
751}
752
753impl<P, FDomain, Composition> SumcheckEvaluator<P, Composition>
754 for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition>
755where
756 P: PackedField<Scalar: TowerField + ExtensionField<FDomain>>,
757 FDomain: Field,
758 Composition: CompositionPoly<P>,
759{
760 fn eval_point_indices(&self) -> Range<usize> {
761 1..self.composition.degree() + 1
765 }
766
767 fn process_subcube_at_eval_point(
768 &self,
769 subcube_vars: usize,
770 subcube_index: usize,
771 is_infinity_point: bool,
772 batch_query: &[&[P]],
773 ) -> P {
774 if self.composition.degree() == 1 {
778 return P::zero();
779 }
780 let row_len = batch_query.first().map_or(0, |row| row.len());
781
782 stackalloc_with_default(row_len, |evals| {
783 if is_infinity_point {
784 self.composition_at_infinity
785 .batch_evaluate(batch_query, evals)
786 .expect("correct by query construction invariant");
787 } else {
788 self.composition
789 .batch_evaluate(batch_query, evals)
790 .expect("correct by query construction invariant");
791 }
792
793 let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
794 for (i, eval) in evals.iter_mut().enumerate() {
795 *eval *= self.partial_eq_ind_evals[subcube_start + i];
796 }
797
798 evals.iter().copied().sum::<P>()
799 })
800 }
801
802 fn composition(&self) -> &Composition {
803 self.composition
804 }
805
806 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
807 Some(self.partial_eq_ind_evals)
808 }
809}
810
811impl<F, P, FDomain, Composition> SumcheckInterpolator<F>
812 for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition>
813where
814 F: Field,
815 P: PackedField<Scalar = F> + PackedExtension<FDomain>,
816 FDomain: Field,
817{
818 fn round_evals_to_coeffs(
819 &self,
820 last_round_sum: F,
821 mut round_evals: Vec<F>,
822 ) -> Result<Vec<F>, PolynomialError> {
823 let alpha = self.round_zerocheck_challenge;
829 let one_evaluation = round_evals[0]; let zero_evaluation_numerator = last_round_sum - one_evaluation * alpha;
831 let zero_evaluation_denominator_inv = (F::ONE - alpha).invert_or_zero();
832 let zero_evaluation = zero_evaluation_numerator * zero_evaluation_denominator_inv;
833
834 round_evals.insert(0, zero_evaluation);
835
836 if round_evals.len() > 3 {
837 let infinity_round_eval = round_evals.remove(2);
842 round_evals.push(infinity_round_eval);
843 }
844
845 let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
846 Ok(coeffs)
847 }
848}