binius_core/protocols/sumcheck/
eq_ind.rs1use binius_field::{util::eq, Field, PackedField};
4use binius_math::{ArithExpr, CompositionPoly};
5use binius_utils::{bail, sorting::is_sorted_ascending};
6use getset::CopyGetters;
7
8use super::{
9 common::{CompositeSumClaim, SumcheckClaim},
10 error::{Error, VerificationError},
11};
12use crate::protocols::sumcheck::BatchSumcheckOutput;
13
14#[derive(Debug, Clone, CopyGetters)]
20pub struct EqIndSumcheckClaim<F: Field, Composition> {
21 #[getset(get_copy = "pub")]
22 n_vars: usize,
23 #[getset(get_copy = "pub")]
24 n_multilinears: usize,
25 eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
26}
27
28impl<F: Field, Composition> EqIndSumcheckClaim<F, Composition>
29where
30 Composition: CompositionPoly<F>,
31{
32 pub fn new(
39 n_vars: usize,
40 n_multilinears: usize,
41 eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
42 ) -> Result<Self, Error> {
43 for CompositeSumClaim {
44 ref composition, ..
45 } in &eq_ind_composite_sums
46 {
47 if composition.n_vars() != n_multilinears {
48 bail!(Error::InvalidComposition {
49 actual: composition.n_vars(),
50 expected: n_multilinears,
51 });
52 }
53 }
54 Ok(Self {
55 n_vars,
56 n_multilinears,
57 eq_ind_composite_sums,
58 })
59 }
60
61 pub fn max_individual_degree(&self) -> usize {
63 self.eq_ind_composite_sums
64 .iter()
65 .map(|composite_sum| composite_sum.composition.degree())
66 .max()
67 .unwrap_or(0)
68 }
69
70 pub fn eq_ind_composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
71 &self.eq_ind_composite_sums
72 }
73}
74
75pub fn reduce_to_regular_sumchecks<F: Field, Composition: CompositionPoly<F>>(
77 claims: &[EqIndSumcheckClaim<F, Composition>],
78) -> Result<Vec<SumcheckClaim<F, ExtraProduct<&Composition>>>, Error> {
79 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
81 bail!(Error::ClaimsOutOfOrder);
82 }
83
84 claims
85 .iter()
86 .map(|eq_ind_sumcheck_claim| {
87 let EqIndSumcheckClaim {
88 n_vars,
89 n_multilinears,
90 eq_ind_composite_sums,
91 ..
92 } = eq_ind_sumcheck_claim;
93 SumcheckClaim::new(
94 *n_vars,
95 *n_multilinears + 1,
96 eq_ind_composite_sums
97 .iter()
98 .map(|composite_sum| CompositeSumClaim {
99 composition: ExtraProduct {
100 inner: &composite_sum.composition,
101 },
102 sum: composite_sum.sum,
103 })
104 .collect(),
105 )
106 })
107 .collect()
108}
109
110pub fn verify_sumcheck_outputs<F: Field, Composition: CompositionPoly<F>>(
119 claims: &[EqIndSumcheckClaim<F, Composition>],
120 eq_ind_challenges: &[F],
121 sumcheck_output: BatchSumcheckOutput<F>,
122) -> Result<BatchSumcheckOutput<F>, Error> {
123 let BatchSumcheckOutput {
124 challenges: sumcheck_challenges,
125 mut multilinear_evals,
126 } = sumcheck_output;
127
128 if multilinear_evals.len() != claims.len() {
129 bail!(VerificationError::NumberOfFinalEvaluations);
130 }
131
132 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
134 bail!(Error::ClaimsOutOfOrder);
135 }
136
137 let max_n_vars = claims
138 .first()
139 .map(|claim| claim.n_vars())
140 .unwrap_or_default();
141
142 if sumcheck_challenges.len() > max_n_vars
143 || eq_ind_challenges.len() != sumcheck_challenges.len()
144 {
145 bail!(VerificationError::NumberOfRounds);
146 }
147
148 let mut eq_ind_eval = F::ONE;
149 let mut last_n_vars = 0;
150 for (claim, multilinear_evals) in claims.iter().zip(multilinear_evals.iter_mut()).rev() {
151 if claim.n_multilinears() + 1 != multilinear_evals.len() {
152 bail!(VerificationError::NumberOfMultilinearEvals);
153 }
154
155 while last_n_vars < claim.n_vars() && last_n_vars < sumcheck_challenges.len() {
156 let sumcheck_challenge =
157 sumcheck_challenges[sumcheck_challenges.len() - 1 - last_n_vars];
158 let eq_ind_challenge = eq_ind_challenges[eq_ind_challenges.len() - 1 - last_n_vars];
159 eq_ind_eval *= eq(sumcheck_challenge, eq_ind_challenge);
160 last_n_vars += 1;
161 }
162
163 let multilinear_evals_last = multilinear_evals
164 .pop()
165 .expect("checked above that multilinear_evals length is at least 1");
166 if eq_ind_eval != multilinear_evals_last {
167 return Err(VerificationError::IncorrectEqIndEvaluation.into());
168 }
169 }
170
171 Ok(BatchSumcheckOutput {
172 challenges: sumcheck_challenges,
173 multilinear_evals,
174 })
175}
176
177#[derive(Debug)]
178pub struct ExtraProduct<Composition> {
179 pub inner: Composition,
180}
181
182impl<P, Composition> CompositionPoly<P> for ExtraProduct<Composition>
183where
184 P: PackedField,
185 Composition: CompositionPoly<P>,
186{
187 fn n_vars(&self) -> usize {
188 self.inner.n_vars() + 1
189 }
190
191 fn degree(&self) -> usize {
192 self.inner.degree() + 1
193 }
194
195 fn expression(&self) -> ArithExpr<P::Scalar> {
196 self.inner.expression() * ArithExpr::Var(self.inner.n_vars())
197 }
198
199 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
200 let n_vars = self.n_vars();
201 if query.len() != n_vars {
202 bail!(binius_math::Error::IncorrectQuerySize { expected: n_vars });
203 }
204
205 let inner_eval = self.inner.evaluate(&query[..n_vars - 1])?;
206 Ok(inner_eval * query[n_vars - 1])
207 }
208
209 fn binary_tower_level(&self) -> usize {
210 self.inner.binary_tower_level()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use std::iter;
217
218 use binius_field::{
219 arch::{OptimalUnderlier128b, OptimalUnderlier256b, OptimalUnderlier512b},
220 as_packed_field::{PackScalar, PackedType},
221 packed::set_packed_slice,
222 underlier::UnderlierType,
223 BinaryField128b, BinaryField8b, ExtensionField, PackedField, PackedFieldIndexable,
224 TowerField,
225 };
226 use binius_hal::make_portable_backend;
227 use binius_hash::groestl::Groestl256;
228 use binius_math::{
229 DefaultEvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
230 MultilinearPoly, MultilinearQuery,
231 };
232 use rand::{rngs::StdRng, Rng, SeedableRng};
233
234 use crate::{
235 composition::BivariateProduct,
236 fiat_shamir::HasherChallenger,
237 protocols::{
238 sumcheck::{
239 self, immediate_switchover_heuristic,
240 prove::eq_ind::{ConstEvalSuffix, EqIndSumcheckProverBuilder},
241 CompositeSumClaim, EqIndSumcheckClaim,
242 },
243 test_utils::AddOneComposition,
244 },
245 transcript::ProverTranscript,
246 };
247
248 fn test_prove_verify_bivariate_product_helper<U, F, FDomain>(n_vars: usize)
249 where
250 U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
251 F: TowerField + ExtensionField<FDomain>,
252 FDomain: TowerField,
253 PackedType<U, F>: PackedFieldIndexable,
254 {
255 let max_nonzero_prefix = 1 << n_vars;
256 let mut nonzero_prefixes = vec![0];
257
258 for i in 1..=n_vars {
259 nonzero_prefixes.push(1 << i);
260 }
261
262 let mut rng = StdRng::seed_from_u64(0);
263 for _ in 0..n_vars + 5 {
264 nonzero_prefixes.push(rng.gen_range(1..max_nonzero_prefix));
265 }
266
267 for nonzero_prefix in nonzero_prefixes {
268 for evaluation_order in [EvaluationOrder::LowToHigh, EvaluationOrder::HighToLow] {
269 test_prove_verify_bivariate_product_helper_under_evaluation_order::<U, F, FDomain>(
270 evaluation_order,
271 n_vars,
272 nonzero_prefix,
273 );
274 }
275 }
276 }
277
278 fn test_prove_verify_bivariate_product_helper_under_evaluation_order<U, F, FDomain>(
279 evaluation_order: EvaluationOrder,
280 n_vars: usize,
281 nonzero_prefix: usize,
282 ) where
283 U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
284 F: TowerField + ExtensionField<FDomain>,
285 FDomain: TowerField,
286 PackedType<U, F>: PackedFieldIndexable,
287 {
288 let mut rng = StdRng::seed_from_u64(0);
289
290 let packed_len = 1 << n_vars.saturating_sub(PackedType::<U, F>::LOG_WIDTH);
291 let mut a_column = (0..packed_len)
292 .map(|_| PackedType::<U, F>::random(&mut rng))
293 .collect::<Vec<_>>();
294 let b_column = (0..packed_len)
295 .map(|_| PackedType::<U, F>::random(&mut rng))
296 .collect::<Vec<_>>();
297 let mut ab1_column = iter::zip(&a_column, &b_column)
298 .map(|(&a, &b)| a * b + PackedType::<U, F>::one())
299 .collect::<Vec<_>>();
300
301 for i in nonzero_prefix..1 << n_vars {
302 set_packed_slice(&mut a_column, i, F::ZERO);
303 set_packed_slice(&mut ab1_column, i, F::ONE);
304 }
305
306 let a_mle =
307 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&a_column).unwrap());
308 let b_mle =
309 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&b_column).unwrap());
310 let ab1_mle =
311 MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&ab1_column).unwrap());
312
313 let eq_ind_challenges = (0..n_vars).map(|_| F::random(&mut rng)).collect::<Vec<_>>();
314 let sum = ab1_mle
315 .evaluate(MultilinearQuery::expand(&eq_ind_challenges).to_ref())
316 .unwrap();
317
318 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
319
320 let backend = make_portable_backend();
321 let evaluation_domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
322
323 let composition = AddOneComposition::new(BivariateProduct {});
324
325 let composite_claim = CompositeSumClaim { sum, composition };
326
327 let prover = EqIndSumcheckProverBuilder::new(&backend)
328 .with_nonzero_scalars_prefixes(&[nonzero_prefix, 1 << n_vars])
329 .build(
330 evaluation_order,
331 vec![a_mle, b_mle],
332 &eq_ind_challenges,
333 [composite_claim.clone()],
334 evaluation_domain_factory,
335 immediate_switchover_heuristic,
336 )
337 .unwrap();
338
339 let (_, const_eval_suffix) = prover.compositions().first().unwrap();
340 assert_eq!(
341 *const_eval_suffix,
342 ConstEvalSuffix {
343 suffix: (1 << n_vars) - nonzero_prefix,
344 value: F::ONE,
345 value_at_inf: F::ZERO
346 }
347 );
348
349 let _sumcheck_proof_output = sumcheck::batch_prove(vec![prover], &mut transcript).unwrap();
350
351 let mut verifier_transcript = transcript.into_verifier();
352
353 let eq_ind_sumcheck_verifier_claim =
354 EqIndSumcheckClaim::new(n_vars, 2, vec![composite_claim]).unwrap();
355 let eq_ind_sumcheck_verifier_claims = [eq_ind_sumcheck_verifier_claim];
356 let regular_sumcheck_verifier_claims =
357 sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_verifier_claims)
358 .unwrap();
359
360 let _sumcheck_verify_output = sumcheck::batch_verify(
361 evaluation_order,
362 ®ular_sumcheck_verifier_claims,
363 &mut verifier_transcript,
364 )
365 .unwrap();
366 }
367
368 #[test]
369 fn test_eq_ind_sumcheck_prove_verify_128b() {
370 let n_vars = 8;
371
372 test_prove_verify_bivariate_product_helper::<
373 OptimalUnderlier128b,
374 BinaryField128b,
375 BinaryField8b,
376 >(n_vars);
377 }
378
379 #[test]
380 fn test_eq_ind_sumcheck_prove_verify_256b() {
381 let n_vars = 8;
382
383 test_prove_verify_bivariate_product_helper::<
386 OptimalUnderlier256b,
387 BinaryField128b,
388 BinaryField8b,
389 >(n_vars);
390 }
391
392 #[test]
393 fn test_eq_ind_sumcheck_prove_verify_512b() {
394 let n_vars = 8;
395
396 test_prove_verify_bivariate_product_helper::<
399 OptimalUnderlier512b,
400 BinaryField128b,
401 BinaryField8b,
402 >(n_vars);
403 }
404}