binius_core/protocols/sumcheck/prove/
batch_prove_univariate_zerocheck.rsuse binius_field::{Field, TowerField};
use binius_utils::{bail, sorting::is_sorted_ascending};
use tracing::instrument;
use crate::{
fiat_shamir::CanSample,
protocols::sumcheck::{
prove::{batch_prove::BatchProveStart, SumcheckProver},
univariate::LagrangeRoundEvals,
Error,
},
transcript::CanWrite,
};
pub trait UnivariateZerocheckProver<'a, F: Field> {
fn n_vars(&self) -> usize;
fn domain_size(&self, skip_rounds: usize) -> usize;
fn execute_univariate_round(
&mut self,
skip_rounds: usize,
max_domain_size: usize,
batch_coeff: F,
) -> Result<LagrangeRoundEvals<F>, Error>;
fn fold_univariate_round(
self: Box<Self>,
challenge: F,
) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error>;
}
impl<'a, F: Field, Prover: UnivariateZerocheckProver<'a, F> + ?Sized>
UnivariateZerocheckProver<'a, F> for Box<Prover>
{
fn n_vars(&self) -> usize {
(**self).n_vars()
}
fn domain_size(&self, skip_rounds: usize) -> usize {
(**self).domain_size(skip_rounds)
}
fn execute_univariate_round(
&mut self,
skip_rounds: usize,
max_domain_size: usize,
batch_coeff: F,
) -> Result<LagrangeRoundEvals<F>, Error> {
(**self).execute_univariate_round(skip_rounds, max_domain_size, batch_coeff)
}
fn fold_univariate_round(
self: Box<Self>,
challenge: F,
) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
(*self).fold_univariate_round(challenge)
}
}
#[derive(Debug)]
pub struct BatchZerocheckUnivariateProveOutput<F: Field, Prover> {
pub univariate_challenge: F,
pub batch_prove_start: BatchProveStart<F, Prover>,
}
#[allow(clippy::type_complexity)]
#[instrument(skip_all, level = "debug")]
pub fn batch_prove_zerocheck_univariate_round<'a, F, Prover, Transcript>(
mut provers: Vec<Prover>,
skip_rounds: usize,
mut transcript: Transcript,
) -> Result<BatchZerocheckUnivariateProveOutput<F, Box<dyn SumcheckProver<F> + 'a>>, Error>
where
F: TowerField,
Prover: UnivariateZerocheckProver<'a, F>,
Transcript: CanSample<F> + CanWrite,
{
if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
bail!(Error::ClaimsOutOfOrder);
}
let max_n_vars = provers.first().map(|prover| prover.n_vars()).unwrap_or(0);
let min_n_vars = provers.last().map(|prover| prover.n_vars()).unwrap_or(0);
if max_n_vars - min_n_vars > skip_rounds {
bail!(Error::TooManySkippedRounds);
}
let max_domain_size = provers
.iter()
.map(|prover| prover.domain_size(skip_rounds + prover.n_vars() - max_n_vars))
.max()
.unwrap_or(0);
let mut batch_coeffs = Vec::with_capacity(provers.len());
let mut round_evals = LagrangeRoundEvals::zeros(max_domain_size);
for prover in provers.iter_mut() {
let next_batch_coeff = transcript.sample();
batch_coeffs.push(next_batch_coeff);
let prover_round_evals = prover.execute_univariate_round(
skip_rounds + prover.n_vars() - max_n_vars,
max_domain_size,
next_batch_coeff,
)?;
round_evals.add_assign_lagrange(&(prover_round_evals * next_batch_coeff))?;
}
let zeros_prefix_len = (1 << (skip_rounds + min_n_vars - max_n_vars)).min(max_domain_size);
if zeros_prefix_len != round_evals.zeros_prefix_len {
bail!(Error::IncorrectZerosPrefixLen);
}
transcript.write_scalar_slice(&round_evals.evals);
let univariate_challenge = transcript.sample();
let mut reduction_provers = Vec::with_capacity(provers.len());
for prover in provers {
let regular_prover = Box::new(prover).fold_univariate_round(univariate_challenge)?;
reduction_provers.push(regular_prover);
}
let batch_prove_start = BatchProveStart {
batch_coeffs,
reduction_provers,
};
let output = BatchZerocheckUnivariateProveOutput {
univariate_challenge,
batch_prove_start,
};
Ok(output)
}