binius_core/protocols/sumcheck/
eq_ind.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::{util::eq, Field, PackedField};
4use binius_math::{ArithExpr, CompositionPoly};
5use binius_utils::{bail, sorting::is_sorted_ascending};
6use getset::CopyGetters;
7
8use super::{
9	common::{CompositeSumClaim, SumcheckClaim},
10	error::{Error, VerificationError},
11};
12use crate::protocols::sumcheck::BatchSumcheckOutput;
13
14/// A group of claims about the sum of the values of multilinear composite polynomials over the
15/// boolean hypercube multiplied by the value of equality indicator.
16///
17/// Reductions transform this struct to a [SumcheckClaim] with an explicit equality indicator in
18/// the last position.
19#[derive(Debug, Clone, CopyGetters)]
20pub struct EqIndSumcheckClaim<F: Field, Composition> {
21	#[getset(get_copy = "pub")]
22	n_vars: usize,
23	#[getset(get_copy = "pub")]
24	n_multilinears: usize,
25	eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
26}
27
28impl<F: Field, Composition> EqIndSumcheckClaim<F, Composition>
29where
30	Composition: CompositionPoly<F>,
31{
32	/// Constructs a new equality indicator sumcheck claim.
33	///
34	/// ## Throws
35	///
36	/// * [`Error::InvalidComposition`] if any of the composition polynomials in the composite
37	///   claims vector do not have their number of variables equal to `n_multilinears`
38	pub fn new(
39		n_vars: usize,
40		n_multilinears: usize,
41		eq_ind_composite_sums: Vec<CompositeSumClaim<F, Composition>>,
42	) -> Result<Self, Error> {
43		for CompositeSumClaim {
44			ref composition, ..
45		} in &eq_ind_composite_sums
46		{
47			if composition.n_vars() != n_multilinears {
48				bail!(Error::InvalidComposition {
49					actual: composition.n_vars(),
50					expected: n_multilinears,
51				});
52			}
53		}
54		Ok(Self {
55			n_vars,
56			n_multilinears,
57			eq_ind_composite_sums,
58		})
59	}
60
61	/// Returns the maximum individual degree of all composite polynomials.
62	pub fn max_individual_degree(&self) -> usize {
63		self.eq_ind_composite_sums
64			.iter()
65			.map(|composite_sum| composite_sum.composition.degree())
66			.max()
67			.unwrap_or(0)
68	}
69
70	pub fn eq_ind_composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
71		&self.eq_ind_composite_sums
72	}
73}
74
75/// Requirement: eq-ind sumcheck challenges have been sampled before this is called
76pub fn reduce_to_regular_sumchecks<F: Field, Composition: CompositionPoly<F>>(
77	claims: &[EqIndSumcheckClaim<F, Composition>],
78) -> Result<Vec<SumcheckClaim<F, ExtraProduct<&Composition>>>, Error> {
79	// Check that the claims are in descending order by n_vars
80	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
81		bail!(Error::ClaimsOutOfOrder);
82	}
83
84	claims
85		.iter()
86		.map(|eq_ind_sumcheck_claim| {
87			let EqIndSumcheckClaim {
88				n_vars,
89				n_multilinears,
90				eq_ind_composite_sums,
91				..
92			} = eq_ind_sumcheck_claim;
93			SumcheckClaim::new(
94				*n_vars,
95				*n_multilinears + 1,
96				eq_ind_composite_sums
97					.iter()
98					.map(|composite_sum| CompositeSumClaim {
99						composition: ExtraProduct {
100							inner: &composite_sum.composition,
101						},
102						sum: composite_sum.sum,
103					})
104					.collect(),
105			)
106		})
107		.collect()
108}
109
110/// Verify the validity of the sumcheck outputs for a reduced eq-ind sumcheck.
111///
112/// This takes in the output of the reduced sumcheck protocol and returns the output for the
113/// eq-ind sumcheck instance. This simply strips off the multilinear evaluation of the eq indicator
114/// polynomial and verifies that the value is correct.
115///
116/// Note that due to univariatization of some rounds the number of challenges may be less than
117/// the maximum number of variables among claims.
118pub fn verify_sumcheck_outputs<F: Field, Composition: CompositionPoly<F>>(
119	claims: &[EqIndSumcheckClaim<F, Composition>],
120	eq_ind_challenges: &[F],
121	sumcheck_output: BatchSumcheckOutput<F>,
122) -> Result<BatchSumcheckOutput<F>, Error> {
123	let BatchSumcheckOutput {
124		challenges: sumcheck_challenges,
125		mut multilinear_evals,
126	} = sumcheck_output;
127
128	if multilinear_evals.len() != claims.len() {
129		bail!(VerificationError::NumberOfFinalEvaluations);
130	}
131
132	// Check that the claims are in descending order by n_vars
133	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
134		bail!(Error::ClaimsOutOfOrder);
135	}
136
137	let max_n_vars = claims
138		.first()
139		.map(|claim| claim.n_vars())
140		.unwrap_or_default();
141
142	if sumcheck_challenges.len() > max_n_vars
143		|| eq_ind_challenges.len() != sumcheck_challenges.len()
144	{
145		bail!(VerificationError::NumberOfRounds);
146	}
147
148	let mut eq_ind_eval = F::ONE;
149	let mut last_n_vars = 0;
150	for (claim, multilinear_evals) in claims.iter().zip(multilinear_evals.iter_mut()).rev() {
151		if claim.n_multilinears() + 1 != multilinear_evals.len() {
152			bail!(VerificationError::NumberOfMultilinearEvals);
153		}
154
155		while last_n_vars < claim.n_vars() && last_n_vars < sumcheck_challenges.len() {
156			let sumcheck_challenge =
157				sumcheck_challenges[sumcheck_challenges.len() - 1 - last_n_vars];
158			let eq_ind_challenge = eq_ind_challenges[eq_ind_challenges.len() - 1 - last_n_vars];
159			eq_ind_eval *= eq(sumcheck_challenge, eq_ind_challenge);
160			last_n_vars += 1;
161		}
162
163		let multilinear_evals_last = multilinear_evals
164			.pop()
165			.expect("checked above that multilinear_evals length is at least 1");
166		if eq_ind_eval != multilinear_evals_last {
167			return Err(VerificationError::IncorrectEqIndEvaluation.into());
168		}
169	}
170
171	Ok(BatchSumcheckOutput {
172		challenges: sumcheck_challenges,
173		multilinear_evals,
174	})
175}
176
177#[derive(Debug)]
178pub struct ExtraProduct<Composition> {
179	pub inner: Composition,
180}
181
182impl<P, Composition> CompositionPoly<P> for ExtraProduct<Composition>
183where
184	P: PackedField,
185	Composition: CompositionPoly<P>,
186{
187	fn n_vars(&self) -> usize {
188		self.inner.n_vars() + 1
189	}
190
191	fn degree(&self) -> usize {
192		self.inner.degree() + 1
193	}
194
195	fn expression(&self) -> ArithExpr<P::Scalar> {
196		self.inner.expression() * ArithExpr::Var(self.inner.n_vars())
197	}
198
199	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
200		let n_vars = self.n_vars();
201		if query.len() != n_vars {
202			bail!(binius_math::Error::IncorrectQuerySize { expected: n_vars });
203		}
204
205		let inner_eval = self.inner.evaluate(&query[..n_vars - 1])?;
206		Ok(inner_eval * query[n_vars - 1])
207	}
208
209	fn binary_tower_level(&self) -> usize {
210		self.inner.binary_tower_level()
211	}
212}
213
214#[cfg(test)]
215mod tests {
216	use std::iter;
217
218	use binius_field::{
219		arch::{OptimalUnderlier128b, OptimalUnderlier256b, OptimalUnderlier512b},
220		as_packed_field::{PackScalar, PackedType},
221		packed::set_packed_slice,
222		underlier::UnderlierType,
223		BinaryField128b, BinaryField8b, ExtensionField, PackedField, PackedFieldIndexable,
224		TowerField,
225	};
226	use binius_hal::make_portable_backend;
227	use binius_hash::groestl::Groestl256;
228	use binius_math::{
229		DefaultEvaluationDomainFactory, EvaluationOrder, MLEDirectAdapter, MultilinearExtension,
230		MultilinearPoly, MultilinearQuery,
231	};
232	use rand::{rngs::StdRng, Rng, SeedableRng};
233
234	use crate::{
235		composition::BivariateProduct,
236		fiat_shamir::HasherChallenger,
237		protocols::{
238			sumcheck::{
239				self, immediate_switchover_heuristic,
240				prove::eq_ind::{ConstEvalSuffix, EqIndSumcheckProverBuilder},
241				CompositeSumClaim, EqIndSumcheckClaim,
242			},
243			test_utils::AddOneComposition,
244		},
245		transcript::ProverTranscript,
246	};
247
248	fn test_prove_verify_bivariate_product_helper<U, F, FDomain>(n_vars: usize)
249	where
250		U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
251		F: TowerField + ExtensionField<FDomain>,
252		FDomain: TowerField,
253		PackedType<U, F>: PackedFieldIndexable,
254	{
255		let max_nonzero_prefix = 1 << n_vars;
256		let mut nonzero_prefixes = vec![0];
257
258		for i in 1..=n_vars {
259			nonzero_prefixes.push(1 << i);
260		}
261
262		let mut rng = StdRng::seed_from_u64(0);
263		for _ in 0..n_vars + 5 {
264			nonzero_prefixes.push(rng.gen_range(1..max_nonzero_prefix));
265		}
266
267		for nonzero_prefix in nonzero_prefixes {
268			for evaluation_order in [EvaluationOrder::LowToHigh, EvaluationOrder::HighToLow] {
269				test_prove_verify_bivariate_product_helper_under_evaluation_order::<U, F, FDomain>(
270					evaluation_order,
271					n_vars,
272					nonzero_prefix,
273				);
274			}
275		}
276	}
277
278	fn test_prove_verify_bivariate_product_helper_under_evaluation_order<U, F, FDomain>(
279		evaluation_order: EvaluationOrder,
280		n_vars: usize,
281		nonzero_prefix: usize,
282	) where
283		U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
284		F: TowerField + ExtensionField<FDomain>,
285		FDomain: TowerField,
286		PackedType<U, F>: PackedFieldIndexable,
287	{
288		let mut rng = StdRng::seed_from_u64(0);
289
290		let packed_len = 1 << n_vars.saturating_sub(PackedType::<U, F>::LOG_WIDTH);
291		let mut a_column = (0..packed_len)
292			.map(|_| PackedType::<U, F>::random(&mut rng))
293			.collect::<Vec<_>>();
294		let b_column = (0..packed_len)
295			.map(|_| PackedType::<U, F>::random(&mut rng))
296			.collect::<Vec<_>>();
297		let mut ab1_column = iter::zip(&a_column, &b_column)
298			.map(|(&a, &b)| a * b + PackedType::<U, F>::one())
299			.collect::<Vec<_>>();
300
301		for i in nonzero_prefix..1 << n_vars {
302			set_packed_slice(&mut a_column, i, F::ZERO);
303			set_packed_slice(&mut ab1_column, i, F::ONE);
304		}
305
306		let a_mle =
307			MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&a_column).unwrap());
308		let b_mle =
309			MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&b_column).unwrap());
310		let ab1_mle =
311			MLEDirectAdapter::from(MultilinearExtension::from_values_slice(&ab1_column).unwrap());
312
313		let eq_ind_challenges = (0..n_vars).map(|_| F::random(&mut rng)).collect::<Vec<_>>();
314		let sum = ab1_mle
315			.evaluate(MultilinearQuery::expand(&eq_ind_challenges).to_ref())
316			.unwrap();
317
318		let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
319
320		let backend = make_portable_backend();
321		let evaluation_domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
322
323		let composition = AddOneComposition::new(BivariateProduct {});
324
325		let composite_claim = CompositeSumClaim { sum, composition };
326
327		let prover = EqIndSumcheckProverBuilder::new(&backend)
328			.with_nonzero_scalars_prefixes(&[nonzero_prefix, 1 << n_vars])
329			.build(
330				evaluation_order,
331				vec![a_mle, b_mle],
332				&eq_ind_challenges,
333				[composite_claim.clone()],
334				evaluation_domain_factory,
335				immediate_switchover_heuristic,
336			)
337			.unwrap();
338
339		let (_, const_eval_suffix) = prover.compositions().first().unwrap();
340		assert_eq!(
341			*const_eval_suffix,
342			ConstEvalSuffix {
343				suffix: (1 << n_vars) - nonzero_prefix,
344				value: F::ONE,
345				value_at_inf: F::ZERO
346			}
347		);
348
349		let _sumcheck_proof_output = sumcheck::batch_prove(vec![prover], &mut transcript).unwrap();
350
351		let mut verifier_transcript = transcript.into_verifier();
352
353		let eq_ind_sumcheck_verifier_claim =
354			EqIndSumcheckClaim::new(n_vars, 2, vec![composite_claim]).unwrap();
355		let eq_ind_sumcheck_verifier_claims = [eq_ind_sumcheck_verifier_claim];
356		let regular_sumcheck_verifier_claims =
357			sumcheck::eq_ind::reduce_to_regular_sumchecks(&eq_ind_sumcheck_verifier_claims)
358				.unwrap();
359
360		let _sumcheck_verify_output = sumcheck::batch_verify(
361			evaluation_order,
362			&regular_sumcheck_verifier_claims,
363			&mut verifier_transcript,
364		)
365		.unwrap();
366	}
367
368	#[test]
369	fn test_eq_ind_sumcheck_prove_verify_128b() {
370		let n_vars = 8;
371
372		test_prove_verify_bivariate_product_helper::<
373			OptimalUnderlier128b,
374			BinaryField128b,
375			BinaryField8b,
376		>(n_vars);
377	}
378
379	#[test]
380	fn test_eq_ind_sumcheck_prove_verify_256b() {
381		let n_vars = 8;
382
383		// Using a 256-bit underlier with a 128-bit extension field means the packed field will have a
384		// non-trivial packing width of 2.
385		test_prove_verify_bivariate_product_helper::<
386			OptimalUnderlier256b,
387			BinaryField128b,
388			BinaryField8b,
389		>(n_vars);
390	}
391
392	#[test]
393	fn test_eq_ind_sumcheck_prove_verify_512b() {
394		let n_vars = 8;
395
396		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have a
397		// non-trivial packing width of 4.
398		test_prove_verify_bivariate_product_helper::<
399			OptimalUnderlier512b,
400			BinaryField128b,
401			BinaryField8b,
402		>(n_vars);
403	}
404}