binius_core/protocols/sumcheck/prove/
batch_prove_univariate_zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, TowerField};
4use binius_utils::{bail, sorting::is_sorted_ascending};
5use tracing::instrument;
6
7use crate::{
8	fiat_shamir::{CanSample, Challenger},
9	protocols::sumcheck::{
10		prove::{batch_prove::BatchProveStart, SumcheckProver},
11		univariate::LagrangeRoundEvals,
12		Error,
13	},
14	transcript::ProverTranscript,
15};
16
17/// A univariate zerocheck prover interface.
18///
19/// The primary reason for providing this logic via a trait is the ability to type erase univariate
20/// round small fields, which may differ between the provers, and to decouple the batch prover implementation
21/// from the relatively complex type signatures of the individual provers.
22///
23/// The batch prover must obey a specific sequence of calls: [`Self::execute_univariate_round`]
24/// should be followed by [`Self::fold_univariate_round`]. Getters [`Self::n_vars`] and [`Self::domain_size`]
25/// are used to align claims and determine the maximal domain size, required by the Lagrange representation
26/// of the univariate round polynomial. Folding univariate round results in a [`SumcheckProver`] instance
27/// that can be driven to completion to prove the remaining multilinear rounds.
28///
29/// This trait is object-safe.
30pub trait UnivariateZerocheckProver<'a, F: Field> {
31	/// The number of variables in the multivariate polynomial.
32	fn n_vars(&self) -> usize;
33
34	/// Maximal required Lagrange domain size among compositions in this prover.
35	fn domain_size(&self, skip_rounds: usize) -> usize;
36
37	/// Computes the prover message for the univariate round as a univariate polynomial.
38	///
39	/// The prover message mixes the univariate polynomials of the underlying composites using
40	/// the same approach as [`SumcheckProver::execute`].
41	///
42	/// Unlike multilinear rounds, the returned univariate is not in monomial basis but in
43	/// Lagrange basis.
44	fn execute_univariate_round(
45		&mut self,
46		skip_rounds: usize,
47		max_domain_size: usize,
48		batch_coeff: F,
49	) -> Result<LagrangeRoundEvals<F>, Error>;
50
51	/// Folds into a regular multilinear prover for the remaining rounds.
52	fn fold_univariate_round(
53		self: Box<Self>,
54		challenge: F,
55	) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error>;
56}
57
58// NB: auto_impl does not currently handle ?Sized bound on Box<Self> receivers correctly.
59impl<'a, F: Field, Prover: UnivariateZerocheckProver<'a, F> + ?Sized>
60	UnivariateZerocheckProver<'a, F> for Box<Prover>
61{
62	fn n_vars(&self) -> usize {
63		(**self).n_vars()
64	}
65
66	fn domain_size(&self, skip_rounds: usize) -> usize {
67		(**self).domain_size(skip_rounds)
68	}
69
70	fn execute_univariate_round(
71		&mut self,
72		skip_rounds: usize,
73		max_domain_size: usize,
74		batch_coeff: F,
75	) -> Result<LagrangeRoundEvals<F>, Error> {
76		(**self).execute_univariate_round(skip_rounds, max_domain_size, batch_coeff)
77	}
78
79	fn fold_univariate_round(
80		self: Box<Self>,
81		challenge: F,
82	) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
83		(*self).fold_univariate_round(challenge)
84	}
85}
86
87#[derive(Debug)]
88pub struct BatchZerocheckUnivariateProveOutput<F: Field, Prover> {
89	pub univariate_challenge: F,
90	pub batch_prove_start: BatchProveStart<F, Prover>,
91}
92
93/// Prove a batched univariate zerocheck round.
94///
95/// Batching principle is entirely analogous to the multilinear case: all the provers are right aligned
96/// and should all "start" in the first `skip_rounds` rounds; this method fails otherwise. Reduction
97/// to remaining multilinear rounds results in provers for `n_vars - skip_rounds` rounds.
98///
99/// The provers in the `provers` parameter must in the same order as the corresponding claims
100/// provided to [`crate::protocols::sumcheck::batch_verify_zerocheck_univariate_round`] during proof
101/// verification.
102#[allow(clippy::type_complexity)]
103#[instrument(skip_all, level = "debug")]
104pub fn batch_prove_zerocheck_univariate_round<'a, F, Prover, Challenger_>(
105	mut provers: Vec<Prover>,
106	skip_rounds: usize,
107	transcript: &mut ProverTranscript<Challenger_>,
108) -> Result<BatchZerocheckUnivariateProveOutput<F, Box<dyn SumcheckProver<F> + 'a>>, Error>
109where
110	F: TowerField,
111	Prover: UnivariateZerocheckProver<'a, F>,
112	Challenger_: Challenger,
113{
114	// Check that the provers are in descending order by n_vars
115	if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
116		bail!(Error::ClaimsOutOfOrder);
117	}
118
119	let max_n_vars = provers.first().map(|prover| prover.n_vars()).unwrap_or(0);
120	let min_n_vars = provers.last().map(|prover| prover.n_vars()).unwrap_or(0);
121
122	if max_n_vars - min_n_vars > skip_rounds {
123		bail!(Error::TooManySkippedRounds);
124	}
125
126	let max_domain_size = provers
127		.iter()
128		.map(|prover| prover.domain_size(skip_rounds + prover.n_vars() - max_n_vars))
129		.max()
130		.unwrap_or(0);
131
132	let mut batch_coeffs = Vec::with_capacity(provers.len());
133	let mut round_evals = LagrangeRoundEvals::zeros(max_domain_size);
134	for prover in &mut provers {
135		let next_batch_coeff = transcript.sample();
136		batch_coeffs.push(next_batch_coeff);
137
138		let prover_round_evals = prover.execute_univariate_round(
139			skip_rounds + prover.n_vars() - max_n_vars,
140			max_domain_size,
141			next_batch_coeff,
142		)?;
143
144		round_evals.add_assign_lagrange(&(prover_round_evals * next_batch_coeff))?;
145	}
146
147	let zeros_prefix_len = (1 << (skip_rounds + min_n_vars - max_n_vars)).min(max_domain_size);
148	if zeros_prefix_len != round_evals.zeros_prefix_len {
149		bail!(Error::IncorrectZerosPrefixLen);
150	}
151
152	transcript.message().write_scalar_slice(&round_evals.evals);
153	let univariate_challenge = transcript.sample();
154
155	let mut reduction_provers = Vec::with_capacity(provers.len());
156	for prover in provers {
157		let regular_prover = Box::new(prover).fold_univariate_round(univariate_challenge)?;
158		reduction_provers.push(regular_prover);
159	}
160
161	let batch_prove_start = BatchProveStart {
162		batch_coeffs,
163		reduction_provers,
164	};
165
166	let output = BatchZerocheckUnivariateProveOutput {
167		univariate_challenge,
168		batch_prove_start,
169	};
170
171	Ok(output)
172}