binius_core/protocols/gkr_exp/
oracles.rs1use binius_field::{BinaryField, PackedField, TowerField};
4use binius_math::{MultilinearPoly, MultilinearQuery, MultilinearQueryRef};
5use binius_maybe_rayon::prelude::*;
6use binius_utils::bail;
7use itertools::izip;
8use tracing::instrument;
9
10use super::{error::Error, BaseExpReductionOutput, BaseExpWitness, ExpClaim};
11use crate::{
12 oracle::{MultilinearOracleSet, OracleId},
13 protocols::{evalcheck::EvalcheckMultilinearClaim, gkr_exp::LayerClaim},
14};
15
16pub fn get_evals_in_point_from_witnesses<P>(
17 witnesses: &[BaseExpWitness<P>],
18 eval_point: &[P::Scalar],
19) -> Result<Vec<P::Scalar>, Error>
20where
21 P: PackedField,
22 P::Scalar: BinaryField,
23{
24 witnesses
25 .into_par_iter()
26 .map(|witness| {
27 let query = MultilinearQuery::expand(&eval_point[0..witness.n_vars()]);
28
29 witness
30 .exponentiation_result_witness()
31 .evaluate(MultilinearQueryRef::new(&query))
32 .map_err(Error::from)
33 })
34 .collect::<Result<Vec<_>, Error>>()
35}
36
37pub fn construct_gkr_exp_claims<F>(
38 exponents_ids: &[Vec<OracleId>],
39 evals: &[F],
40 static_bases: Vec<Option<F>>,
41 oracles: &MultilinearOracleSet<F>,
42 eval_point: &[F],
43) -> Result<Vec<ExpClaim<F>>, Error>
44where
45 F: TowerField,
46{
47 for exponent_ids in exponents_ids {
48 if exponent_ids.is_empty() {
49 bail!(Error::EmptyExp)
50 }
51 }
52
53 let claims = izip!(exponents_ids, evals, static_bases)
54 .map(|(exponents_ids, &eval, static_base)| {
55 let id = *exponents_ids.last().expect("exponents_ids not empty");
56 let n_vars = oracles.n_vars(id);
57
58 ExpClaim {
59 eval_point: eval_point[..n_vars].to_vec(),
60 eval,
61 exponent_bit_width: exponents_ids.len(),
62 n_vars,
63 static_base,
64 }
65 })
66 .collect::<Vec<_>>();
67
68 Ok(claims)
69}
70
71#[instrument(skip_all, level = "debug")]
72pub fn make_eval_claims<F: TowerField>(
73 metas: Vec<Vec<OracleId>>,
74 mut base_exp_output: BaseExpReductionOutput<F>,
75 dynamic_base_ids: Vec<Option<OracleId>>,
76) -> Result<Vec<EvalcheckMultilinearClaim<F>>, Error> {
77 let max_exponent_bit_number = metas.iter().map(|meta| meta.len()).max().unwrap_or(0);
78
79 let mut evalcheck_claims = Vec::new();
80
81 for layer_no in 0..max_exponent_bit_number {
82 for (&dynamic_base, meta) in dynamic_base_ids.iter().zip(&metas).rev() {
83 if layer_no > meta.len() - 1 {
84 continue;
85 }
86
87 let LayerClaim { eval_point, eval } = base_exp_output.layers_claims[layer_no]
88 .pop()
89 .ok_or(Error::MetasClaimMismatch)?;
90
91 if let Some(base_id) = dynamic_base {
92 let base_claim = EvalcheckMultilinearClaim {
93 id: base_id,
94 eval_point: eval_point.into(),
95 eval,
96 };
97
98 let LayerClaim { eval_point, eval } = base_exp_output.layers_claims[layer_no]
99 .pop()
100 .ok_or(Error::MetasClaimMismatch)?;
101
102 let exponent_bit_id = meta[layer_no];
103
104 let exponent_bit_claim = EvalcheckMultilinearClaim {
105 id: exponent_bit_id,
106 eval_point: eval_point.into(),
107 eval,
108 };
109
110 evalcheck_claims.push(exponent_bit_claim);
111 evalcheck_claims.push(base_claim);
112 } else {
113 let exponent_bit_id = meta[meta.len() - 1 - layer_no];
114
115 let exponent_bit_claim = EvalcheckMultilinearClaim {
116 id: exponent_bit_id,
117 eval_point: eval_point.into(),
118 eval,
119 };
120 evalcheck_claims.push(exponent_bit_claim);
121 }
122 }
123 }
124
125 Ok(evalcheck_claims)
126}