binius_core/protocols/gkr_gpa/gpa_sumcheck/
verify.rs1use std::iter;
4
5use binius_field::{util::eq, Field};
6use binius_utils::{bail, sorting::is_sorted_ascending};
7use getset::CopyGetters;
8
9use crate::{
10 composition::{IndexComposition, TrivariateProduct},
11 protocols::sumcheck::{
12 BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, VerificationError,
13 },
14};
15
16#[derive(Debug, CopyGetters)]
17pub struct GPASumcheckClaim<F: Field> {
18 #[getset(get_copy = "pub")]
19 n_vars: usize,
20 sum: F,
21}
22
23impl<F: Field> GPASumcheckClaim<F> {
24 pub const fn new(n_vars: usize, sum: F) -> Result<Self, Error> {
25 Ok(Self { n_vars, sum })
26 }
27}
28
29pub fn reduce_to_sumcheck<F: Field>(
30 claims: &[GPASumcheckClaim<F>],
31) -> Result<SumcheckClaim<F, IndexComposition<TrivariateProduct, 3>>, Error> {
32 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
34 bail!(Error::ClaimsOutOfOrder);
35 }
36
37 let n_vars = claims.first().map_or(0, |claim| claim.n_vars);
38
39 if claims.iter().any(|claim| claim.n_vars != n_vars) {
40 bail!(Error::NumberOfVariablesMismatch);
41 }
42
43 let n_claims = claims.len();
44 let n_multilinears = 2 * n_claims + 1;
45
46 let composite_sums = claims
47 .iter()
48 .enumerate()
49 .map(|(i, claim)| {
50 let composition = IndexComposition::new(
51 n_multilinears,
52 [2 * i, 2 * i + 1, n_multilinears - 1],
53 TrivariateProduct {},
54 )?;
55 let composite_sum_claim = CompositeSumClaim {
56 composition,
57 sum: claim.sum,
58 };
59 Ok(composite_sum_claim)
60 })
61 .collect::<Result<Vec<_>, Error>>()?;
62
63 let sumcheck_claim = SumcheckClaim::new(n_vars, n_multilinears, composite_sums)?;
64
65 Ok(sumcheck_claim)
66}
67
68pub fn verify_sumcheck_outputs<F: Field>(
74 claims: &[GPASumcheckClaim<F>],
75 gpa_challenges: &[F],
76 sumcheck_output: BatchSumcheckOutput<F>,
77) -> Result<BatchSumcheckOutput<F>, Error> {
78 let BatchSumcheckOutput {
79 challenges: sumcheck_challenges,
80 mut multilinear_evals,
81 } = sumcheck_output;
82
83 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
85 bail!(Error::ClaimsOutOfOrder);
86 }
87
88 if multilinear_evals.len() != 1 || multilinear_evals[0].len() != 2 * claims.len() + 1 {
89 return Err(VerificationError::NumberOfFinalEvaluations.into());
90 }
91
92 let max_n_vars = claims
93 .first()
94 .map(|claim| claim.n_vars())
95 .unwrap_or_default();
96
97 assert_eq!(gpa_challenges.len(), max_n_vars);
98 assert_eq!(sumcheck_challenges.len(), max_n_vars);
99
100 let eq_ind_eval = iter::zip(gpa_challenges, &sumcheck_challenges)
101 .map(|(&gpa_challenge, &sumcheck_challenge)| eq(gpa_challenge, sumcheck_challenge))
102 .product::<F>();
103
104 let multilinear_evals_last = multilinear_evals[0]
105 .pop()
106 .expect("checked above that multilinear_evals length is at least 1");
107
108 if eq_ind_eval != multilinear_evals_last {
109 return Err(VerificationError::IncorrectEqIndEvaluation.into());
110 }
111
112 Ok(BatchSumcheckOutput {
113 challenges: sumcheck_challenges,
114 multilinear_evals,
115 })
116}
117
118#[cfg(test)]
119mod tests {
120 use std::iter;
121
122 use binius_field::{
123 arch::OptimalUnderlier128b, as_packed_field::PackedType, BinaryField128b, BinaryField32b,
124 BinaryField8b, PackedField,
125 };
126 use binius_hal::{make_portable_backend, ComputationBackendExt};
127 use binius_math::{EvaluationOrder, IsomorphicEvaluationDomainFactory, MultilinearExtension};
128 use groestl_crypto::Groestl256;
129 use rand::{rngs::StdRng, Rng, SeedableRng};
130
131 use crate::{
132 composition::BivariateProduct,
133 fiat_shamir::{CanSample, HasherChallenger},
134 protocols::{
135 gkr_gpa::gpa_sumcheck::{
136 prove::GPAProver,
137 verify::{reduce_to_sumcheck, verify_sumcheck_outputs, GPASumcheckClaim},
138 },
139 sumcheck,
140 },
141 transcript::ProverTranscript,
142 };
143
144 fn generate_poly_helper<P>(
145 mut rng: impl Rng,
146 n_vars: usize,
147 n_multilinears: usize,
148 ) -> Vec<MultilinearExtension<P>>
149 where
150 P: PackedField,
151 {
152 (0..n_multilinears)
153 .map(|_| {
154 let values = (0..(1 << (n_vars - P::LOG_WIDTH)))
155 .map(|_| PackedField::random(&mut rng))
156 .collect();
157 MultilinearExtension::from_values(values).unwrap()
158 })
159 .collect()
160 }
161
162 #[test]
163 fn test_prove_verify_gpa_sumcheck() {
164 type U = OptimalUnderlier128b;
165 type F = BinaryField32b;
166 type FDomain = BinaryField8b;
167 type FE = BinaryField128b;
168 let mut rng = StdRng::seed_from_u64(0);
169 let backend = make_portable_backend();
170 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
171 let n_vars = 4;
172
173 let mles = generate_poly_helper::<PackedType<U, F>>(&mut rng, n_vars, 2);
174 let prod_mle = MultilinearExtension::from_values(
175 iter::zip(mles[0].evals(), mles[1].evals())
176 .map(|(&a, &b)| a * b)
177 .collect(),
178 )
179 .unwrap();
180
181 let multilins = mles
182 .into_iter()
183 .map(|mle| mle.specialize_arc_dyn::<PackedType<U, FE>>())
184 .collect::<Vec<_>>();
185 let prod_multilin = prod_mle.specialize_arc_dyn::<PackedType<U, FE>>();
186
187 let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
188 let challenges: Vec<FE> = prove_transcript.sample_vec(n_vars);
189
190 let sum = prod_multilin
191 .evaluate(backend.multilinear_query(&challenges).unwrap().to_ref())
192 .unwrap();
193
194 let composite_claims = [sumcheck::CompositeSumClaim {
195 composition: BivariateProduct {},
196 sum,
197 }];
198
199 let prod_multilins = vec![prod_multilin];
200
201 let prover = GPAProver::<FDomain, _, _, _, _>::new(
202 EvaluationOrder::LowToHigh,
203 multilins,
204 Some(prod_multilins),
205 composite_claims,
206 domain_factory,
207 &challenges,
208 &backend,
209 )
210 .unwrap();
211
212 let _ = sumcheck::batch_prove(vec![prover], &mut prove_transcript).unwrap();
213
214 let claim = GPASumcheckClaim::new(n_vars, sum).unwrap();
215
216 let sumcheck_claim = reduce_to_sumcheck(&[claim]).unwrap();
217 let sumcheck_claims = [sumcheck_claim];
218
219 let mut verify_challenger = prove_transcript.into_verifier();
220 let _: Vec<FE> = verify_challenger.sample_vec(n_vars);
221 let batch_output = sumcheck::batch_verify(
222 EvaluationOrder::LowToHigh,
223 &sumcheck_claims,
224 &mut verify_challenger,
225 )
226 .unwrap();
227 verify_challenger.finalize().unwrap();
228
229 let claim = GPASumcheckClaim::new(n_vars, sum).unwrap();
230 verify_sumcheck_outputs(&[claim], &challenges, batch_output).unwrap();
231 }
232}