use super::Error;
use crate::{
oracle::MultilinearPolyOracle,
protocols::{
abstract_sumcheck::ReducedClaim, evalcheck::EvalcheckMultilinearClaim,
gkr_sumcheck::GkrSumcheckBatchProof,
},
witness::MultilinearWitness,
};
use binius_field::{Field, PackedField};
use binius_utils::bail;
use rayon::prelude::*;
type LayerEvals<'a, FW> = &'a [FW];
type LayerHalfEvals<'a, FW> = (&'a [FW], &'a [FW]);
#[derive(Debug, Clone)]
pub struct GrandProductClaim<F: Field> {
pub poly: MultilinearPolyOracle<F>,
pub product: F,
}
#[derive(Debug, Clone)]
pub struct GrandProductWitness<'a, PW: PackedField> {
poly: MultilinearWitness<'a, PW>,
circuit_evals: Vec<Vec<PW::Scalar>>,
}
impl<'a, PW: PackedField> GrandProductWitness<'a, PW> {
pub fn new(poly: MultilinearWitness<'a, PW>) -> Result<Self, Error> {
let input_layer = (0..1 << poly.n_vars())
.into_par_iter()
.map(|i| poly.evaluate_on_hypercube(i))
.collect::<Result<Vec<_>, _>>()?;
let mut all_layers = vec![input_layer];
for curr_n_vars in (0..poly.n_vars()).rev() {
let layer_below = all_layers.last().expect("layers is not empty by invariant");
let new_layer = (0..1 << curr_n_vars)
.into_par_iter()
.map(|i| {
let left = layer_below[i];
let right = layer_below[i + (1 << curr_n_vars)];
left * right
})
.collect();
all_layers.push(new_layer);
}
all_layers.reverse();
Ok(Self {
poly,
circuit_evals: all_layers,
})
}
pub fn n_vars(&self) -> usize {
self.poly.n_vars()
}
pub fn grand_product_evaluation(&self) -> PW::Scalar {
self.circuit_evals[0][0]
}
pub fn ith_layer_evals(&self, i: usize) -> Result<LayerEvals<'_, PW::Scalar>, Error> {
let max_layer_idx = self.n_vars();
if i > max_layer_idx {
bail!(Error::InvalidLayerIndex);
}
Ok(&self.circuit_evals[i])
}
pub fn ith_layer_eval_halves(&self, i: usize) -> Result<LayerHalfEvals<'_, PW::Scalar>, Error> {
if i == 0 {
bail!(Error::CannotSplitOutputLayerIntoHalves);
}
let layer = self.ith_layer_evals(i)?;
let half = layer.len() / 2;
debug_assert_eq!(half, 1 << (i - 1));
Ok((&layer[..half], &layer[half..]))
}
}
pub type LayerClaim<F> = ReducedClaim<F>;
#[derive(Debug, Clone)]
pub struct BatchLayerProof<F: Field> {
pub gkr_sumcheck_batch_proof: GkrSumcheckBatchProof<F>,
pub zero_evals: Vec<F>,
pub one_evals: Vec<F>,
}
#[derive(Debug, Clone, Default)]
pub struct GrandProductBatchProof<F: Field> {
pub batch_layer_proofs: Vec<BatchLayerProof<F>>,
}
#[derive(Debug, Default)]
pub struct GrandProductBatchProveOutput<F: Field> {
pub evalcheck_multilinear_claims: Vec<EvalcheckMultilinearClaim<F>>,
pub proof: GrandProductBatchProof<F>,
}