binius_core/protocols/gkr_gpa/
verify.rs1use binius_field::{Field, TowerField};
4use binius_math::{EvaluationOrder, extrapolate_line_scalar};
5use binius_utils::{
6 bail,
7 sorting::{stable_sort, unsort},
8};
9use tracing::instrument;
10
11use super::{Error, GrandProductClaim, gkr_gpa::LayerClaim};
12use crate::{
13 composition::{BivariateProduct, IndexComposition},
14 fiat_shamir::{CanSample, Challenger},
15 polynomial::Error as PolynomialError,
16 protocols::sumcheck::{
17 self, CompositeSumClaim, EqIndSumcheckClaim, eq_ind::ClaimsSortingOrder, front_loaded,
18 },
19 transcript::VerifierTranscript,
20};
21
22#[instrument(skip_all, name = "gkr_gpa::batch_verify", level = "debug")]
24pub fn batch_verify<F, Challenger_>(
25 evaluation_order: EvaluationOrder,
26 claims: impl IntoIterator<Item = GrandProductClaim<F>>,
27 transcript: &mut VerifierTranscript<Challenger_>,
28) -> Result<Vec<LayerClaim<F>>, Error>
29where
30 F: TowerField,
31 Challenger_: Challenger,
32{
33 let (original_indices, mut sorted_claims) = stable_sort(claims, |claim| claim.n_vars, true);
34 let max_n_vars = sorted_claims.first().map(|claim| claim.n_vars).unwrap_or(0);
35
36 let mut layer_claims = sorted_claims
38 .iter()
39 .map(|claim| LayerClaim {
40 eval_point: vec![],
41 eval: claim.product,
42 })
43 .collect::<Vec<_>>();
44
45 let n_claims = sorted_claims.len();
47 let mut reverse_sorted_evalcheck_claims = Vec::with_capacity(n_claims);
48
49 for layer_no in 0..max_n_vars {
50 process_finished_claims(
51 n_claims,
52 layer_no,
53 &mut layer_claims,
54 &mut sorted_claims,
55 &mut reverse_sorted_evalcheck_claims,
56 );
57
58 layer_claims = reduce_layer_claim_batch(evaluation_order, &layer_claims, transcript)?;
59 }
60 process_finished_claims(
61 n_claims,
62 max_n_vars,
63 &mut layer_claims,
64 &mut sorted_claims,
65 &mut reverse_sorted_evalcheck_claims,
66 );
67
68 debug_assert!(layer_claims.is_empty());
69 debug_assert_eq!(reverse_sorted_evalcheck_claims.len(), n_claims);
70
71 reverse_sorted_evalcheck_claims.reverse();
72 let sorted_evalcheck_claims = reverse_sorted_evalcheck_claims;
73
74 let final_layer_claims = unsort(original_indices, sorted_evalcheck_claims);
75 Ok(final_layer_claims)
76}
77
78fn process_finished_claims<F: Field>(
79 n_claims: usize,
80 layer_no: usize,
81 layer_claims: &mut Vec<LayerClaim<F>>,
82 sorted_claims: &mut Vec<GrandProductClaim<F>>,
83 reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
84) {
85 while let Some(claim) = sorted_claims.last() {
86 if claim.n_vars != layer_no {
87 break;
88 }
89
90 debug_assert!(!layer_claims.is_empty());
91 debug_assert_eq!(sorted_claims.len(), layer_claims.len());
92 let finished_layer_claim = layer_claims.pop().expect("must exist");
93 let _ = sorted_claims.pop().expect("must exist");
94 reverse_sorted_final_layer_claims.push(finished_layer_claim);
95 debug_assert_eq!(sorted_claims.len() + reverse_sorted_final_layer_claims.len(), n_claims);
96 }
97}
98
99fn reduce_layer_claim_batch<F, Challenger_>(
107 evaluation_order: EvaluationOrder,
108 claims: &[LayerClaim<F>],
109 transcript: &mut VerifierTranscript<Challenger_>,
110) -> Result<Vec<LayerClaim<F>>, Error>
111where
112 F: TowerField,
113 Challenger_: Challenger,
114{
115 if claims.is_empty() {
117 return Ok(vec![]);
118 }
119
120 let curr_layer_challenge = &claims[0].eval_point;
121 if !claims
122 .iter()
123 .all(|claim| &claim.eval_point == curr_layer_challenge)
124 {
125 bail!(Error::MismatchedEvalPointLength);
126 }
127
128 let n_vars = curr_layer_challenge.len();
129 let n_multilinears = 2 * claims.len();
130
131 let composite_sums = claims
132 .iter()
133 .enumerate()
134 .map(|(i, claim)| {
135 let composition =
136 IndexComposition::new(n_multilinears, [2 * i, 2 * i + 1], BivariateProduct {})?;
137
138 let composite_sum_claim = CompositeSumClaim {
139 composition,
140 sum: claim.eval,
141 };
142
143 Ok(composite_sum_claim)
144 })
145 .collect::<Result<Vec<_>, PolynomialError>>()?;
146
147 let eq_ind_sumcheck_claim = EqIndSumcheckClaim::new(n_vars, n_multilinears, composite_sums)?;
148
149 let eq_ind_sumcheck_claims = [eq_ind_sumcheck_claim];
150
151 let regular_sumcheck_claims =
152 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
153
154 let batch_sumcheck_verifier =
155 front_loaded::BatchVerifier::new(®ular_sumcheck_claims, transcript)?;
156 let mut batch_sumcheck_output = batch_sumcheck_verifier.run(transcript)?;
157
158 if evaluation_order == EvaluationOrder::HighToLow {
159 batch_sumcheck_output.challenges.reverse();
160 }
161
162 let batch_sumcheck_output = sumcheck::eq_ind::verify_sumcheck_outputs(
163 ClaimsSortingOrder::DescendingVars,
164 &eq_ind_sumcheck_claims,
165 curr_layer_challenge,
166 batch_sumcheck_output,
167 )?;
168
169 let sumcheck_challenge = batch_sumcheck_output.challenges.clone();
171 let gpa_challenge = transcript.sample();
172 let new_layer_challenge = sumcheck_challenge
173 .into_iter()
174 .chain(Some(gpa_challenge))
175 .collect::<Vec<_>>();
176 let new_layer_claims = batch_sumcheck_output.multilinear_evals[0]
177 .chunks_exact(2)
178 .map(|evals| {
179 let new_eval = extrapolate_line_scalar::<_, F>(evals[0], evals[1], gpa_challenge);
180 LayerClaim {
181 eval_point: new_layer_challenge.clone(),
182 eval: new_eval,
183 }
184 })
185 .collect::<Vec<_>>();
186
187 Ok(new_layer_claims)
188}