binius_core/protocols/gkr_gpa/
gkr_gpa.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{packed::get_packed_slice, Field, PackedField};
4use binius_maybe_rayon::prelude::*;
5use binius_utils::bail;
6use bytemuck::zeroed_vec;
7use tracing::{debug_span, instrument};
8
9use super::Error;
10use crate::protocols::sumcheck::Error as SumcheckError;
11
12#[derive(Debug, Clone)]
13pub struct GrandProductClaim<F: Field> {
14	pub n_vars: usize,
15	pub product: F,
16}
17
18impl<F: Field> GrandProductClaim<F> {
19	pub fn isomorphic<FI: Field + From<F>>(self) -> GrandProductClaim<FI> {
20		GrandProductClaim {
21			n_vars: self.n_vars,
22			product: self.product.into(),
23		}
24	}
25}
26
27#[derive(Debug, Clone)]
28pub struct GrandProductWitness<P: PackedField> {
29	circuit_layers: Vec<Vec<P>>,
30}
31
32/// Grand product witness.
33///
34/// Constructs GKR multiplication circuit layers from an owned packed field witness,
35/// which may be shorter than `2^n_vars` scalars, in which case the absent values are
36/// assumed to be `P::Scalar::ONE`. There is a total on `n_vars + 1` layers, ordered
37/// by decreasing size, with last layer containing a single grand product scalar.
38impl<P: PackedField> GrandProductWitness<P> {
39	#[instrument(skip_all, level = "debug", name = "GrandProductWitness::new")]
40	pub fn new(n_vars: usize, input_layer: Vec<P>) -> Result<Self, Error> {
41		if input_layer.len() > 1 << n_vars.saturating_sub(P::LOG_WIDTH) {
42			bail!(SumcheckError::NumberOfVariablesMismatch);
43		}
44
45		let mut circuit_layers = Vec::with_capacity(n_vars + 1);
46
47		circuit_layers.push(input_layer);
48		debug_span!("constructing_layers").in_scope(|| {
49			for layer_n_vars in (0..n_vars).rev() {
50				let prev_layer = circuit_layers
51					.last()
52					.expect("all_layers is not empty by invariant");
53				let max_layer_len = 1 << layer_n_vars.saturating_sub(P::LOG_WIDTH);
54				let mut layer = zeroed_vec(prev_layer.len().min(max_layer_len));
55
56				// Specialize the _last_ variable to construct the next layer.
57				if layer_n_vars >= P::LOG_WIDTH {
58					let packed_len = 1 << (layer_n_vars - P::LOG_WIDTH);
59					let pivot = prev_layer.len().saturating_sub(packed_len);
60
61					if pivot > 0 {
62						let (evals_0, evals_1) = prev_layer.split_at(packed_len);
63						(layer.as_mut_slice(), evals_0, evals_1)
64							.into_par_iter()
65							.for_each(|(product, &eval_0, &eval_1)| {
66								*product = eval_0 * eval_1;
67							});
68					}
69
70					// In case of truncated witness, some of the scalars may stay unaltered
71					// due to implicit multiplication by one.
72					layer[pivot..]
73						.copy_from_slice(&prev_layer[pivot..packed_len.min(prev_layer.len())]);
74				} else if !prev_layer.is_empty() {
75					let layer = layer
76						.first_mut()
77						.expect("layer.len() >= 1 iff prev_layer.len() >= 1");
78					for i in 0..1 << layer_n_vars {
79						let product = get_packed_slice(prev_layer, i)
80							* get_packed_slice(prev_layer, i | 1 << layer_n_vars);
81						layer.set(i, product);
82					}
83				}
84
85				circuit_layers.push(layer);
86			}
87		});
88
89		Ok(Self { circuit_layers })
90	}
91
92	/// Base-two logarithm of the number of inputs to the GKR grand product circuit
93	pub fn n_vars(&self) -> usize {
94		self.circuit_layers.len() - 1
95	}
96
97	/// Final evaluation of the GKR grand product circuit
98	pub fn grand_product_evaluation(&self) -> P::Scalar {
99		let first_layer = self.circuit_layers.last().expect("always n_vars+1 layers");
100		let first_packed = first_layer.first().copied().unwrap_or_else(P::one);
101		first_packed.get(0)
102	}
103
104	/// Consume the witness, returning the vector of layer multilinears in non-ascending length order.
105	pub fn into_circuit_layers(self) -> Vec<Vec<P>> {
106		self.circuit_layers
107	}
108}
109
110/// LayerClaim is a claim about the evaluation of the kth layer-multilinear at a specific evaluation point
111///
112/// Notation:
113/// * The kth layer-multilinear is the multilinear polynomial whose evaluations are the intermediate values of the kth
114///   layer of the evaluated product circuit.
115#[derive(Debug, Clone, Default)]
116pub struct LayerClaim<F: Field> {
117	pub eval_point: Vec<F>,
118	pub eval: F,
119}
120
121impl<F: Field> LayerClaim<F> {
122	pub fn isomorphic<FI: Field>(self) -> LayerClaim<FI>
123	where
124		F: Into<FI>,
125	{
126		LayerClaim {
127			eval_point: self.eval_point.into_iter().map(Into::into).collect(),
128			eval: self.eval.into(),
129		}
130	}
131}
132
133#[derive(Debug, Default)]
134pub struct GrandProductBatchProveOutput<F: Field> {
135	// Reduced evalcheck claims for all the initial grand product claims
136	pub final_layer_claims: Vec<LayerClaim<F>>,
137}