use super::lasso::{
reduce_lasso_claim, LassoBatches, LassoClaim, LassoProof, LassoProveOutput,
LassoReducedClaimOracleIds, LassoWitness,
};
use crate::{
oracle::OracleId,
polynomial::Error as PolynomialError,
protocols::{
gkr_gpa::{GrandProductClaim, GrandProductWitness},
lasso::Error,
},
};
use crate::{
oracle::MultilinearOracleSet,
witness::{MultilinearExtensionIndex, MultilinearWitness},
};
use binius_field::{
as_packed_field::{PackScalar, PackedType},
underlier::{UnderlierType, WithUnderlier},
ExtensionField, Field, PackedField, PackedFieldIndexable, TowerField,
};
use binius_hal::ComputationBackend;
use binius_utils::bail;
use itertools::izip;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use std::{array, sync::Arc};
use tracing::instrument;
#[allow(clippy::too_many_arguments)]
#[instrument(skip_all, name = "lasso::prove", level = "debug")]
pub fn prove<'a, FC, U, F, FW, L, Backend>(
oracles: &mut MultilinearOracleSet<F>,
witness_index: MultilinearExtensionIndex<'a, U, FW>,
lasso_claim: &LassoClaim<F>,
lasso_witness: LassoWitness<'a, PackedType<U, FW>, L>,
lasso_batches: &LassoBatches,
gamma: F,
alpha: F,
backend: Backend,
) -> Result<LassoProveOutput<'a, U, FW, F>, Error>
where
U: UnderlierType + PackScalar<FW> + PackScalar<FC>,
FC: TowerField,
PackedType<U, FC>: PackedFieldIndexable,
PackedType<U, FW>: PackedFieldIndexable,
F: TowerField + From<FW> + ExtensionField<FC>,
FW: TowerField + ExtensionField<FC> + From<F>,
L: AsRef<[usize]>,
Backend: ComputationBackend + 'static,
{
let t_n_vars = lasso_claim.t_oracle().n_vars();
if lasso_claim.u_oracles().len() != lasso_witness.u_polynomials().len() {
bail!(Error::ClaimWitnessTablesLenMismatch);
}
let bit_packing_log_width = PackedType::<U, FC>::LOG_WIDTH;
let mut final_counts_underlier_vecs: [_; 2] =
array::from_fn(|_| vec![U::default(); 1 << (t_n_vars - bit_packing_log_width)]);
let [final_counts, ones_repeating] = final_counts_underlier_vecs.each_mut().map(|underliers| {
let packed_slice = PackedType::<U, FC>::from_underliers_ref_mut(underliers.as_mut_slice());
PackedType::<U, FC>::unpack_scalars_mut(packed_slice)
});
final_counts.fill(FC::ONE);
ones_repeating.fill(FC::ONE);
let common_counts_len = lasso_claim
.u_oracles()
.iter()
.map(|oracle| 1 << oracle.n_vars())
.sum::<usize>();
if common_counts_len >= 1 << FC::N_BITS {
bail!(Error::LassoCountTypeTooSmall);
}
let (gkr_claim_oracle_ids, reduced_claim_oracle_ids) =
reduce_lasso_claim::<FC, _, _>(oracles, lasso_claim, lasso_batches, gamma, alpha, backend)?;
let LassoReducedClaimOracleIds {
ones_repeating_oracle_id,
mixed_t_final_counts_oracle_id,
mixed_t_one_oracle_id,
mixed_u_counts_oracle_ids,
mixed_u_counts_plus_one_oracle_ids,
} = reduced_claim_oracle_ids;
let mut witness_index = witness_index;
let alpha_gen = alpha * FC::MULTIPLICATIVE_GENERATOR;
for (i, (u_polynomial, u_to_t_mapping)) in lasso_witness
.u_polynomials()
.iter()
.zip(lasso_witness.u_to_t_mappings())
.enumerate()
{
let u_n_vars = u_polynomial.n_vars();
let mut counts_underlier_vec = vec![U::default(); 1 << (u_n_vars - bit_packing_log_width)];
let counts = {
let packed_slice =
PackedType::<U, FC>::from_underliers_ref_mut(counts_underlier_vec.as_mut_slice());
PackedType::<U, FC>::unpack_scalars_mut(packed_slice)
};
let t_indice = u_to_t_mapping.as_ref();
for (&t_index, counts) in izip!(t_indice, counts) {
let count = final_counts[t_index];
final_counts[t_index] = count * FC::MULTIPLICATIVE_GENERATOR;
*counts = count;
}
let counts = {
let packed_slice =
PackedType::<U, FC>::from_underliers_ref_mut(counts_underlier_vec.as_mut_slice());
PackedType::<U, FC>::unpack_scalars_mut(packed_slice)
};
let mixed_u_counts = lincom::<U, FC, FW, _, F>(u_polynomial, counts, gamma, alpha)?;
let mixed_u_counts_plus_one = lincom(u_polynomial, counts, gamma, alpha_gen)?;
witness_index = witness_index.update_owned::<FW, _>([
(mixed_u_counts_oracle_ids[i], mixed_u_counts),
(mixed_u_counts_plus_one_oracle_ids[i], mixed_u_counts_plus_one),
])?;
witness_index = witness_index
.update_owned::<FC, _>([(lasso_batches.counts[i], counts_underlier_vec)])?;
}
let mixed_t_final_counts = lincom(lasso_witness.t_polynomial(), final_counts, gamma, alpha)?;
let mixed_t_ones = lincom(lasso_witness.t_polynomial(), ones_repeating, gamma, alpha)?;
let [final_counts_underlier_vecs, ones_repeating] = final_counts_underlier_vecs;
witness_index = witness_index.update_owned::<FC, _>([
(lasso_batches.final_counts, final_counts_underlier_vecs),
(ones_repeating_oracle_id, ones_repeating),
])?;
witness_index = witness_index.update_owned::<FW, _>([
(mixed_t_final_counts_oracle_id, mixed_t_final_counts),
(mixed_t_one_oracle_id, mixed_t_ones),
])?;
let left_grand_product_witness_claims =
gkr_product_witness_claims(&gkr_claim_oracle_ids.left, &witness_index, oracles)?;
let right_grand_product_witness_claims =
gkr_product_witness_claims(&gkr_claim_oracle_ids.right, &witness_index, oracles)?;
let counts_grand_product_witness_claims =
gkr_product_witness_claims(&gkr_claim_oracle_ids.counts, &witness_index, oracles)?;
let left_product: F = left_grand_product_witness_claims
.grand_products
.iter()
.product();
let right_product: F = right_grand_product_witness_claims
.grand_products
.iter()
.product();
if left_product != right_product {
bail!(Error::ProductsDiffer);
}
if counts_grand_product_witness_claims
.grand_products
.iter()
.any(|count| *count == F::ZERO)
{
bail!(Error::ZeroCounts);
}
let lasso_proof = LassoProof {
left_grand_products: left_grand_product_witness_claims.grand_products,
right_grand_products: right_grand_product_witness_claims.grand_products,
counts_grand_products: counts_grand_product_witness_claims.grand_products,
};
let reduced_gpa_claims = [
left_grand_product_witness_claims.gpa_claims,
right_grand_product_witness_claims.gpa_claims,
counts_grand_product_witness_claims.gpa_claims,
]
.concat();
let reduced_gpa_witnesses = [
left_grand_product_witness_claims.gpa_witnesses,
right_grand_product_witness_claims.gpa_witnesses,
counts_grand_product_witness_claims.gpa_witnesses,
]
.concat();
Ok(LassoProveOutput {
reduced_gpa_claims,
reduced_gpa_witnesses,
lasso_proof,
witness_index,
})
}
fn lincom<U, FC, FW, PW, F>(
trace: &MultilinearWitness<PW>,
counts: &[FC],
gamma: F,
alpha: F,
) -> Result<Arc<[U]>, Error>
where
U: UnderlierType + PackScalar<FW>,
PackedType<U, FW>: PackedFieldIndexable,
PW: PackedField<Scalar = FW>,
FW: Field + From<F>,
F: Field + ExtensionField<FC>,
FC: Field,
{
let n_vars = trace.n_vars();
let packing_log_width = PackedType::<U, FW>::LOG_WIDTH;
let mut underliers = vec![U::default(); 1 << (n_vars - packing_log_width)];
let values = PackedType::<U, FW>::unpack_scalars_mut(
PackedType::<U, FW>::from_underliers_ref_mut(underliers.as_mut_slice()),
);
values.par_iter_mut().enumerate().for_each(|(i, values_i)| {
let res = alpha * counts[i] + gamma;
*values_i = FW::from(res);
});
values.par_iter_mut().enumerate().try_for_each(
|(i, values_i)| -> Result<_, PolynomialError> {
*values_i += trace.evaluate_on_hypercube(i)?;
Ok(())
},
)?;
Ok(underliers.into())
}
#[derive(Debug, Default)]
struct GrandProductWitnessClaim<'a, U, FW, F>
where
U: UnderlierType + PackScalar<FW>,
F: TowerField + From<FW>,
FW: Field,
{
grand_products: Vec<F>,
gpa_witnesses: Vec<GrandProductWitness<'a, PackedType<U, FW>>>,
gpa_claims: Vec<GrandProductClaim<F>>,
}
fn gkr_product_witness_claims<'a, U, F, FW>(
ids: &[OracleId],
witness_index: &MultilinearExtensionIndex<'a, U, FW>,
oracles: &MultilinearOracleSet<F>,
) -> Result<GrandProductWitnessClaim<'a, U, FW, F>, Error>
where
U: UnderlierType + PackScalar<FW>,
F: TowerField + From<FW>,
FW: Field,
{
let mut grand_product_witness_claims = GrandProductWitnessClaim::default();
for id in ids {
let poly = witness_index.get_multilin_poly(*id)?;
let oracle = oracles.oracle(*id);
let gpa_witness = GrandProductWitness::new(poly)?;
let grand_product = gpa_witness.grand_product_evaluation().into();
let gpa_claim = GrandProductClaim {
poly: oracle,
product: grand_product,
};
grand_product_witness_claims
.grand_products
.push(grand_product);
grand_product_witness_claims.gpa_witnesses.push(gpa_witness);
grand_product_witness_claims.gpa_claims.push(gpa_claim);
}
Ok(grand_product_witness_claims)
}