binius_core/protocols/gkr_gpa/
verify.rs1use binius_field::{Field, TowerField};
4use binius_math::{extrapolate_line_scalar, EvaluationOrder};
5use binius_utils::{
6 bail,
7 sorting::{stable_sort, unsort},
8};
9use tracing::instrument;
10
11use super::{
12 gkr_gpa::LayerClaim,
13 gpa_sumcheck::verify::{reduce_to_sumcheck, verify_sumcheck_outputs, GPASumcheckClaim},
14 Error, GrandProductClaim,
15};
16use crate::{
17 fiat_shamir::{CanSample, Challenger},
18 protocols::sumcheck,
19 transcript::VerifierTranscript,
20};
21
22#[instrument(skip_all, name = "gkr_gpa::batch_verify", level = "debug")]
24pub fn batch_verify<F, Challenger_>(
25 evaluation_order: EvaluationOrder,
26 claims: impl IntoIterator<Item = GrandProductClaim<F>>,
27 transcript: &mut VerifierTranscript<Challenger_>,
28) -> Result<Vec<LayerClaim<F>>, Error>
29where
30 F: TowerField,
31 Challenger_: Challenger,
32{
33 let (original_indices, mut sorted_claims) = stable_sort(claims, |claim| claim.n_vars, true);
34 let max_n_vars = sorted_claims.first().map(|claim| claim.n_vars).unwrap_or(0);
35
36 let mut layer_claims = sorted_claims
38 .iter()
39 .map(|claim| LayerClaim {
40 eval_point: vec![],
41 eval: claim.product,
42 })
43 .collect::<Vec<_>>();
44
45 let n_claims = sorted_claims.len();
47 let mut reverse_sorted_evalcheck_claims = Vec::with_capacity(n_claims);
48
49 for layer_no in 0..max_n_vars {
50 process_finished_claims(
51 n_claims,
52 layer_no,
53 &mut layer_claims,
54 &mut sorted_claims,
55 &mut reverse_sorted_evalcheck_claims,
56 );
57
58 layer_claims = reduce_layer_claim_batch(evaluation_order, &layer_claims, transcript)?;
59 }
60 process_finished_claims(
61 n_claims,
62 max_n_vars,
63 &mut layer_claims,
64 &mut sorted_claims,
65 &mut reverse_sorted_evalcheck_claims,
66 );
67
68 debug_assert!(layer_claims.is_empty());
69 debug_assert_eq!(reverse_sorted_evalcheck_claims.len(), n_claims);
70
71 reverse_sorted_evalcheck_claims.reverse();
72 let sorted_evalcheck_claims = reverse_sorted_evalcheck_claims;
73
74 let final_layer_claims = unsort(original_indices, sorted_evalcheck_claims);
75 Ok(final_layer_claims)
76}
77
78fn process_finished_claims<F: Field>(
79 n_claims: usize,
80 layer_no: usize,
81 layer_claims: &mut Vec<LayerClaim<F>>,
82 sorted_claims: &mut Vec<GrandProductClaim<F>>,
83 reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
84) {
85 while let Some(claim) = sorted_claims.last() {
86 if claim.n_vars != layer_no {
87 break;
88 }
89
90 debug_assert!(layer_no > 0);
91 debug_assert_eq!(sorted_claims.len(), layer_claims.len());
92 let finished_layer_claim = layer_claims.pop().expect("must exist");
93 let _ = sorted_claims.pop().expect("must exist");
94 reverse_sorted_final_layer_claims.push(finished_layer_claim);
95 debug_assert_eq!(sorted_claims.len() + reverse_sorted_final_layer_claims.len(), n_claims);
96 }
97}
98
99fn reduce_layer_claim_batch<F, Challenger_>(
106 evaluation_order: EvaluationOrder,
107 claims: &[LayerClaim<F>],
108 transcript: &mut VerifierTranscript<Challenger_>,
109) -> Result<Vec<LayerClaim<F>>, Error>
110where
111 F: TowerField,
112 Challenger_: Challenger,
113{
114 if claims.is_empty() {
116 return Ok(vec![]);
117 }
118
119 let curr_layer_challenge = &claims[0].eval_point[..];
120 if !claims
121 .iter()
122 .all(|claim| claim.eval_point == curr_layer_challenge)
123 {
124 bail!(Error::MismatchedEvalPointLength);
125 }
126
127 let gpa_sumcheck_claims = claims
129 .iter()
130 .map(|claim| GPASumcheckClaim::new(claim.eval_point.len(), claim.eval))
131 .collect::<Result<Vec<_>, _>>()?;
132
133 let sumcheck_claim = reduce_to_sumcheck(&gpa_sumcheck_claims)?;
134 let sumcheck_claims = [sumcheck_claim];
135
136 let batch_sumcheck_output =
137 sumcheck::batch_verify(evaluation_order, &sumcheck_claims, transcript)?;
138
139 let batch_sumcheck_output =
140 verify_sumcheck_outputs(&gpa_sumcheck_claims, curr_layer_challenge, batch_sumcheck_output)?;
141
142 let sumcheck_challenge = batch_sumcheck_output.challenges.clone();
144 let gpa_challenge = transcript.sample();
145 let new_layer_challenge = sumcheck_challenge
146 .into_iter()
147 .chain(Some(gpa_challenge))
148 .collect::<Vec<_>>();
149 let new_layer_claims = batch_sumcheck_output.multilinear_evals[0]
150 .chunks_exact(2)
151 .map(|evals| {
152 let new_eval = extrapolate_line_scalar::<_, F>(evals[0], evals[1], gpa_challenge);
153 LayerClaim {
154 eval_point: new_layer_challenge.clone(),
155 eval: new_eval,
156 }
157 })
158 .collect::<Vec<_>>();
159
160 Ok(new_layer_claims)
161}