binius_core/protocols/gkr_gpa/
gkr_gpa.rs1use std::slice;
4
5use binius_field::{packed::get_packed_slice, Field, PackedField};
6use binius_maybe_rayon::prelude::*;
7use binius_utils::bail;
8use bytemuck::zeroed_vec;
9use tracing::{debug_span, instrument};
10
11use super::{packed_field_storage::PackedFieldStorage, Error};
12use crate::witness::MultilinearWitness;
13
14type LayerEvals<'a, PW> = &'a [PW];
15type LayerHalfEvals<'a, PW> = (PackedFieldStorage<'a, PW>, PackedFieldStorage<'a, PW>);
16
17#[derive(Debug, Clone)]
18pub struct GrandProductClaim<F: Field> {
19 pub n_vars: usize,
20 pub product: F,
22}
23
24impl<F: Field> GrandProductClaim<F> {
25 pub fn isomorphic<FI: Field + From<F>>(self) -> GrandProductClaim<FI> {
26 GrandProductClaim {
27 n_vars: self.n_vars,
28 product: self.product.into(),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
34pub struct GrandProductWitness<PW: PackedField> {
35 n_vars: usize,
36 circuit_evals: Vec<Vec<PW>>,
37}
38
39impl<PW: PackedField> GrandProductWitness<PW> {
40 #[instrument(skip_all, level = "debug", name = "GrandProductWitness::new")]
41 pub fn new(poly: MultilinearWitness<PW>) -> Result<Self, Error> {
42 let mut input_layer = zeroed_vec(1 << poly.n_vars().saturating_sub(PW::LOG_WIDTH));
45
46 if poly.n_vars() >= PW::LOG_WIDTH {
47 const LOG_CHUNK_SIZE: usize = 12;
48 let log_chunk_size = (poly.n_vars() - PW::LOG_WIDTH).min(LOG_CHUNK_SIZE);
49 input_layer
50 .par_chunks_mut(1 << (log_chunk_size - PW::LOG_WIDTH))
51 .enumerate()
52 .for_each(|(i, chunk)| {
53 poly.subcube_evals(log_chunk_size, i, 0, chunk).expect(
54 "index is between 0 and 2^{n_vars - log_chunk_size}; \
55 log_embedding degree is 0",
56 )
57 });
58 } else {
59 poly.subcube_evals(poly.n_vars(), 0, 0, slice::from_mut(&mut input_layer[0]))
60 .expect(
61 "index is between 0 and 2^{n_vars - log_chunk_size}; log_embedding degree is 0",
62 );
63 }
64
65 let mut all_layers = vec![input_layer];
66 debug_span!("constructing_layers").in_scope(|| {
67 for curr_n_vars in (0..poly.n_vars()).rev() {
68 let layer_below = all_layers.last().expect("layers is not empty by invariant");
69 let mut new_layer = zeroed_vec(1 << curr_n_vars.saturating_sub(PW::LOG_WIDTH));
70
71 if curr_n_vars >= PW::LOG_WIDTH {
72 let (left_half, right_half) =
73 layer_below.split_at(1 << (curr_n_vars - PW::LOG_WIDTH));
74
75 new_layer
76 .par_iter_mut()
77 .zip(left_half.par_iter().zip(right_half.par_iter()))
78 .for_each(|(out_i, (left_i, right_i))| {
79 *out_i = *left_i * *right_i;
80 });
81 } else {
82 let new_layer = &mut new_layer[0];
83 let len = 1 << curr_n_vars;
84 for i in 0..len {
85 new_layer.set(
86 i,
87 get_packed_slice(layer_below, i)
88 * get_packed_slice(layer_below, len + i),
89 );
90 }
91 }
92
93 all_layers.push(new_layer);
94 }
95 });
96
97 all_layers.reverse();
99 Ok(Self {
100 n_vars: poly.n_vars(),
101 circuit_evals: all_layers,
102 })
103 }
104
105 pub const fn n_vars(&self) -> usize {
107 self.n_vars
108 }
109
110 pub fn grand_product_evaluation(&self) -> PW::Scalar {
112 self.circuit_evals[0][0].get(0)
115 }
116
117 pub fn ith_layer_evals(&self, i: usize) -> Result<LayerEvals<'_, PW>, Error> {
118 let max_layer_idx = self.n_vars();
119 if i > max_layer_idx {
120 bail!(Error::InvalidLayerIndex);
121 }
122 Ok(&self.circuit_evals[i])
123 }
124
125 pub fn ith_layer_eval_halves(&self, i: usize) -> Result<LayerHalfEvals<'_, PW>, Error> {
128 if i == 0 {
129 bail!(Error::CannotSplitOutputLayerIntoHalves);
130 }
131 let layer = self.ith_layer_evals(i)?;
132
133 if layer.len() > 1 {
134 let half = layer.len() / 2;
135 debug_assert_eq!(half << PW::LOG_WIDTH, 1 << (i - 1));
136
137 Ok((layer[..half].into(), layer[half..].into()))
138 } else {
139 let layer_size = 1 << (i - 1);
140
141 let first_half = PackedFieldStorage::new_inline(layer[0].iter().take(layer_size))?;
142 let second_half =
143 PackedFieldStorage::new_inline(layer[0].iter().skip(layer_size).take(layer_size))?;
144
145 Ok((first_half, second_half))
146 }
147 }
148}
149
150#[derive(Debug, Clone, Default)]
156pub struct LayerClaim<F: Field> {
157 pub eval_point: Vec<F>,
158 pub eval: F,
159}
160
161impl<F: Field> LayerClaim<F> {
162 pub fn isomorphic<FI: Field>(self) -> LayerClaim<FI>
163 where
164 F: Into<FI>,
165 {
166 LayerClaim {
167 eval_point: self.eval_point.into_iter().map(Into::into).collect(),
168 eval: self.eval.into(),
169 }
170 }
171}
172
173#[derive(Debug, Default)]
174pub struct GrandProductBatchProveOutput<F: Field> {
175 pub final_layer_claims: Vec<LayerClaim<F>>,
177}