binius_core/protocols/gkr_gpa/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_field::{Field, PackedField, TowerField};
6use binius_utils::bail;
7use tracing::instrument;
8
9use super::{gkr_gpa::LayerClaim, Error, GrandProductClaim, GrandProductWitness};
10use crate::{
11	oracle::{MultilinearOracleSet, OracleId},
12	protocols::evalcheck::EvalcheckMultilinearClaim,
13	witness::MultilinearExtensionIndex,
14};
15
16#[instrument(skip_all, level = "debug")]
17pub fn construct_grand_product_witnesses<P>(
18	ids: &[OracleId],
19	witness_index: &MultilinearExtensionIndex<P>,
20) -> Result<Vec<GrandProductWitness<P>>, Error>
21where
22	P: PackedField,
23{
24	ids.iter()
25		.map(|id| {
26			witness_index
27				.get_multilin_poly(*id)
28				.map_err(|e| e.into())
29				.and_then(GrandProductWitness::new)
30		})
31		.collect::<Result<Vec<_>, _>>()
32}
33
34pub fn get_grand_products_from_witnesses<PW, F>(witnesses: &[GrandProductWitness<PW>]) -> Vec<F>
35where
36	PW: PackedField<Scalar: Into<F>>,
37	F: Field,
38{
39	witnesses
40		.iter()
41		.map(|witness| witness.grand_product_evaluation().into())
42		.collect::<Vec<_>>()
43}
44
45pub fn construct_grand_product_claims<F>(
46	ids: &[OracleId],
47	oracles: &MultilinearOracleSet<F>,
48	products: &[F],
49) -> Result<Vec<GrandProductClaim<F>>, Error>
50where
51	F: TowerField,
52{
53	if ids.len() != products.len() {
54		bail!(Error::MetasProductsMismatch);
55	}
56
57	Ok(iter::zip(ids, products)
58		.map(|(id, product)| GrandProductClaim {
59			n_vars: oracles.n_vars(*id),
60			product: *product,
61		})
62		.collect::<Vec<_>>())
63}
64
65#[instrument(skip_all, level = "debug")]
66pub fn make_eval_claims<F: TowerField>(
67	metas: impl IntoIterator<Item = OracleId>,
68	final_layer_claims: impl IntoIterator<IntoIter: ExactSizeIterator<Item = LayerClaim<F>>>,
69) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
70	let metas = metas.into_iter().collect::<Vec<_>>();
71
72	let final_layer_claims = final_layer_claims.into_iter();
73	if metas.len() != final_layer_claims.len() {
74		bail!(Error::MetasClaimMismatch);
75	}
76
77	Ok(iter::zip(metas, final_layer_claims)
78		.map(|(oracle_id, claim)| EvalcheckMultilinearClaim {
79			id: oracle_id,
80			eval_point: claim.eval_point.into(),
81			eval: claim.eval,
82		})
83		.collect::<Vec<_>>())
84}