binius_core/protocols/gkr_gpa/
oracles.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_field::{
6	as_packed_field::{PackScalar, PackedType},
7	underlier::UnderlierType,
8	Field, PackedField, TowerField,
9};
10use binius_utils::bail;
11use tracing::instrument;
12
13use super::{gkr_gpa::LayerClaim, Error, GrandProductClaim, GrandProductWitness};
14use crate::{
15	oracle::{MultilinearOracleSet, OracleId},
16	protocols::evalcheck::EvalcheckMultilinearClaim,
17	witness::MultilinearExtensionIndex,
18};
19
20#[instrument(skip_all, level = "debug")]
21pub fn construct_grand_product_witnesses<U, F>(
22	ids: &[OracleId],
23	witness_index: &MultilinearExtensionIndex<U, F>,
24) -> Result<Vec<GrandProductWitness<PackedType<U, F>>>, Error>
25where
26	U: UnderlierType + PackScalar<F>,
27	F: Field,
28{
29	ids.iter()
30		.map(|id| {
31			witness_index
32				.get_multilin_poly(*id)
33				.map_err(|e| e.into())
34				.and_then(GrandProductWitness::new)
35		})
36		.collect::<Result<Vec<_>, _>>()
37}
38
39pub fn get_grand_products_from_witnesses<PW, F>(witnesses: &[GrandProductWitness<PW>]) -> Vec<F>
40where
41	PW: PackedField<Scalar: Into<F>>,
42	F: Field,
43{
44	witnesses
45		.iter()
46		.map(|witness| witness.grand_product_evaluation().into())
47		.collect::<Vec<_>>()
48}
49
50pub fn construct_grand_product_claims<F>(
51	ids: &[OracleId],
52	oracles: &MultilinearOracleSet<F>,
53	products: &[F],
54) -> Result<Vec<GrandProductClaim<F>>, Error>
55where
56	F: TowerField,
57{
58	if ids.len() != products.len() {
59		bail!(Error::MetasProductsMismatch);
60	}
61
62	Ok(iter::zip(ids, products)
63		.map(|(id, product)| GrandProductClaim {
64			n_vars: oracles.n_vars(*id),
65			product: *product,
66		})
67		.collect::<Vec<_>>())
68}
69
70#[instrument(skip_all, level = "debug")]
71pub fn make_eval_claims<F: TowerField>(
72	metas: impl IntoIterator<Item = OracleId>,
73	final_layer_claims: impl IntoIterator<IntoIter: ExactSizeIterator<Item = LayerClaim<F>>>,
74) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
75	let metas = metas.into_iter().collect::<Vec<_>>();
76
77	let final_layer_claims = final_layer_claims.into_iter();
78	if metas.len() != final_layer_claims.len() {
79		bail!(Error::MetasClaimMismatch);
80	}
81
82	Ok(iter::zip(metas, final_layer_claims)
83		.map(|(oracle_id, claim)| EvalcheckMultilinearClaim {
84			id: oracle_id,
85			eval_point: claim.eval_point.into(),
86			eval: claim.eval,
87		})
88		.collect::<Vec<_>>())
89}