binius_core/protocols/gkr_gpa/
prove.rsuse super::{
gkr_gpa::{GrandProductBatchProveOutput, LayerClaim},
gpa_sumcheck::prove::GPAProver,
Error, GrandProductBatchProof, GrandProductClaim, GrandProductWitness,
};
use crate::{
polynomial::MultilinearExtension, protocols::sumcheck_v2, witness::MultilinearWitness,
};
use binius_field::{
ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
};
use binius_hal::ComputationBackend;
use binius_math::{extrapolate_line_scalar, EvaluationDomainFactory};
use binius_utils::{
bail,
sorting::{stable_sort, unsort},
};
use p3_challenger::{CanObserve, CanSample};
use std::sync::Arc;
use tracing::instrument;
#[instrument(skip_all, name = "gkr_gpa::batch_prove", level = "debug")]
pub fn batch_prove<'a, F, P, FDomain, Challenger, Backend>(
witnesses: impl IntoIterator<Item = GrandProductWitness<'a, P>>,
claims: &[GrandProductClaim<F>],
evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
mut challenger: Challenger,
backend: Backend,
) -> Result<GrandProductBatchProveOutput<F>, Error>
where
F: TowerField,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
FDomain: Field,
P::Scalar: Field + ExtensionField<FDomain>,
Challenger: CanSample<F> + CanObserve<F>,
Backend: ComputationBackend,
{
let witness_vec = witnesses.into_iter().collect::<Vec<_>>();
let n_claims = claims.len();
if n_claims == 0 {
return Ok(GrandProductBatchProveOutput::default());
}
if witness_vec.len() != n_claims {
bail!(Error::MismatchedWitnessClaimLength);
}
let provers_vec = witness_vec
.into_iter()
.zip(claims)
.map(|(witness, claim)| GrandProductProverState::new(claim, witness, backend.clone()))
.collect::<Result<Vec<_>, _>>()?;
let (original_indices, mut sorted_provers) =
stable_sort(provers_vec, |prover| prover.input_vars(), true);
let max_n_vars = sorted_provers
.first()
.expect("sorted_provers is not empty by invariant")
.input_vars();
let mut batch_layer_proofs = Vec::with_capacity(max_n_vars);
let mut reverse_sorted_final_layer_claims = Vec::with_capacity(n_claims);
for layer_no in 0..max_n_vars {
process_finished_provers(
layer_no,
&mut sorted_provers,
&mut reverse_sorted_final_layer_claims,
)?;
let (gpa_sumcheck_batch_proof, sumcheck_challenge) = {
let stage_gpa_sumcheck_provers = sorted_provers
.iter_mut()
.map(|p| p.stage_gpa_sumcheck_prover(evaluation_domain_factory.clone()))
.collect::<Result<Vec<_>, _>>()?;
let (batch_sumcheck_output, proof) =
sumcheck_v2::batch_prove(stage_gpa_sumcheck_provers, &mut challenger)?;
let sumcheck_challenge = batch_sumcheck_output.challenges;
(proof, sumcheck_challenge)
};
let gpa_challenge = challenger.sample();
for (i, prover) in sorted_provers.iter_mut().enumerate() {
prover.finalize_batch_layer_proof(
gpa_sumcheck_batch_proof.multilinear_evals[i][0],
gpa_sumcheck_batch_proof.multilinear_evals[i][1],
sumcheck_challenge.clone(),
gpa_challenge,
)?;
}
batch_layer_proofs.push(gpa_sumcheck_batch_proof);
}
process_finished_provers(
max_n_vars,
&mut sorted_provers,
&mut reverse_sorted_final_layer_claims,
)?;
debug_assert!(sorted_provers.is_empty());
debug_assert_eq!(reverse_sorted_final_layer_claims.len(), n_claims);
reverse_sorted_final_layer_claims.reverse();
let sorted_final_layer_claim = reverse_sorted_final_layer_claims;
let final_layer_claims = unsort(original_indices, sorted_final_layer_claim);
Ok(GrandProductBatchProveOutput {
final_layer_claims,
proof: GrandProductBatchProof { batch_layer_proofs },
})
}
fn process_finished_provers<F, P, Backend>(
layer_no: usize,
sorted_provers: &mut Vec<GrandProductProverState<'_, F, P, Backend>>,
reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
) -> Result<(), Error>
where
P: PackedFieldIndexable<Scalar = F>,
F: Field + From<P::Scalar>,
P::Scalar: Field + From<F>,
Backend: ComputationBackend,
{
while let Some(prover) = sorted_provers.last() {
if prover.input_vars() != layer_no {
break;
}
debug_assert!(layer_no > 0);
let finished_prover = sorted_provers.pop().expect("not empty");
let final_layer_claim = finished_prover.finalize()?;
reverse_sorted_final_layer_claims.push(final_layer_claim);
}
Ok(())
}
#[derive(Debug)]
struct GrandProductProverState<'a, F, P, Backend>
where
F: Field + From<P::Scalar>,
P: PackedField,
P::Scalar: Field + From<F>,
Backend: ComputationBackend,
{
n_vars: usize,
layers: Vec<MultilinearWitness<'a, P>>,
next_layer_halves: Vec<[MultilinearWitness<'a, P>; 2]>,
current_layer_claim: LayerClaim<F>,
backend: Backend,
}
impl<'a, F, P, Backend> GrandProductProverState<'a, F, P, Backend>
where
F: Field + From<P::Scalar>,
P: PackedFieldIndexable<Scalar = F>,
P::Scalar: Field + From<F>,
Backend: ComputationBackend,
{
fn new(
claim: &GrandProductClaim<F>,
witness: GrandProductWitness<'a, P>,
backend: Backend,
) -> Result<Self, Error> {
let n_vars = claim.n_vars;
if n_vars != witness.n_vars() || witness.grand_product_evaluation() != claim.product {
bail!(Error::ProverClaimWitnessMismatch);
}
let n_layers = n_vars + 1;
let next_layer_halves = (1..n_layers)
.map(|i| {
let (left_evals, right_evals) = witness.ith_layer_eval_halves(i)?;
let left = MultilinearExtension::from_values_generic(Arc::from(left_evals))?
.specialize_arc_dyn();
let right = MultilinearExtension::from_values_generic(Arc::from(right_evals))?
.specialize_arc_dyn();
Ok([left, right])
})
.collect::<Result<Vec<_>, Error>>()?;
let layers = (0..n_layers)
.map(|i| {
let ith_layer_evals = witness.ith_layer_evals(i)?;
let mle = MultilinearExtension::from_values_generic(Arc::from(ith_layer_evals))?
.specialize_arc_dyn();
Ok(mle)
})
.collect::<Result<Vec<_>, Error>>()?;
debug_assert_eq!(next_layer_halves.len(), n_vars);
debug_assert_eq!(layers.len(), n_vars + 1);
let layer_claim = LayerClaim {
eval_point: vec![],
eval: claim.product,
};
Ok(Self {
n_vars,
next_layer_halves,
layers,
current_layer_claim: layer_claim,
backend,
})
}
fn input_vars(&self) -> usize {
self.n_vars
}
fn current_layer_no(&self) -> usize {
self.current_layer_claim.eval_point.len()
}
#[allow(clippy::type_complexity)]
fn stage_gpa_sumcheck_prover<FDomain>(
&self,
evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
) -> Result<GPAProver<FDomain, P, MultilinearWitness<'a, P>, Backend>, Error>
where
FDomain: Field,
P: PackedExtension<FDomain>,
F: ExtensionField<FDomain>,
{
if self.current_layer_no() >= self.input_vars() {
bail!(Error::TooManyRounds);
}
let current_layer = self.layers[self.current_layer_no()].clone();
let multilinears = self.next_layer_halves[self.current_layer_no()].clone();
GPAProver::new(
multilinears,
current_layer,
self.current_layer_claim.eval,
evaluation_domain_factory,
&self.current_layer_claim.eval_point,
self.backend.clone(),
)
.map_err(|e| e.into())
}
fn finalize_batch_layer_proof(
&mut self,
zero_eval: F,
one_eval: F,
sumcheck_challenge: Vec<F>,
gpa_challenge: F,
) -> Result<(), Error> {
if self.current_layer_no() >= self.input_vars() {
bail!(Error::TooManyRounds);
}
let new_eval = extrapolate_line_scalar(zero_eval, one_eval, gpa_challenge);
let mut layer_challenge = sumcheck_challenge;
layer_challenge.push(gpa_challenge);
self.current_layer_claim = LayerClaim {
eval_point: layer_challenge,
eval: new_eval,
};
Ok(())
}
fn finalize(self) -> Result<LayerClaim<F>, Error> {
if self.current_layer_no() != self.input_vars() {
bail!(Error::PrematureFinalize);
}
let final_layer_claim = LayerClaim {
eval_point: self.current_layer_claim.eval_point,
eval: self.current_layer_claim.eval,
};
Ok(final_layer_claim)
}
}