binius_core/protocols/gkr_exp/
oracles.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{BinaryField, PackedField, TowerField};
4use binius_math::{MultilinearPoly, MultilinearQuery, MultilinearQueryRef};
5use binius_maybe_rayon::prelude::*;
6use binius_utils::bail;
7use itertools::izip;
8use tracing::instrument;
9
10use super::{error::Error, BaseExpReductionOutput, BaseExpWitness, ExpClaim};
11use crate::{
12	oracle::{MultilinearOracleSet, OracleId},
13	protocols::{evalcheck::EvalcheckMultilinearClaim, gkr_exp::LayerClaim},
14};
15
16pub fn get_evals_in_point_from_witnesses<P>(
17	witnesses: &[BaseExpWitness<P>],
18	eval_point: &[P::Scalar],
19) -> Result<Vec<P::Scalar>, Error>
20where
21	P: PackedField,
22	P::Scalar: BinaryField,
23{
24	witnesses
25		.into_par_iter()
26		.map(|witness| {
27			let query = MultilinearQuery::expand(&eval_point[0..witness.n_vars()]);
28
29			witness
30				.exponentiation_result_witness()
31				.evaluate(MultilinearQueryRef::new(&query))
32				.map_err(Error::from)
33		})
34		.collect::<Result<Vec<_>, Error>>()
35}
36
37pub fn construct_gkr_exp_claims<F>(
38	exponents_ids: &[Vec<OracleId>],
39	evals: &[F],
40	static_bases: Vec<Option<F>>,
41	oracles: &MultilinearOracleSet<F>,
42	eval_point: &[F],
43) -> Result<Vec<ExpClaim<F>>, Error>
44where
45	F: TowerField,
46{
47	for exponent_ids in exponents_ids {
48		if exponent_ids.is_empty() {
49			bail!(Error::EmptyExp)
50		}
51	}
52
53	let claims = izip!(exponents_ids, evals, static_bases)
54		.map(|(exponents_ids, &eval, static_base)| {
55			let id = *exponents_ids.last().expect("exponents_ids not empty");
56			let n_vars = oracles.n_vars(id);
57
58			ExpClaim {
59				eval_point: eval_point[..n_vars].to_vec(),
60				eval,
61				exponent_bit_width: exponents_ids.len(),
62				n_vars,
63				static_base,
64			}
65		})
66		.collect::<Vec<_>>();
67
68	Ok(claims)
69}
70
71#[instrument(skip_all, level = "debug")]
72pub fn make_eval_claims<F: TowerField>(
73	metas: Vec<Vec<OracleId>>,
74	mut base_exp_output: BaseExpReductionOutput<F>,
75	dynamic_base_ids: Vec<Option<OracleId>>,
76) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
77	let max_exponent_bit_number = metas.iter().map(|meta| meta.len()).max().unwrap_or(0);
78
79	let mut evalcheck_claims = Vec::new();
80
81	for layer_no in 0..max_exponent_bit_number {
82		for (&dynamic_base, meta) in dynamic_base_ids.iter().zip(&metas).rev() {
83			if layer_no > meta.len() - 1 {
84				continue;
85			}
86
87			let LayerClaim { eval_point, eval } = base_exp_output.layers_claims[layer_no]
88				.pop()
89				.ok_or(Error::MetasClaimMismatch)?;
90
91			if let Some(base_id) = dynamic_base {
92				let base_claim = EvalcheckMultilinearClaim {
93					id: base_id,
94					eval_point: eval_point.into(),
95					eval,
96				};
97
98				let LayerClaim { eval_point, eval } = base_exp_output.layers_claims[layer_no]
99					.pop()
100					.ok_or(Error::MetasClaimMismatch)?;
101
102				let exponent_bit_id = meta[layer_no];
103
104				let exponent_bit_claim = EvalcheckMultilinearClaim {
105					id: exponent_bit_id,
106					eval_point: eval_point.into(),
107					eval,
108				};
109
110				evalcheck_claims.push(exponent_bit_claim);
111				evalcheck_claims.push(base_claim);
112			} else {
113				let exponent_bit_id = meta[meta.len() - 1 - layer_no];
114
115				let exponent_bit_claim = EvalcheckMultilinearClaim {
116					id: exponent_bit_id,
117					eval_point: eval_point.into(),
118					eval,
119				};
120				evalcheck_claims.push(exponent_bit_claim);
121			}
122		}
123	}
124
125	Ok(evalcheck_claims)
126}