binius_core/protocols/sumcheck/prove/
batch_zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::sync::Arc;
4
5use binius_field::{ExtensionField, PackedExtension, PackedField, TowerField};
6use binius_hal::{make_portable_backend, CpuBackend};
7use binius_math::{
8	BinarySubspace, EvaluationDomain, EvaluationOrder, IsomorphicEvaluationDomainFactory,
9	MLEDirectAdapter, MultilinearPoly,
10};
11use binius_utils::{bail, sorting::is_sorted_ascending};
12
13use crate::{
14	fiat_shamir::{CanSample, Challenger},
15	protocols::sumcheck::{
16		immediate_switchover_heuristic,
17		prove::{
18			front_loaded, logging::FoldLowDimensionsData, RegularSumcheckProver, SumcheckProver,
19		},
20		zerocheck::{
21			lagrange_evals_multilinear_extension, univariatizing_reduction_claim,
22			BatchZerocheckOutput, ZerocheckRoundEvals,
23		},
24		BatchSumcheckOutput, Error,
25	},
26	transcript::ProverTranscript,
27};
28
29/// A zerocheck prover interface.
30///
31/// The primary reason for providing this logic via a trait is the ability to type erase univariate
32/// round small fields, which may differ between the provers, and to decouple the batch prover implementation
33/// from the relatively complex type signatures of the individual provers.
34///
35/// The batch prover must obey a specific sequence of calls: [`Self::execute_univariate_round`]
36/// should be followed by [`Self::fold_univariate_round`], and then [`Self::project_to_skipped_variables`].
37/// Getters [`Self::n_vars`] and [`Self::domain_size`] are used for alignment and maximal domain size calculation
38/// required by the Lagrange representation of the univariate round polynomial.
39/// Folding univariate round results in a [`SumcheckProver`] instance that can be driven to completion to prove the
40/// remaining multilinear rounds.
41///
42/// This trait is object-safe.
43pub trait ZerocheckProver<'a, P: PackedField> {
44	/// The number of variables in the multivariate polynomial.
45	fn n_vars(&self) -> usize;
46
47	/// Maximal required Lagrange domain size among compositions in this prover.
48	///
49	/// Returns `None` if the current prover state doesn't contain information about the domain size.
50	fn domain_size(&self, skip_rounds: usize) -> Option<usize>;
51
52	/// Computes the prover message for the univariate round as a univariate polynomial.
53	///
54	/// The prover message mixes the univariate polynomials of the underlying composites using
55	/// the same approach as [`SumcheckProver::execute`].
56	///
57	/// Unlike multilinear rounds, the returned univariate is not in monomial basis but in
58	/// Lagrange basis.
59	fn execute_univariate_round(
60		&mut self,
61		skip_rounds: usize,
62		max_domain_size: usize,
63		batch_coeff: P::Scalar,
64	) -> Result<ZerocheckRoundEvals<P::Scalar>, Error>;
65
66	/// Folds into a regular multilinear prover for the remaining rounds.
67	fn fold_univariate_round(
68		&mut self,
69		challenge: P::Scalar,
70	) -> Result<Box<dyn SumcheckProver<P::Scalar> + 'a>, Error>;
71
72	/// Projects witness onto the "skipped" variables for the univariatizing reduction.
73	fn project_to_skipped_variables(
74		self: Box<Self>,
75		challenges: &[P::Scalar],
76	) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error>;
77}
78
79// NB: auto_impl does not currently handle ?Sized bound on Box<Self> receivers correctly.
80impl<'a, P: PackedField, Prover: ZerocheckProver<'a, P> + ?Sized> ZerocheckProver<'a, P>
81	for Box<Prover>
82{
83	fn n_vars(&self) -> usize {
84		(**self).n_vars()
85	}
86
87	fn domain_size(&self, skip_rounds: usize) -> Option<usize> {
88		(**self).domain_size(skip_rounds)
89	}
90
91	fn execute_univariate_round(
92		&mut self,
93		skip_rounds: usize,
94		max_domain_size: usize,
95		batch_coeff: P::Scalar,
96	) -> Result<ZerocheckRoundEvals<P::Scalar>, Error> {
97		(**self).execute_univariate_round(skip_rounds, max_domain_size, batch_coeff)
98	}
99
100	fn fold_univariate_round(
101		&mut self,
102		challenge: P::Scalar,
103	) -> Result<Box<dyn SumcheckProver<P::Scalar> + 'a>, Error> {
104		(**self).fold_univariate_round(challenge)
105	}
106
107	fn project_to_skipped_variables(
108		self: Box<Self>,
109		challenges: &[P::Scalar],
110	) -> Result<Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>, Error> {
111		(*self).project_to_skipped_variables(challenges)
112	}
113}
114
115fn univariatizing_reduction_prover<F, FDomain, P>(
116	mut projected_multilinears: Vec<Arc<dyn MultilinearPoly<P> + Send + Sync>>,
117	skip_rounds: usize,
118	univariatized_multilinear_evals: Vec<Vec<F>>,
119	univariate_challenge: F,
120	backend: &'_ CpuBackend,
121) -> Result<impl SumcheckProver<F> + '_, Error>
122where
123	F: TowerField + ExtensionField<FDomain>,
124	FDomain: TowerField,
125	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
126{
127	let sumcheck_claim =
128		univariatizing_reduction_claim(skip_rounds, &univariatized_multilinear_evals)?;
129
130	let subspace =
131		BinarySubspace::<FDomain::Canonical>::with_dim(skip_rounds)?.isomorphic::<FDomain>();
132	let ntt_domain = EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
133
134	projected_multilinears.push(
135		MLEDirectAdapter::from(lagrange_evals_multilinear_extension(
136			&ntt_domain,
137			univariate_challenge,
138		)?)
139		.upcast_arc_dyn(),
140	);
141
142	// REVIEW: all multilins are large field, we could benefit from "no switchover" constructor, but this sumcheck
143	//         is very small anyway.
144	let prover = RegularSumcheckProver::<FDomain, P, _, _, _>::new(
145		EvaluationOrder::HighToLow,
146		projected_multilinears,
147		sumcheck_claim.composite_sums().iter().cloned(),
148		IsomorphicEvaluationDomainFactory::<FDomain::Canonical>::default(),
149		immediate_switchover_heuristic,
150		backend,
151	)?;
152
153	Ok(prover)
154}
155
156/// Prove a batched zerocheck protocol execution.
157///
158/// See the [`batch_verify_zerocheck`](`super::super::batch_verify_zerocheck`) docstring for
159/// a detailed description of the zerocheck reduction stages. The `provers` in this invocation
160/// should be provided in the same order as the corresponding claims during verification.
161///
162/// Zerocheck challenges (`max_n_vars - skip_rounds` of them) are to be sampled right before this
163/// call and used for [`ZerocheckProver`] instances creation (most likely via calls to
164/// [`ZerocheckProverImpl::new`](`super::zerocheck::ZerocheckProverImpl::new`))
165#[allow(clippy::type_complexity)]
166pub fn batch_prove<'a, F, FDomain, P, Prover, Challenger_>(
167	mut provers: Vec<Prover>,
168	skip_rounds: usize,
169	transcript: &mut ProverTranscript<Challenger_>,
170) -> Result<BatchZerocheckOutput<P::Scalar>, Error>
171where
172	F: TowerField + ExtensionField<FDomain>,
173	FDomain: TowerField,
174	P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
175	Prover: ZerocheckProver<'a, P>,
176	Challenger_: Challenger,
177{
178	// Check that the provers are in non-descending order by n_vars
179	if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) {
180		bail!(Error::ClaimsOutOfOrder);
181	}
182
183	let max_domain_size = provers
184		.iter()
185		.map(|prover| {
186			prover
187				.domain_size(skip_rounds)
188				.expect("domain size must be known")
189		})
190		.max()
191		.unwrap_or(0);
192
193	// Sample batching coefficients while computing round polynomials per claim, then batch
194	// those in Lagrange domain.
195	let mut batch_coeffs = Vec::with_capacity(provers.len());
196	let mut round_evals =
197		ZerocheckRoundEvals::zeros(max_domain_size.saturating_sub(1 << skip_rounds));
198	for prover in &mut provers {
199		let next_batch_coeff = transcript.sample();
200		batch_coeffs.push(next_batch_coeff);
201
202		let prover_round_evals =
203			prover.execute_univariate_round(skip_rounds, max_domain_size, next_batch_coeff)?;
204
205		round_evals.add_assign_lagrange(&(prover_round_evals * next_batch_coeff))?;
206	}
207
208	// Sample univariate challenge
209	transcript.message().write_scalar_slice(&round_evals.evals);
210	let univariate_challenge = transcript.sample();
211
212	// Prove reduced multilinear eq-ind sumchecks, high-to-low, with front-loaded batching
213	let mut sumcheck_provers = Vec::with_capacity(provers.len());
214	for prover in &mut provers {
215		let sumcheck_prover = prover.fold_univariate_round(univariate_challenge)?;
216		sumcheck_provers.push(sumcheck_prover);
217	}
218
219	let regular_sumcheck_prover =
220		front_loaded::BatchProver::new_prebatched(batch_coeffs, sumcheck_provers)?;
221
222	let BatchSumcheckOutput {
223		challenges: mut unskipped_challenges,
224		multilinear_evals: mut univariatized_multilinear_evals,
225	} = regular_sumcheck_prover.run(transcript)?;
226
227	// Reverse challenges since folding high-to-low
228	unskipped_challenges.reverse();
229
230	// Drop equality indicator evals prior to univariatizing reduction
231	for evals in &mut univariatized_multilinear_evals {
232		evals
233			.pop()
234			.expect("equality indicator evaluation at last position");
235	}
236
237	// Project witness multilinears to "skipped" variables
238	let mut projected_multilinears = Vec::new();
239	let dimensions_data = FoldLowDimensionsData::new(skip_rounds, &provers);
240	let mle_fold_low_span = tracing::debug_span!(
241		"[task] Initial MLE Fold Low",
242		phase = "zerocheck",
243		perfetto_category = "task.main",
244		?dimensions_data,
245	)
246	.entered();
247	for prover in provers {
248		let claim_projected_multilinears =
249			Box::new(prover).project_to_skipped_variables(&unskipped_challenges)?;
250
251		projected_multilinears.extend(claim_projected_multilinears);
252	}
253	drop(mle_fold_low_span);
254
255	// Prove univariatizing reduction sumcheck.
256	// It's small (`skip_rounds` variables), so portable backend is likely fine.
257	let backend = make_portable_backend();
258	let reduction_prover = univariatizing_reduction_prover::<_, FDomain, _>(
259		projected_multilinears,
260		skip_rounds,
261		univariatized_multilinear_evals,
262		univariate_challenge,
263		&backend,
264	)?;
265
266	let batch_reduction_prover =
267		front_loaded::BatchProver::new(vec![reduction_prover], transcript)?;
268
269	let BatchSumcheckOutput {
270		challenges: mut skipped_challenges,
271		multilinear_evals: mut concat_multilinear_evals,
272	} = batch_reduction_prover.run(transcript)?;
273
274	// Reverse challenges since folding high-to-low
275	skipped_challenges.reverse();
276
277	let mut concat_multilinear_evals = concat_multilinear_evals
278		.pop()
279		.expect("multilinear_evals.len() == 1");
280
281	concat_multilinear_evals
282		.pop()
283		.expect("Lagrange coefficients MLE eval at last position");
284
285	// Fin
286	let output = BatchZerocheckOutput {
287		skipped_challenges,
288		unskipped_challenges,
289		concat_multilinear_evals,
290	};
291
292	Ok(output)
293}