binius_core/protocols/gkr_gpa/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, TowerField};
4use binius_math::{EvaluationOrder, extrapolate_line_scalar};
5use binius_utils::{
6	bail,
7	sorting::{stable_sort, unsort},
8};
9use tracing::instrument;
10
11use super::{Error, GrandProductClaim, gkr_gpa::LayerClaim};
12use crate::{
13	composition::{BivariateProduct, IndexComposition},
14	fiat_shamir::{CanSample, Challenger},
15	polynomial::Error as PolynomialError,
16	protocols::sumcheck::{
17		self, CompositeSumClaim, EqIndSumcheckClaim, eq_ind::ClaimsSortingOrder, front_loaded,
18	},
19	transcript::VerifierTranscript,
20};
21
22/// Verifies batch reduction turning each GrandProductClaim into an EvalcheckMultilinearClaim
23#[instrument(skip_all, name = "gkr_gpa::batch_verify", level = "debug")]
24pub fn batch_verify<F, Challenger_>(
25	evaluation_order: EvaluationOrder,
26	claims: impl IntoIterator<Item = GrandProductClaim<F>>,
27	transcript: &mut VerifierTranscript<Challenger_>,
28) -> Result<Vec<LayerClaim<F>>, Error>
29where
30	F: TowerField,
31	Challenger_: Challenger,
32{
33	let (original_indices, mut sorted_claims) = stable_sort(claims, |claim| claim.n_vars, true);
34	let max_n_vars = sorted_claims.first().map(|claim| claim.n_vars).unwrap_or(0);
35
36	// Create LayerClaims for each of the claims
37	let mut layer_claims = sorted_claims
38		.iter()
39		.map(|claim| LayerClaim {
40			eval_point: vec![],
41			eval: claim.product,
42		})
43		.collect::<Vec<_>>();
44
45	// Create a vector of evalchecks with the same length as the number of claims
46	let n_claims = sorted_claims.len();
47	let mut reverse_sorted_evalcheck_claims = Vec::with_capacity(n_claims);
48
49	for layer_no in 0..max_n_vars {
50		process_finished_claims(
51			n_claims,
52			layer_no,
53			&mut layer_claims,
54			&mut sorted_claims,
55			&mut reverse_sorted_evalcheck_claims,
56		);
57
58		layer_claims = reduce_layer_claim_batch(evaluation_order, &layer_claims, transcript)?;
59	}
60	process_finished_claims(
61		n_claims,
62		max_n_vars,
63		&mut layer_claims,
64		&mut sorted_claims,
65		&mut reverse_sorted_evalcheck_claims,
66	);
67
68	debug_assert!(layer_claims.is_empty());
69	debug_assert_eq!(reverse_sorted_evalcheck_claims.len(), n_claims);
70
71	reverse_sorted_evalcheck_claims.reverse();
72	let sorted_evalcheck_claims = reverse_sorted_evalcheck_claims;
73
74	let final_layer_claims = unsort(original_indices, sorted_evalcheck_claims);
75	Ok(final_layer_claims)
76}
77
78fn process_finished_claims<F: Field>(
79	n_claims: usize,
80	layer_no: usize,
81	layer_claims: &mut Vec<LayerClaim<F>>,
82	sorted_claims: &mut Vec<GrandProductClaim<F>>,
83	reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
84) {
85	while let Some(claim) = sorted_claims.last() {
86		if claim.n_vars != layer_no {
87			break;
88		}
89
90		debug_assert!(!layer_claims.is_empty());
91		debug_assert_eq!(sorted_claims.len(), layer_claims.len());
92		let finished_layer_claim = layer_claims.pop().expect("must exist");
93		let _ = sorted_claims.pop().expect("must exist");
94		reverse_sorted_final_layer_claims.push(finished_layer_claim);
95		debug_assert_eq!(sorted_claims.len() + reverse_sorted_final_layer_claims.len(), n_claims);
96	}
97}
98
99/// Reduces n kth LayerClaims to n (k+1)th LayerClaims
100///
101/// Arguments
102/// * `claims` - The kth layer LayerClaims
103/// * `proof` - The batch layer proof that reduces the kth layer claims of the product circuits to
104///   the (k+1)th
105/// * `transcript` - The verifier transcript
106fn reduce_layer_claim_batch<F, Challenger_>(
107	evaluation_order: EvaluationOrder,
108	claims: &[LayerClaim<F>],
109	transcript: &mut VerifierTranscript<Challenger_>,
110) -> Result<Vec<LayerClaim<F>>, Error>
111where
112	F: TowerField,
113	Challenger_: Challenger,
114{
115	// Validation
116	if claims.is_empty() {
117		return Ok(vec![]);
118	}
119
120	let curr_layer_challenge = &claims[0].eval_point;
121	if !claims
122		.iter()
123		.all(|claim| &claim.eval_point == curr_layer_challenge)
124	{
125		bail!(Error::MismatchedEvalPointLength);
126	}
127
128	let n_vars = curr_layer_challenge.len();
129	let n_multilinears = 2 * claims.len();
130
131	let composite_sums = claims
132		.iter()
133		.enumerate()
134		.map(|(i, claim)| {
135			let composition =
136				IndexComposition::new(n_multilinears, [2 * i, 2 * i + 1], BivariateProduct {})?;
137
138			let composite_sum_claim = CompositeSumClaim {
139				composition,
140				sum: claim.eval,
141			};
142
143			Ok(composite_sum_claim)
144		})
145		.collect::<Result<Vec<_>, PolynomialError>>()?;
146
147	let eq_ind_sumcheck_claim = EqIndSumcheckClaim::new(n_vars, n_multilinears, composite_sums)?;
148
149	let eq_ind_sumcheck_claims = [eq_ind_sumcheck_claim];
150
151	let regular_sumcheck_claims =
152		sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
153
154	let batch_sumcheck_verifier =
155		front_loaded::BatchVerifier::new(&regular_sumcheck_claims, transcript)?;
156	let mut batch_sumcheck_output = batch_sumcheck_verifier.run(transcript)?;
157
158	if evaluation_order == EvaluationOrder::HighToLow {
159		batch_sumcheck_output.challenges.reverse();
160	}
161
162	let batch_sumcheck_output = sumcheck::eq_ind::verify_sumcheck_outputs(
163		ClaimsSortingOrder::DescendingVars,
164		&eq_ind_sumcheck_claims,
165		curr_layer_challenge,
166		batch_sumcheck_output,
167	)?;
168
169	// Create the new (k+1)th layer LayerClaims for each grand product circuit
170	let sumcheck_challenge = batch_sumcheck_output.challenges.clone();
171	let gpa_challenge = transcript.sample();
172	let new_layer_challenge = sumcheck_challenge
173		.into_iter()
174		.chain(Some(gpa_challenge))
175		.collect::<Vec<_>>();
176	let new_layer_claims = batch_sumcheck_output.multilinear_evals[0]
177		.chunks_exact(2)
178		.map(|evals| {
179			let new_eval = extrapolate_line_scalar::<_, F>(evals[0], evals[1], gpa_challenge);
180			LayerClaim {
181				eval_point: new_layer_challenge.clone(),
182				eval: new_eval,
183			}
184		})
185		.collect::<Vec<_>>();
186
187	Ok(new_layer_claims)
188}