1use std::{cmp::Reverse, marker::PhantomData, ops::Range};
4
5use binius_field::{util::eq, ExtensionField, Field, PackedExtension, PackedField, TowerField};
6use binius_hal::{make_portable_backend, ComputationBackend, Error as HalError, SumcheckEvaluator};
7use binius_math::{
8 CompositionPoly, EvaluationDomainFactory, EvaluationOrder, InterpolationDomain,
9 MultilinearPoly, RowsBatchRef,
10};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::bail;
13use getset::Getters;
14use itertools::izip;
15use stackalloc::stackalloc_with_default;
16use tracing::instrument;
17
18use crate::{
19 polynomial::{
20 ArithCircuitPoly, Error as PolynomialError, MultilinearComposite, MultivariatePoly,
21 },
22 protocols::sumcheck::{
23 common::{
24 equal_n_vars_check, get_nontrivial_evaluation_points,
25 interpolation_domains_for_composition_degrees, RoundCoeffs,
26 },
27 prove::{
28 common::fold_partial_eq_ind, MultilinearInput, ProverState, SumcheckInterpolator,
29 SumcheckProver,
30 },
31 CompositeSumClaim, Error,
32 },
33 transparent::{eq_ind::EqIndPartialEval, step_up::StepUp},
34};
35
36pub fn validate_witness<F, P, M, Composition>(
37 multilinears: &[M],
38 eq_ind_challenges: &[F],
39 eq_ind_sum_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
40) -> Result<(), Error>
41where
42 F: Field,
43 P: PackedField<Scalar = F>,
44 M: MultilinearPoly<P> + Send + Sync,
45 Composition: CompositionPoly<P>,
46{
47 let n_vars = equal_n_vars_check(multilinears)?;
48 let multilinears = multilinears.iter().collect::<Vec<_>>();
49
50 if eq_ind_challenges.len() != n_vars {
51 bail!(Error::IncorrectEqIndChallengesLength);
52 }
53
54 let backend = make_portable_backend();
55 let eq_ind =
56 EqIndPartialEval::new(eq_ind_challenges).multilinear_extension::<P, _>(&backend)?;
57
58 for (i, claim) in eq_ind_sum_claims.into_iter().enumerate() {
59 let CompositeSumClaim {
60 composition,
61 sum: expected_sum,
62 } = claim;
63 let witness = MultilinearComposite::new(n_vars, composition, multilinears.clone())?;
64 let sum = (0..(1 << n_vars))
65 .into_par_iter()
66 .map(|j| -> Result<F, Error> {
67 Ok(eq_ind.evaluate_on_hypercube(j)? * witness.evaluate_on_hypercube(j)?)
68 })
69 .try_reduce(|| F::ZERO, |a, b| Ok(a + b))?;
70
71 if sum != expected_sum {
72 bail!(Error::SumcheckNaiveValidationFailure {
73 composition_index: i,
74 });
75 }
76 }
77 Ok(())
78}
79
80pub struct EqIndSumcheckProverBuilder<'a, P, Backend>
96where
97 P: PackedField,
98 Backend: ComputationBackend,
99{
100 eq_ind_partial_evals: Option<Backend::Vec<P>>,
101 nonzero_scalars_prefixes: Option<Vec<usize>>,
102 first_round_eval_1s: Option<Vec<P::Scalar>>,
103 backend: &'a Backend,
104}
105
106impl<'a, F, P, Backend> EqIndSumcheckProverBuilder<'a, P, Backend>
107where
108 F: TowerField,
109 P: PackedField<Scalar = F>,
110 Backend: ComputationBackend,
111{
112 pub fn new(backend: &'a Backend) -> Self {
113 Self {
114 backend,
115 eq_ind_partial_evals: None,
116 nonzero_scalars_prefixes: None,
117 first_round_eval_1s: None,
118 }
119 }
120
121 pub fn with_eq_ind_partial_evals(mut self, eq_ind_partial_evals: Backend::Vec<P>) -> Self {
123 self.eq_ind_partial_evals = Some(eq_ind_partial_evals);
124 self
125 }
126
127 pub fn with_first_round_eval_1s(mut self, first_round_eval_1s: &[F]) -> Self {
133 self.first_round_eval_1s = Some(first_round_eval_1s.to_vec());
134 self
135 }
136
137 pub fn with_nonzero_scalars_prefixes(mut self, nonzero_scalars_prefixes: &[usize]) -> Self {
142 self.nonzero_scalars_prefixes = Some(nonzero_scalars_prefixes.to_vec());
143 self
144 }
145
146 #[instrument(skip_all, level = "debug", name = "EqIndSumcheckProverBuilder::build")]
147 pub fn build<FDomain, Composition, M>(
148 self,
149 evaluation_order: EvaluationOrder,
150 multilinears: Vec<M>,
151 eq_ind_challenges: &[F],
152 composite_claims: impl IntoIterator<Item = CompositeSumClaim<F, Composition>>,
153 domain_factory: impl EvaluationDomainFactory<FDomain>,
154 switchover_fn: impl Fn(usize) -> usize,
155 ) -> Result<EqIndSumcheckProver<'a, FDomain, P, Composition, M, Backend>, Error>
156 where
157 F: ExtensionField<FDomain>,
158 P: PackedExtension<FDomain>,
159 FDomain: Field,
160 M: MultilinearPoly<P> + Send + Sync,
161 Composition: CompositionPoly<P>,
162 {
163 let n_vars = equal_n_vars_check(&multilinears)?;
164 let composite_claims = composite_claims.into_iter().collect::<Vec<_>>();
165 let backend = self.backend;
166
167 #[cfg(feature = "debug_validate_sumcheck")]
168 {
169 let composite_claims = composite_claims
170 .iter()
171 .map(|composite_claim| CompositeSumClaim {
172 composition: &composite_claim.composition,
173 sum: composite_claim.sum,
174 })
175 .collect::<Vec<_>>();
176 validate_witness(&multilinears, eq_ind_challenges, composite_claims.clone())?;
177 }
178
179 if eq_ind_challenges.len() != n_vars {
180 bail!(Error::IncorrectEqIndChallengesLength);
181 }
182
183 let eq_ind_partial_evals = if let Some(eq_ind_partial_evals) = self.eq_ind_partial_evals {
186 if eq_ind_partial_evals.len() != 1 << n_vars.saturating_sub(P::LOG_WIDTH + 1) {
187 bail!(Error::IncorrectEqIndPartialEvalsSize);
188 }
189
190 eq_ind_partial_evals
191 } else {
192 eq_ind_expand(evaluation_order, n_vars, eq_ind_challenges, backend)?
193 };
194
195 if let Some(ref first_round_eval_1s) = self.first_round_eval_1s {
196 if first_round_eval_1s.len() != composite_claims.len() {
197 bail!(Error::IncorrectFirstRoundEvalOnesLength);
198 }
199 }
200
201 for claim in &composite_claims {
202 if claim.composition.n_vars() != multilinears.len() {
203 bail!(Error::InvalidComposition {
204 expected: multilinears.len(),
205 actual: claim.composition.n_vars(),
206 });
207 }
208 }
209
210 let zero_scalars_suffixes = self
211 .nonzero_scalars_prefixes
212 .unwrap_or_else(|| vec![1 << n_vars; multilinears.len()])
213 .into_iter()
214 .map(|prefix| (1 << n_vars) - prefix)
215 .collect::<Vec<_>>();
216
217 let (compositions, claimed_sums) =
218 determine_const_eval_suffixes(composite_claims, &zero_scalars_suffixes);
219
220 let domains = interpolation_domains_for_composition_degrees(
221 domain_factory,
222 compositions
223 .iter()
224 .map(|(composition, _)| composition.degree()),
225 )?;
226
227 let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;
228
229 let multilinears_input = izip!(multilinears, &zero_scalars_suffixes)
230 .map(|(multilinear, &zero_scalars_suffix)| MultilinearInput {
231 multilinear,
232 zero_scalars_suffix,
233 })
234 .collect();
235
236 let state = ProverState::new(
237 evaluation_order,
238 multilinears_input,
239 claimed_sums,
240 nontrivial_evaluation_points,
241 switchover_fn,
242 backend,
243 )?;
244
245 let eq_ind_prefix_eval = F::ONE;
246 let eq_ind_challenges = eq_ind_challenges.to_vec();
247 let first_round_eval_1s = self.first_round_eval_1s;
248
249 Ok(EqIndSumcheckProver {
250 n_vars,
251 state,
252 eq_ind_prefix_eval,
253 eq_ind_partial_evals,
254 eq_ind_challenges,
255 compositions,
256 domains,
257 first_round_eval_1s,
258 backend: PhantomData,
259 })
260 }
261}
262
263#[derive(Default, PartialEq, Eq, Debug)]
264pub struct ConstEvalSuffix<F: Field> {
265 pub suffix: usize,
266 pub value: F,
267 pub value_at_inf: F,
268}
269
270impl<F: Field> ConstEvalSuffix<F> {
271 fn update(&mut self, evaluation_order: EvaluationOrder, n_vars: usize) {
272 let eval_prefix = (1 << n_vars) - self.suffix;
273 let updated_eval_prefix = match evaluation_order {
274 EvaluationOrder::LowToHigh => eval_prefix.div_ceil(2),
275 EvaluationOrder::HighToLow => eval_prefix.min(1 << (n_vars - 1)),
276 };
277 self.suffix = (1 << (n_vars - 1)) - updated_eval_prefix;
278 }
279}
280
281#[derive(Debug, Getters)]
282pub struct EqIndSumcheckProver<'a, FDomain, P, Composition, M, Backend>
283where
284 FDomain: Field,
285 P: PackedField,
286 M: MultilinearPoly<P> + Send + Sync,
287 Backend: ComputationBackend,
288{
289 n_vars: usize,
290 state: ProverState<'a, FDomain, P, M, Backend>,
291 eq_ind_prefix_eval: P::Scalar,
292 eq_ind_partial_evals: Backend::Vec<P>,
293 eq_ind_challenges: Vec<P::Scalar>,
294 #[getset(get = "pub")]
295 compositions: Vec<(Composition, ConstEvalSuffix<P::Scalar>)>,
296 domains: Vec<InterpolationDomain<FDomain>>,
297 first_round_eval_1s: Option<Vec<P::Scalar>>,
298 backend: PhantomData<Backend>,
299}
300
301impl<F, FDomain, P, Composition, M, Backend>
302 EqIndSumcheckProver<'_, FDomain, P, Composition, M, Backend>
303where
304 F: TowerField + ExtensionField<FDomain>,
305 FDomain: Field,
306 P: PackedExtension<FDomain, Scalar = F>,
307 Composition: CompositionPoly<P>,
308 M: MultilinearPoly<P> + Send + Sync,
309 Backend: ComputationBackend,
310{
311 fn round(&self) -> usize {
312 self.n_vars - self.n_rounds_remaining()
313 }
314
315 fn n_rounds_remaining(&self) -> usize {
316 self.state.n_vars()
317 }
318
319 fn eq_ind_round_challenge(&self) -> F {
320 match self.state.evaluation_order() {
321 EvaluationOrder::LowToHigh => self.eq_ind_challenges[self.round()],
322 EvaluationOrder::HighToLow => {
323 self.eq_ind_challenges[self.eq_ind_challenges.len() - 1 - self.round()]
324 }
325 }
326 }
327
328 fn update_eq_ind_prefix_eval(&mut self, challenge: F) {
329 self.eq_ind_prefix_eval *= eq(self.eq_ind_round_challenge(), challenge);
331 }
332}
333
334pub fn eq_ind_expand<P, Backend>(
335 evaluation_order: EvaluationOrder,
336 n_vars: usize,
337 eq_ind_challenges: &[P::Scalar],
338 backend: &Backend,
339) -> Result<Backend::Vec<P>, HalError>
340where
341 P: PackedField,
342 Backend: ComputationBackend,
343{
344 if n_vars != eq_ind_challenges.len() {
345 bail!(HalError::IncorrectQuerySize { expected: n_vars });
346 }
347
348 backend.tensor_product_full_query(match evaluation_order {
349 EvaluationOrder::LowToHigh => &eq_ind_challenges[n_vars.min(1)..],
350 EvaluationOrder::HighToLow => &eq_ind_challenges[..n_vars.saturating_sub(1)],
351 })
352}
353
354type CompositionsAndSums<F, Composition> = (Vec<(Composition, ConstEvalSuffix<F>)>, Vec<F>);
355
356fn determine_const_eval_suffixes<F, P, Composition>(
363 composite_claims: Vec<CompositeSumClaim<F, Composition>>,
364 zero_scalars_suffixes: &[usize],
365) -> CompositionsAndSums<F, Composition>
366where
367 F: Field,
368 P: PackedField<Scalar = F>,
369 Composition: CompositionPoly<P>,
370{
371 let mut zero_scalars_suffixes = zero_scalars_suffixes
372 .iter()
373 .copied()
374 .enumerate()
375 .collect::<Vec<_>>();
376
377 zero_scalars_suffixes.sort_by_key(|(_var, zero_scalars_suffix)| Reverse(*zero_scalars_suffix));
378
379 composite_claims
380 .into_iter()
381 .map(|claim| {
382 let CompositeSumClaim { composition, sum } = claim;
383
384 let mut const_eval_suffix = Default::default();
385
386 let mut expr = composition.expression();
387 let mut expr_at_inf = composition.expression().leading_term();
388
389 for &(var_index, suffix) in &zero_scalars_suffixes {
390 expr = expr.const_subst(var_index, F::ZERO).optimize();
391 expr_at_inf = expr_at_inf.const_subst(var_index, F::ZERO).optimize();
392
393 if let Some((value, value_at_inf)) = expr.constant().zip(expr_at_inf.constant()) {
394 const_eval_suffix = ConstEvalSuffix {
395 suffix,
396 value,
397 value_at_inf,
398 };
399
400 break;
401 }
402 }
403
404 ((composition, const_eval_suffix), sum)
405 })
406 .unzip::<_, _, Vec<_>, Vec<_>>()
407}
408
409impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
410 for EqIndSumcheckProver<'_, FDomain, P, Composition, M, Backend>
411where
412 F: TowerField + ExtensionField<FDomain>,
413 FDomain: Field,
414 P: PackedExtension<FDomain, Scalar = F>,
415 Composition: CompositionPoly<P>,
416 M: MultilinearPoly<P> + Send + Sync,
417 Backend: ComputationBackend,
418{
419 fn n_vars(&self) -> usize {
420 self.n_vars
421 }
422
423 fn evaluation_order(&self) -> EvaluationOrder {
424 self.state.evaluation_order()
425 }
426
427 #[instrument(skip_all, name = "EqIndSumcheckProver::execute", level = "debug")]
428 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
429 let round = self.round();
430 let n_rounds_remaining = self.n_rounds_remaining();
431
432 let alpha = self.eq_ind_round_challenge();
433 let eq_ind_partial_evals = &self.eq_ind_partial_evals;
434
435 let first_round_eval_1s = self.first_round_eval_1s.take();
436 let have_first_round_eval_1s = first_round_eval_1s.is_some();
437
438 let eq_ind_challenges = match self.state.evaluation_order() {
439 EvaluationOrder::LowToHigh => &self.eq_ind_challenges[self.n_vars.min(round + 1)..],
440 EvaluationOrder::HighToLow => {
441 &self.eq_ind_challenges[..self.n_vars.saturating_sub(round + 1)]
442 }
443 };
444
445 let evaluators = self
446 .compositions
447 .iter_mut()
448 .map(|(composition, const_eval_suffix)| {
449 let composition_at_infinity =
450 ArithCircuitPoly::new(composition.expression().leading_term());
451
452 const_eval_suffix.update(self.state.evaluation_order(), n_rounds_remaining);
453
454 Evaluator {
455 n_rounds_remaining,
456 composition,
457 composition_at_infinity,
458 have_first_round_eval_1s,
459 eq_ind_challenges,
460 eq_ind_partial_evals,
461 const_eval_suffix,
462 }
463 })
464 .collect::<Vec<_>>();
465
466 let interpolators = self
467 .domains
468 .iter()
469 .enumerate()
470 .map(|(index, interpolation_domain)| Interpolator {
471 interpolation_domain,
472 alpha,
473 first_round_eval_1: first_round_eval_1s
474 .as_ref()
475 .map(|first_round_eval_1s| first_round_eval_1s[index]),
476 })
477 .collect::<Vec<_>>();
478
479 let round_evals = self.state.calculate_round_evals(&evaluators)?;
480
481 let prime_coeffs = self.state.calculate_round_coeffs_from_evals(
482 &interpolators,
483 batch_coeff,
484 round_evals,
485 )?;
486
487 let (prime_coeffs_scaled_by_constant_term, mut prime_coeffs_scaled_by_linear_term) =
492 if F::CHARACTERISTIC == 2 {
493 (prime_coeffs.clone() * (F::ONE + alpha), prime_coeffs)
494 } else {
495 (prime_coeffs.clone() * (F::ONE - alpha), prime_coeffs * (alpha.double() - F::ONE))
496 };
497
498 prime_coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); let coeffs = (prime_coeffs_scaled_by_constant_term + &prime_coeffs_scaled_by_linear_term)
501 * self.eq_ind_prefix_eval;
502
503 Ok(coeffs)
504 }
505
506 #[instrument(skip_all, name = "EqIndSumcheckProver::fold", level = "debug")]
507 fn fold(&mut self, challenge: F) -> Result<(), Error> {
508 self.update_eq_ind_prefix_eval(challenge);
509
510 let evaluation_order = self.state.evaluation_order();
511 let n_rounds_remaining = self.n_rounds_remaining();
512
513 let Self {
514 state,
515 eq_ind_partial_evals,
516 ..
517 } = self;
518
519 binius_maybe_rayon::join(
520 || state.fold(challenge),
521 || {
522 fold_partial_eq_ind::<P, Backend>(
523 evaluation_order,
524 n_rounds_remaining - 1,
525 eq_ind_partial_evals,
526 );
527 },
528 )
529 .0?;
530 Ok(())
531 }
532
533 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
534 let mut evals = self.state.finish()?;
535 evals.push(self.eq_ind_prefix_eval);
536 Ok(evals)
537 }
538}
539
540struct Evaluator<'a, P, Composition>
541where
542 P: PackedField,
543{
544 n_rounds_remaining: usize,
545 composition: &'a Composition,
546 composition_at_infinity: ArithCircuitPoly<P::Scalar>,
547 have_first_round_eval_1s: bool,
548 eq_ind_challenges: &'a [P::Scalar],
549 eq_ind_partial_evals: &'a [P],
550 const_eval_suffix: &'a ConstEvalSuffix<P::Scalar>,
551}
552
553impl<P, Composition> SumcheckEvaluator<P, Composition> for Evaluator<'_, P, Composition>
554where
555 P: PackedField<Scalar: TowerField>,
556 Composition: CompositionPoly<P>,
557{
558 fn eval_point_indices(&self) -> Range<usize> {
559 let start_index = if self.have_first_round_eval_1s { 2 } else { 1 };
561 start_index..self.composition.degree() + 1
562 }
563
564 fn process_subcube_at_eval_point(
565 &self,
566 subcube_vars: usize,
567 subcube_index: usize,
568 is_infinity_point: bool,
569 batch_query: &RowsBatchRef<P>,
570 ) -> P {
571 let row_len = batch_query.row_len();
572
573 stackalloc_with_default(row_len, |evals| {
574 if is_infinity_point {
575 self.composition_at_infinity
576 .batch_evaluate(batch_query, evals)
577 .expect("correct by query construction invariant");
578 } else {
579 self.composition
580 .batch_evaluate(batch_query, evals)
581 .expect("correct by query construction invariant");
582 };
583
584 let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
585 for (i, eval) in evals.iter_mut().enumerate() {
586 *eval *= self.eq_ind_partial_evals[subcube_start + i];
590 }
591 evals.iter().copied().sum::<P>()
592 })
593 }
594
595 fn process_constant_eval_suffix(
596 &self,
597 const_eval_suffix: usize,
598 is_infinity_point: bool,
599 ) -> P::Scalar {
600 let eval_prefix = (1 << self.n_rounds_remaining) - const_eval_suffix;
601 let eq_ind_suffix_sum = StepUp::new(self.eq_ind_challenges.len(), eval_prefix)
602 .expect("eval_prefix does not exceed the equality indicator size")
603 .evaluate(self.eq_ind_challenges)
604 .expect("StepUp is initialized with eq_ind_challenges.len()");
605
606 eq_ind_suffix_sum
607 * if is_infinity_point {
608 self.const_eval_suffix.value_at_inf
609 } else {
610 self.const_eval_suffix.value
611 }
612 }
613
614 fn composition(&self) -> &Composition {
615 self.composition
616 }
617
618 fn eq_ind_partial_eval(&self) -> Option<&[P]> {
619 Some(self.eq_ind_partial_evals)
620 }
621
622 fn const_eval_suffix(&self) -> usize {
623 self.const_eval_suffix.suffix
624 }
625}
626
627struct Interpolator<'a, F, FDomain>
628where
629 F: Field,
630 FDomain: Field,
631{
632 interpolation_domain: &'a InterpolationDomain<FDomain>,
633 alpha: F,
634 first_round_eval_1: Option<F>,
635}
636
637impl<F, FDomain> SumcheckInterpolator<F> for Interpolator<'_, F, FDomain>
638where
639 F: ExtensionField<FDomain>,
640 FDomain: Field,
641{
642 #[instrument(
643 skip_all,
644 name = "eq_ind::Interpolator::round_evals_to_coeffs",
645 level = "debug"
646 )]
647 fn round_evals_to_coeffs(
648 &self,
649 last_round_sum: F,
650 mut round_evals: Vec<F>,
651 ) -> Result<Vec<F>, PolynomialError> {
652 if let Some(first_round_eval_1) = self.first_round_eval_1 {
653 round_evals.insert(0, first_round_eval_1);
654 }
655
656 let one_evaluation = round_evals[0];
657 let zero_evaluation_numerator = last_round_sum - one_evaluation * self.alpha;
658 let zero_evaluation_denominator_inv = (F::ONE - self.alpha).invert_or_zero();
659 let zero_evaluation = zero_evaluation_numerator * zero_evaluation_denominator_inv;
660 round_evals.insert(0, zero_evaluation);
661
662 if round_evals.len() > 3 {
663 let infinity_round_eval = round_evals.remove(2);
668 round_evals.push(infinity_round_eval);
669 }
670
671 Ok(self.interpolation_domain.interpolate(&round_evals)?)
672 }
673}