binius_core/protocols/gkr_gpa/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, TowerField};
4use binius_math::{extrapolate_line_scalar, EvaluationOrder};
5use binius_utils::{
6	bail,
7	sorting::{stable_sort, unsort},
8};
9use tracing::instrument;
10
11use super::{gkr_gpa::LayerClaim, Error, GrandProductClaim};
12use crate::{
13	composition::{BivariateProduct, IndexComposition},
14	fiat_shamir::{CanSample, Challenger},
15	polynomial::Error as PolynomialError,
16	protocols::sumcheck::{self, CompositeSumClaim, EqIndSumcheckClaim},
17	transcript::VerifierTranscript,
18};
19
20/// Verifies batch reduction turning each GrandProductClaim into an EvalcheckMultilinearClaim
21#[instrument(skip_all, name = "gkr_gpa::batch_verify", level = "debug")]
22pub fn batch_verify<F, Challenger_>(
23	evaluation_order: EvaluationOrder,
24	claims: impl IntoIterator<Item = GrandProductClaim<F>>,
25	transcript: &mut VerifierTranscript<Challenger_>,
26) -> Result<Vec<LayerClaim<F>>, Error>
27where
28	F: TowerField,
29	Challenger_: Challenger,
30{
31	let (original_indices, mut sorted_claims) = stable_sort(claims, |claim| claim.n_vars, true);
32	let max_n_vars = sorted_claims.first().map(|claim| claim.n_vars).unwrap_or(0);
33
34	// Create LayerClaims for each of the claims
35	let mut layer_claims = sorted_claims
36		.iter()
37		.map(|claim| LayerClaim {
38			eval_point: vec![],
39			eval: claim.product,
40		})
41		.collect::<Vec<_>>();
42
43	// Create a vector of evalchecks with the same length as the number of claims
44	let n_claims = sorted_claims.len();
45	let mut reverse_sorted_evalcheck_claims = Vec::with_capacity(n_claims);
46
47	for layer_no in 0..max_n_vars {
48		process_finished_claims(
49			n_claims,
50			layer_no,
51			&mut layer_claims,
52			&mut sorted_claims,
53			&mut reverse_sorted_evalcheck_claims,
54		);
55
56		layer_claims = reduce_layer_claim_batch(evaluation_order, &layer_claims, transcript)?;
57	}
58	process_finished_claims(
59		n_claims,
60		max_n_vars,
61		&mut layer_claims,
62		&mut sorted_claims,
63		&mut reverse_sorted_evalcheck_claims,
64	);
65
66	debug_assert!(layer_claims.is_empty());
67	debug_assert_eq!(reverse_sorted_evalcheck_claims.len(), n_claims);
68
69	reverse_sorted_evalcheck_claims.reverse();
70	let sorted_evalcheck_claims = reverse_sorted_evalcheck_claims;
71
72	let final_layer_claims = unsort(original_indices, sorted_evalcheck_claims);
73	Ok(final_layer_claims)
74}
75
76fn process_finished_claims<F: Field>(
77	n_claims: usize,
78	layer_no: usize,
79	layer_claims: &mut Vec<LayerClaim<F>>,
80	sorted_claims: &mut Vec<GrandProductClaim<F>>,
81	reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
82) {
83	while let Some(claim) = sorted_claims.last() {
84		if claim.n_vars != layer_no {
85			break;
86		}
87
88		debug_assert!(layer_no > 0);
89		debug_assert_eq!(sorted_claims.len(), layer_claims.len());
90		let finished_layer_claim = layer_claims.pop().expect("must exist");
91		let _ = sorted_claims.pop().expect("must exist");
92		reverse_sorted_final_layer_claims.push(finished_layer_claim);
93		debug_assert_eq!(sorted_claims.len() + reverse_sorted_final_layer_claims.len(), n_claims);
94	}
95}
96
97/// Reduces n kth LayerClaims to n (k+1)th LayerClaims
98///
99/// Arguments
100/// * `claims` - The kth layer LayerClaims
101/// * `proof` - The batch layer proof that reduces the kth layer claims of the product circuits to the (k+1)th
102/// * `transcript` - The verifier transcript
103fn reduce_layer_claim_batch<F, Challenger_>(
104	evaluation_order: EvaluationOrder,
105	claims: &[LayerClaim<F>],
106	transcript: &mut VerifierTranscript<Challenger_>,
107) -> Result<Vec<LayerClaim<F>>, Error>
108where
109	F: TowerField,
110	Challenger_: Challenger,
111{
112	// Validation
113	if claims.is_empty() {
114		return Ok(vec![]);
115	}
116
117	let curr_layer_challenge = &claims[0].eval_point;
118	if !claims
119		.iter()
120		.all(|claim| &claim.eval_point == curr_layer_challenge)
121	{
122		bail!(Error::MismatchedEvalPointLength);
123	}
124
125	let n_vars = curr_layer_challenge.len();
126	let n_multilinears = 2 * claims.len();
127
128	let composite_sums = claims
129		.iter()
130		.enumerate()
131		.map(|(i, claim)| {
132			let composition =
133				IndexComposition::new(n_multilinears, [2 * i, 2 * i + 1], BivariateProduct {})?;
134
135			let composite_sum_claim = CompositeSumClaim {
136				composition,
137				sum: claim.eval,
138			};
139
140			Ok(composite_sum_claim)
141		})
142		.collect::<Result<Vec<_>, PolynomialError>>()?;
143
144	let eq_ind_sumcheck_claim = EqIndSumcheckClaim::new(n_vars, n_multilinears, composite_sums)?;
145
146	let eq_ind_sumcheck_claims = [eq_ind_sumcheck_claim];
147
148	let regular_sumcheck_claims =
149		sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
150
151	let batch_sumcheck_output =
152		sumcheck::batch_verify(evaluation_order, &regular_sumcheck_claims, transcript)?;
153
154	let batch_sumcheck_output = sumcheck::eq_ind::verify_sumcheck_outputs(
155		&eq_ind_sumcheck_claims,
156		curr_layer_challenge,
157		batch_sumcheck_output,
158	)?;
159
160	// Create the new (k+1)th layer LayerClaims for each grand product circuit
161	let sumcheck_challenge = batch_sumcheck_output.challenges.clone();
162	let gpa_challenge = transcript.sample();
163	let new_layer_challenge = sumcheck_challenge
164		.into_iter()
165		.chain(Some(gpa_challenge))
166		.collect::<Vec<_>>();
167	let new_layer_claims = batch_sumcheck_output.multilinear_evals[0]
168		.chunks_exact(2)
169		.map(|evals| {
170			let new_eval = extrapolate_line_scalar::<_, F>(evals[0], evals[1], gpa_challenge);
171			LayerClaim {
172				eval_point: new_layer_challenge.clone(),
173				eval: new_eval,
174			}
175		})
176		.collect::<Vec<_>>();
177
178	Ok(new_layer_claims)
179}