binius_core/protocols/gkr_gpa/gpa_sumcheck/
verify.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_field::{util::eq, Field};
6use binius_utils::{bail, sorting::is_sorted_ascending};
7use getset::CopyGetters;
8
9use crate::{
10	composition::{IndexComposition, TrivariateProduct},
11	protocols::sumcheck::{
12		BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, VerificationError,
13	},
14};
15
16#[derive(Debug, CopyGetters)]
17pub struct GPASumcheckClaim<F: Field> {
18	#[getset(get_copy = "pub")]
19	n_vars: usize,
20	sum: F,
21}
22
23impl<F: Field> GPASumcheckClaim<F> {
24	pub const fn new(n_vars: usize, sum: F) -> Result<Self, Error> {
25		Ok(Self { n_vars, sum })
26	}
27}
28
29pub fn reduce_to_sumcheck<F: Field>(
30	claims: &[GPASumcheckClaim<F>],
31) -> Result<SumcheckClaim<F, IndexComposition<TrivariateProduct, 3>>, Error> {
32	// Check that the claims are in descending order by n_vars
33	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
34		bail!(Error::ClaimsOutOfOrder);
35	}
36
37	let n_vars = claims.first().map_or(0, |claim| claim.n_vars);
38
39	if claims.iter().any(|claim| claim.n_vars != n_vars) {
40		bail!(Error::NumberOfVariablesMismatch);
41	}
42
43	let n_claims = claims.len();
44	let n_multilinears = 2 * n_claims + 1;
45
46	let composite_sums = claims
47		.iter()
48		.enumerate()
49		.map(|(i, claim)| {
50			let composition = IndexComposition::new(
51				n_multilinears,
52				[2 * i, 2 * i + 1, n_multilinears - 1],
53				TrivariateProduct {},
54			)?;
55			let composite_sum_claim = CompositeSumClaim {
56				composition,
57				sum: claim.sum,
58			};
59			Ok(composite_sum_claim)
60		})
61		.collect::<Result<Vec<_>, Error>>()?;
62
63	let sumcheck_claim = SumcheckClaim::new(n_vars, n_multilinears, composite_sums)?;
64
65	Ok(sumcheck_claim)
66}
67
68/// Verify the validity of the sumcheck outputs for a reduced GPA sumcheck.
69///
70/// This takes in the output of the reduced sumcheck protocol and returns the output for the
71/// GPA sumcheck instance. This simply strips off the multilinear evaluation of the eq indicator
72/// polynomial and verifies that the value is correct.
73pub fn verify_sumcheck_outputs<F: Field>(
74	claims: &[GPASumcheckClaim<F>],
75	gpa_challenges: &[F],
76	sumcheck_output: BatchSumcheckOutput<F>,
77) -> Result<BatchSumcheckOutput<F>, Error> {
78	let BatchSumcheckOutput {
79		challenges: sumcheck_challenges,
80		mut multilinear_evals,
81	} = sumcheck_output;
82
83	// Check that the claims are in descending order by n_vars
84	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
85		bail!(Error::ClaimsOutOfOrder);
86	}
87
88	if multilinear_evals.len() != 1 || multilinear_evals[0].len() != 2 * claims.len() + 1 {
89		return Err(VerificationError::NumberOfFinalEvaluations.into());
90	}
91
92	let max_n_vars = claims
93		.first()
94		.map(|claim| claim.n_vars())
95		.unwrap_or_default();
96
97	assert_eq!(gpa_challenges.len(), max_n_vars);
98	assert_eq!(sumcheck_challenges.len(), max_n_vars);
99
100	let eq_ind_eval = iter::zip(gpa_challenges, &sumcheck_challenges)
101		.map(|(&gpa_challenge, &sumcheck_challenge)| eq(gpa_challenge, sumcheck_challenge))
102		.product::<F>();
103
104	let multilinear_evals_last = multilinear_evals[0]
105		.pop()
106		.expect("checked above that multilinear_evals length is at least 1");
107
108	if eq_ind_eval != multilinear_evals_last {
109		return Err(VerificationError::IncorrectEqIndEvaluation.into());
110	}
111
112	Ok(BatchSumcheckOutput {
113		challenges: sumcheck_challenges,
114		multilinear_evals,
115	})
116}
117
118#[cfg(test)]
119mod tests {
120	use std::iter;
121
122	use binius_field::{
123		arch::OptimalUnderlier128b, as_packed_field::PackedType, BinaryField128b, BinaryField32b,
124		BinaryField8b, PackedField,
125	};
126	use binius_hal::{make_portable_backend, ComputationBackendExt};
127	use binius_math::{EvaluationOrder, IsomorphicEvaluationDomainFactory, MultilinearExtension};
128	use groestl_crypto::Groestl256;
129	use rand::{rngs::StdRng, Rng, SeedableRng};
130
131	use crate::{
132		composition::BivariateProduct,
133		fiat_shamir::{CanSample, HasherChallenger},
134		protocols::{
135			gkr_gpa::gpa_sumcheck::{
136				prove::GPAProver,
137				verify::{reduce_to_sumcheck, verify_sumcheck_outputs, GPASumcheckClaim},
138			},
139			sumcheck,
140		},
141		transcript::ProverTranscript,
142	};
143
144	fn generate_poly_helper<P>(
145		mut rng: impl Rng,
146		n_vars: usize,
147		n_multilinears: usize,
148	) -> Vec<MultilinearExtension<P>>
149	where
150		P: PackedField,
151	{
152		(0..n_multilinears)
153			.map(|_| {
154				let values = (0..(1 << (n_vars - P::LOG_WIDTH)))
155					.map(|_| PackedField::random(&mut rng))
156					.collect();
157				MultilinearExtension::from_values(values).unwrap()
158			})
159			.collect()
160	}
161
162	#[test]
163	fn test_prove_verify_gpa_sumcheck() {
164		type U = OptimalUnderlier128b;
165		type F = BinaryField32b;
166		type FDomain = BinaryField8b;
167		type FE = BinaryField128b;
168		let mut rng = StdRng::seed_from_u64(0);
169		let backend = make_portable_backend();
170		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
171		let n_vars = 4;
172
173		let mles = generate_poly_helper::<PackedType<U, F>>(&mut rng, n_vars, 2);
174		let prod_mle = MultilinearExtension::from_values(
175			iter::zip(mles[0].evals(), mles[1].evals())
176				.map(|(&a, &b)| a * b)
177				.collect(),
178		)
179		.unwrap();
180
181		let multilins = mles
182			.into_iter()
183			.map(|mle| mle.specialize_arc_dyn::<PackedType<U, FE>>())
184			.collect::<Vec<_>>();
185		let prod_multilin = prod_mle.specialize_arc_dyn::<PackedType<U, FE>>();
186
187		let mut prove_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
188		let challenges: Vec<FE> = prove_transcript.sample_vec(n_vars);
189
190		let sum = prod_multilin
191			.evaluate(backend.multilinear_query(&challenges).unwrap().to_ref())
192			.unwrap();
193
194		let composite_claims = [sumcheck::CompositeSumClaim {
195			composition: BivariateProduct {},
196			sum,
197		}];
198
199		let prod_multilins = vec![prod_multilin];
200
201		let prover = GPAProver::<FDomain, _, _, _, _>::new(
202			EvaluationOrder::LowToHigh,
203			multilins,
204			Some(prod_multilins),
205			composite_claims,
206			domain_factory,
207			&challenges,
208			&backend,
209		)
210		.unwrap();
211
212		let _ = sumcheck::batch_prove(vec![prover], &mut prove_transcript).unwrap();
213
214		let claim = GPASumcheckClaim::new(n_vars, sum).unwrap();
215
216		let sumcheck_claim = reduce_to_sumcheck(&[claim]).unwrap();
217		let sumcheck_claims = [sumcheck_claim];
218
219		let mut verify_challenger = prove_transcript.into_verifier();
220		let _: Vec<FE> = verify_challenger.sample_vec(n_vars);
221		let batch_output = sumcheck::batch_verify(
222			EvaluationOrder::LowToHigh,
223			&sumcheck_claims,
224			&mut verify_challenger,
225		)
226		.unwrap();
227		verify_challenger.finalize().unwrap();
228
229		let claim = GPASumcheckClaim::new(n_vars, sum).unwrap();
230		verify_sumcheck_outputs(&[claim], &challenges, batch_output).unwrap();
231	}
232}