binius_core/protocols/gkr_gpa/
verify.rs1use binius_field::{Field, TowerField};
4use binius_math::{extrapolate_line_scalar, EvaluationOrder};
5use binius_utils::{
6 bail,
7 sorting::{stable_sort, unsort},
8};
9use tracing::instrument;
10
11use super::{gkr_gpa::LayerClaim, Error, GrandProductClaim};
12use crate::{
13 composition::{BivariateProduct, IndexComposition},
14 fiat_shamir::{CanSample, Challenger},
15 polynomial::Error as PolynomialError,
16 protocols::sumcheck::{self, CompositeSumClaim, EqIndSumcheckClaim},
17 transcript::VerifierTranscript,
18};
19
20#[instrument(skip_all, name = "gkr_gpa::batch_verify", level = "debug")]
22pub fn batch_verify<F, Challenger_>(
23 evaluation_order: EvaluationOrder,
24 claims: impl IntoIterator<Item = GrandProductClaim<F>>,
25 transcript: &mut VerifierTranscript<Challenger_>,
26) -> Result<Vec<LayerClaim<F>>, Error>
27where
28 F: TowerField,
29 Challenger_: Challenger,
30{
31 let (original_indices, mut sorted_claims) = stable_sort(claims, |claim| claim.n_vars, true);
32 let max_n_vars = sorted_claims.first().map(|claim| claim.n_vars).unwrap_or(0);
33
34 let mut layer_claims = sorted_claims
36 .iter()
37 .map(|claim| LayerClaim {
38 eval_point: vec![],
39 eval: claim.product,
40 })
41 .collect::<Vec<_>>();
42
43 let n_claims = sorted_claims.len();
45 let mut reverse_sorted_evalcheck_claims = Vec::with_capacity(n_claims);
46
47 for layer_no in 0..max_n_vars {
48 process_finished_claims(
49 n_claims,
50 layer_no,
51 &mut layer_claims,
52 &mut sorted_claims,
53 &mut reverse_sorted_evalcheck_claims,
54 );
55
56 layer_claims = reduce_layer_claim_batch(evaluation_order, &layer_claims, transcript)?;
57 }
58 process_finished_claims(
59 n_claims,
60 max_n_vars,
61 &mut layer_claims,
62 &mut sorted_claims,
63 &mut reverse_sorted_evalcheck_claims,
64 );
65
66 debug_assert!(layer_claims.is_empty());
67 debug_assert_eq!(reverse_sorted_evalcheck_claims.len(), n_claims);
68
69 reverse_sorted_evalcheck_claims.reverse();
70 let sorted_evalcheck_claims = reverse_sorted_evalcheck_claims;
71
72 let final_layer_claims = unsort(original_indices, sorted_evalcheck_claims);
73 Ok(final_layer_claims)
74}
75
76fn process_finished_claims<F: Field>(
77 n_claims: usize,
78 layer_no: usize,
79 layer_claims: &mut Vec<LayerClaim<F>>,
80 sorted_claims: &mut Vec<GrandProductClaim<F>>,
81 reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
82) {
83 while let Some(claim) = sorted_claims.last() {
84 if claim.n_vars != layer_no {
85 break;
86 }
87
88 debug_assert!(layer_no > 0);
89 debug_assert_eq!(sorted_claims.len(), layer_claims.len());
90 let finished_layer_claim = layer_claims.pop().expect("must exist");
91 let _ = sorted_claims.pop().expect("must exist");
92 reverse_sorted_final_layer_claims.push(finished_layer_claim);
93 debug_assert_eq!(sorted_claims.len() + reverse_sorted_final_layer_claims.len(), n_claims);
94 }
95}
96
97fn reduce_layer_claim_batch<F, Challenger_>(
104 evaluation_order: EvaluationOrder,
105 claims: &[LayerClaim<F>],
106 transcript: &mut VerifierTranscript<Challenger_>,
107) -> Result<Vec<LayerClaim<F>>, Error>
108where
109 F: TowerField,
110 Challenger_: Challenger,
111{
112 if claims.is_empty() {
114 return Ok(vec![]);
115 }
116
117 let curr_layer_challenge = &claims[0].eval_point;
118 if !claims
119 .iter()
120 .all(|claim| &claim.eval_point == curr_layer_challenge)
121 {
122 bail!(Error::MismatchedEvalPointLength);
123 }
124
125 let n_vars = curr_layer_challenge.len();
126 let n_multilinears = 2 * claims.len();
127
128 let composite_sums = claims
129 .iter()
130 .enumerate()
131 .map(|(i, claim)| {
132 let composition =
133 IndexComposition::new(n_multilinears, [2 * i, 2 * i + 1], BivariateProduct {})?;
134
135 let composite_sum_claim = CompositeSumClaim {
136 composition,
137 sum: claim.eval,
138 };
139
140 Ok(composite_sum_claim)
141 })
142 .collect::<Result<Vec<_>, PolynomialError>>()?;
143
144 let eq_ind_sumcheck_claim = EqIndSumcheckClaim::new(n_vars, n_multilinears, composite_sums)?;
145
146 let eq_ind_sumcheck_claims = [eq_ind_sumcheck_claim];
147
148 let regular_sumcheck_claims =
149 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_claims)?;
150
151 let batch_sumcheck_output =
152 sumcheck::batch_verify(evaluation_order, ®ular_sumcheck_claims, transcript)?;
153
154 let batch_sumcheck_output = sumcheck::eq_ind::verify_sumcheck_outputs(
155 &eq_ind_sumcheck_claims,
156 curr_layer_challenge,
157 batch_sumcheck_output,
158 )?;
159
160 let sumcheck_challenge = batch_sumcheck_output.challenges.clone();
162 let gpa_challenge = transcript.sample();
163 let new_layer_challenge = sumcheck_challenge
164 .into_iter()
165 .chain(Some(gpa_challenge))
166 .collect::<Vec<_>>();
167 let new_layer_claims = batch_sumcheck_output.multilinear_evals[0]
168 .chunks_exact(2)
169 .map(|evals| {
170 let new_eval = extrapolate_line_scalar::<_, F>(evals[0], evals[1], gpa_challenge);
171 LayerClaim {
172 eval_point: new_layer_challenge.clone(),
173 eval: new_eval,
174 }
175 })
176 .collect::<Vec<_>>();
177
178 Ok(new_layer_claims)
179}