binius_core/protocols/gkr_gpa/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, TowerField};
4use binius_math::{extrapolate_line_scalar, EvaluationOrder};
5use binius_utils::{
6	bail,
7	sorting::{stable_sort, unsort},
8};
9use tracing::instrument;
10
11use super::{
12	gkr_gpa::LayerClaim,
13	gpa_sumcheck::verify::{reduce_to_sumcheck, verify_sumcheck_outputs, GPASumcheckClaim},
14	Error, GrandProductClaim,
15};
16use crate::{
17	fiat_shamir::{CanSample, Challenger},
18	protocols::sumcheck,
19	transcript::VerifierTranscript,
20};
21
22/// Verifies batch reduction turning each GrandProductClaim into an EvalcheckMultilinearClaim
23#[instrument(skip_all, name = "gkr_gpa::batch_verify", level = "debug")]
24pub fn batch_verify<F, Challenger_>(
25	evaluation_order: EvaluationOrder,
26	claims: impl IntoIterator<Item = GrandProductClaim<F>>,
27	transcript: &mut VerifierTranscript<Challenger_>,
28) -> Result<Vec<LayerClaim<F>>, Error>
29where
30	F: TowerField,
31	Challenger_: Challenger,
32{
33	let (original_indices, mut sorted_claims) = stable_sort(claims, |claim| claim.n_vars, true);
34	let max_n_vars = sorted_claims.first().map(|claim| claim.n_vars).unwrap_or(0);
35
36	// Create LayerClaims for each of the claims
37	let mut layer_claims = sorted_claims
38		.iter()
39		.map(|claim| LayerClaim {
40			eval_point: vec![],
41			eval: claim.product,
42		})
43		.collect::<Vec<_>>();
44
45	// Create a vector of evalchecks with the same length as the number of claims
46	let n_claims = sorted_claims.len();
47	let mut reverse_sorted_evalcheck_claims = Vec::with_capacity(n_claims);
48
49	for layer_no in 0..max_n_vars {
50		process_finished_claims(
51			n_claims,
52			layer_no,
53			&mut layer_claims,
54			&mut sorted_claims,
55			&mut reverse_sorted_evalcheck_claims,
56		);
57
58		layer_claims = reduce_layer_claim_batch(evaluation_order, &layer_claims, transcript)?;
59	}
60	process_finished_claims(
61		n_claims,
62		max_n_vars,
63		&mut layer_claims,
64		&mut sorted_claims,
65		&mut reverse_sorted_evalcheck_claims,
66	);
67
68	debug_assert!(layer_claims.is_empty());
69	debug_assert_eq!(reverse_sorted_evalcheck_claims.len(), n_claims);
70
71	reverse_sorted_evalcheck_claims.reverse();
72	let sorted_evalcheck_claims = reverse_sorted_evalcheck_claims;
73
74	let final_layer_claims = unsort(original_indices, sorted_evalcheck_claims);
75	Ok(final_layer_claims)
76}
77
78fn process_finished_claims<F: Field>(
79	n_claims: usize,
80	layer_no: usize,
81	layer_claims: &mut Vec<LayerClaim<F>>,
82	sorted_claims: &mut Vec<GrandProductClaim<F>>,
83	reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
84) {
85	while let Some(claim) = sorted_claims.last() {
86		if claim.n_vars != layer_no {
87			break;
88		}
89
90		debug_assert!(layer_no > 0);
91		debug_assert_eq!(sorted_claims.len(), layer_claims.len());
92		let finished_layer_claim = layer_claims.pop().expect("must exist");
93		let _ = sorted_claims.pop().expect("must exist");
94		reverse_sorted_final_layer_claims.push(finished_layer_claim);
95		debug_assert_eq!(sorted_claims.len() + reverse_sorted_final_layer_claims.len(), n_claims);
96	}
97}
98
99/// Reduces n kth LayerClaims to n (k+1)th LayerClaims
100///
101/// Arguments
102/// * `claims` - The kth layer LayerClaims
103/// * `proof` - The batch layer proof that reduces the kth layer claims of the product circuits to the (k+1)th
104/// * `transcript` - The verifier transcript
105fn reduce_layer_claim_batch<F, Challenger_>(
106	evaluation_order: EvaluationOrder,
107	claims: &[LayerClaim<F>],
108	transcript: &mut VerifierTranscript<Challenger_>,
109) -> Result<Vec<LayerClaim<F>>, Error>
110where
111	F: TowerField,
112	Challenger_: Challenger,
113{
114	// Validation
115	if claims.is_empty() {
116		return Ok(vec![]);
117	}
118
119	let curr_layer_challenge = &claims[0].eval_point[..];
120	if !claims
121		.iter()
122		.all(|claim| claim.eval_point == curr_layer_challenge)
123	{
124		bail!(Error::MismatchedEvalPointLength);
125	}
126
127	// Verify the gpa sumcheck batch proof and receive the corresponding reduced claims
128	let gpa_sumcheck_claims = claims
129		.iter()
130		.map(|claim| GPASumcheckClaim::new(claim.eval_point.len(), claim.eval))
131		.collect::<Result<Vec<_>, _>>()?;
132
133	let sumcheck_claim = reduce_to_sumcheck(&gpa_sumcheck_claims)?;
134	let sumcheck_claims = [sumcheck_claim];
135
136	let batch_sumcheck_output =
137		sumcheck::batch_verify(evaluation_order, &sumcheck_claims, transcript)?;
138
139	let batch_sumcheck_output =
140		verify_sumcheck_outputs(&gpa_sumcheck_claims, curr_layer_challenge, batch_sumcheck_output)?;
141
142	// Create the new (k+1)th layer LayerClaims for each grand product circuit
143	let sumcheck_challenge = batch_sumcheck_output.challenges.clone();
144	let gpa_challenge = transcript.sample();
145	let new_layer_challenge = sumcheck_challenge
146		.into_iter()
147		.chain(Some(gpa_challenge))
148		.collect::<Vec<_>>();
149	let new_layer_claims = batch_sumcheck_output.multilinear_evals[0]
150		.chunks_exact(2)
151		.map(|evals| {
152			let new_eval = extrapolate_line_scalar::<_, F>(evals[0], evals[1], gpa_challenge);
153			LayerClaim {
154				eval_point: new_layer_challenge.clone(),
155				eval: new_eval,
156			}
157		})
158		.collect::<Vec<_>>();
159
160	Ok(new_layer_claims)
161}