binius_core/protocols/sumcheck/prove/
concrete_prover.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// Copyright 2024-2025 Irreducible Inc.

use binius_field::{
	ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, RepackedExtension,
};
use binius_hal::ComputationBackend;
use binius_math::{CompositionPolyOS, MultilinearPoly};

use super::{batch_prove::SumcheckProver, RegularSumcheckProver, ZerocheckProver};
use crate::protocols::sumcheck::{common::RoundCoeffs, error::Error};

/// A sum type that is used to put both regular sumchecks and zerochecks into the same `batch_prove` call.
pub enum ConcreteProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
	FDomain: Field,
	PBase: PackedField,
	P: PackedField,
	M: MultilinearPoly<P> + Send + Sync,
	Backend: ComputationBackend,
{
	Sumcheck(RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>),
	Zerocheck(ZerocheckProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend>),
}

impl<F, FDomain, PBase, P, CompositionBase, Composition, M, Backend> SumcheckProver<F>
	for ConcreteProver<'_, FDomain, PBase, P, CompositionBase, Composition, M, Backend>
where
	F: Field + ExtensionField<PBase::Scalar> + ExtensionField<FDomain>,
	FDomain: Field,
	PBase: PackedField<Scalar: ExtensionField<FDomain>> + PackedExtension<FDomain>,
	P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain> + RepackedExtension<PBase>,
	CompositionBase: CompositionPolyOS<PBase>,
	Composition: CompositionPolyOS<P>,
	M: MultilinearPoly<P> + Send + Sync,
	Backend: ComputationBackend,
{
	fn n_vars(&self) -> usize {
		match self {
			ConcreteProver::Sumcheck(prover) => prover.n_vars(),
			ConcreteProver::Zerocheck(prover) => prover.n_vars(),
		}
	}

	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
		match self {
			ConcreteProver::Sumcheck(prover) => prover.execute(batch_coeff),
			ConcreteProver::Zerocheck(prover) => prover.execute(batch_coeff),
		}
	}

	fn fold(&mut self, challenge: F) -> Result<(), Error> {
		match self {
			ConcreteProver::Sumcheck(prover) => prover.fold(challenge),
			ConcreteProver::Zerocheck(prover) => prover.fold(challenge),
		}
	}

	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
		match *self {
			ConcreteProver::Sumcheck(prover) => Box::new(prover).finish(),
			ConcreteProver::Zerocheck(prover) => Box::new(prover).finish(),
		}
	}
}