binius_core/protocols/gkr_gpa/
oracles.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright 2024 Ulvetanna Inc.

use std::iter;

use super::{gkr_gpa::LayerClaim, Error, GrandProductClaim, GrandProductWitness};
use crate::{
	oracle::{MultilinearOracleSet, OracleId},
	protocols::evalcheck_v2::EvalcheckMultilinearClaim,
	witness::MultilinearExtensionIndex,
};
use binius_field::{
	as_packed_field::{PackScalar, PackedType},
	underlier::UnderlierType,
	Field, PackedField, TowerField,
};
use binius_utils::bail;

pub fn construct_grand_product_witnesses<'a, U, F>(
	ids: &[OracleId],
	witness_index: &MultilinearExtensionIndex<'a, U, F>,
) -> Result<Vec<GrandProductWitness<'a, PackedType<U, F>>>, Error>
where
	U: UnderlierType + PackScalar<F>,
	F: Field,
{
	ids.iter()
		.map(|id| {
			witness_index
				.get_multilin_poly(*id)
				.map_err(|e| e.into())
				.and_then(GrandProductWitness::new)
		})
		.collect::<Result<Vec<_>, _>>()
}

pub fn get_grand_products_from_witnesses<PW, F>(witnesses: &[GrandProductWitness<PW>]) -> Vec<F>
where
	PW: PackedField,
	F: Field + From<PW::Scalar>,
{
	witnesses
		.iter()
		.map(|witness| witness.grand_product_evaluation().into())
		.collect::<Vec<_>>()
}

pub fn construct_grand_product_claims<F>(
	ids: &[OracleId],
	oracles: &MultilinearOracleSet<F>,
	products: &[F],
) -> Result<Vec<GrandProductClaim<F>>, Error>
where
	F: TowerField,
{
	if ids.len() != products.len() {
		bail!(Error::MetasProductsMismatch);
	}

	Ok(iter::zip(ids, products)
		.map(|(id, product)| {
			let oracle = oracles.oracle(*id);
			GrandProductClaim {
				n_vars: oracle.n_vars(),
				product: *product,
			}
		})
		.collect::<Vec<_>>())
}

pub fn make_eval_claims<F: TowerField>(
	oracles: &MultilinearOracleSet<F>,
	metas: impl IntoIterator<Item = OracleId>,
	final_layer_claims: &[LayerClaim<F>],
) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
	let metas = metas.into_iter().collect::<Vec<_>>();

	if metas.len() != final_layer_claims.len() {
		bail!(Error::MetasClaimMismatch);
	}

	Ok(iter::zip(metas, final_layer_claims)
		.map(|(oracle_id, claim)| {
			let poly = oracles.oracle(oracle_id);

			EvalcheckMultilinearClaim {
				poly,
				eval_point: claim.eval_point.clone(),
				eval: claim.eval,
				is_random_point: true,
			}
		})
		.collect::<Vec<_>>())
}