binius_core/protocols/sumcheck/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	marker::PhantomData,
5	ops::{Mul, MulAssign},
6};
7
8use binius_field::{packed::set_packed_slice, ExtensionField, Field, PackedField, TowerField};
9use binius_hal::{make_portable_backend, ComputationBackendExt};
10use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearExtension};
11use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
12use bytemuck::zeroed_vec;
13use getset::CopyGetters;
14use itertools::izip;
15
16use super::error::Error;
17use crate::{
18	composition::{BivariateProduct, IndexComposition},
19	polynomial::Error as PolynomialError,
20	protocols::sumcheck::{
21		eq_ind::EqIndSumcheckClaim, BatchSumcheckOutput, CompositeSumClaim, SumcheckClaim,
22		VerificationError,
23	},
24};
25
26#[derive(Debug, CopyGetters)]
27pub struct ZerocheckClaim<F: Field, Composition> {
28	#[getset(get_copy = "pub")]
29	n_vars: usize,
30	#[getset(get_copy = "pub")]
31	n_multilinears: usize,
32	composite_zeros: Vec<Composition>,
33	_marker: PhantomData<F>,
34}
35
36impl<F: Field, Composition> ZerocheckClaim<F, Composition>
37where
38	Composition: CompositionPoly<F>,
39{
40	pub fn new(
41		n_vars: usize,
42		n_multilinears: usize,
43		composite_zeros: Vec<Composition>,
44	) -> Result<Self, Error> {
45		for composition in &composite_zeros {
46			if composition.n_vars() != n_multilinears {
47				bail!(Error::InvalidComposition {
48					actual: composition.n_vars(),
49					expected: n_multilinears,
50				});
51			}
52		}
53		Ok(Self {
54			n_vars,
55			n_multilinears,
56			composite_zeros,
57			_marker: PhantomData,
58		})
59	}
60
61	/// Returns the maximum individual degree of all composite polynomials.
62	pub fn max_individual_degree(&self) -> usize {
63		self.composite_zeros
64			.iter()
65			.map(|composite_zero| composite_zero.degree())
66			.max()
67			.unwrap_or(0)
68	}
69
70	pub fn composite_zeros(&self) -> &[Composition] {
71		&self.composite_zeros
72	}
73}
74
75/// Zerocheck round polynomial in Lagrange basis
76///
77/// Has `(composition_max_degree - 1) * 2^skip_rounds` length, where first `2^skip_rounds`
78/// evaluations are assumed to be zero. Addition of Lagrange basis representations only
79/// makes sense for the polynomials in the same domain.
80#[derive(Clone, Debug)]
81pub struct ZerocheckRoundEvals<F: Field> {
82	pub evals: Vec<F>,
83}
84
85impl<F: Field> ZerocheckRoundEvals<F> {
86	/// A Lagrange representation of a zero polynomial, on a given domain.
87	pub fn zeros(len: usize) -> Self {
88		Self {
89			evals: vec![F::ZERO; len],
90		}
91	}
92
93	/// An assigning addition of two polynomials in Lagrange basis. May fail,
94	/// thus it's not simply an `AddAssign` overload due to signature mismatch.
95	pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
96		if self.evals.len() != rhs.evals.len() {
97			bail!(Error::LagrangeRoundEvalsSizeMismatch);
98		}
99
100		for (lhs, rhs) in izip!(&mut self.evals, &rhs.evals) {
101			*lhs += rhs;
102		}
103
104		Ok(())
105	}
106}
107
108impl<F: Field> Mul<F> for ZerocheckRoundEvals<F> {
109	type Output = Self;
110
111	fn mul(mut self, rhs: F) -> Self::Output {
112		self *= rhs;
113		self
114	}
115}
116
117impl<F: Field> MulAssign<F> for ZerocheckRoundEvals<F> {
118	fn mul_assign(&mut self, rhs: F) {
119		for eval in &mut self.evals {
120			*eval *= rhs;
121		}
122	}
123}
124
125/// Univariatized domain size.
126///
127/// Note that composition over univariatized multilinears has degree $d (2^n - 1)$ and
128/// can be uniquely determined by its evaluations on $d (2^n - 1) + 1$ points. We however
129/// deliberately round this number up to $d 2^n$ to be able to use additive NTT interpolation
130/// techniques on round evaluations.
131pub const fn domain_size(composition_degree: usize, skip_rounds: usize) -> usize {
132	composition_degree << skip_rounds
133}
134
135/// For zerocheck, we know that a honest prover would evaluate to zero on the skipped domain.
136pub const fn extrapolated_scalars_count(composition_degree: usize, skip_rounds: usize) -> usize {
137	composition_degree.saturating_sub(1) << skip_rounds
138}
139
140/// Output of the batched zerocheck reduction
141pub struct BatchZerocheckOutput<F: Field> {
142	/// Sumcheck challenges corresponding to low indexed variables "skipped" by the univariate round.
143	/// Assigned by the univariatizing reduction sumcheck.
144	pub skipped_challenges: Vec<F>,
145	/// Sumcheck challenges correspending to high indexed variables that are not "skipped" and are
146	/// reduced via follow up multilinear eq-ind sumcheck.
147	pub unskipped_challenges: Vec<F>,
148	/// Multilinear evals of all batched claims, concatenated in the non-descending `n_vars` order.
149	pub concat_multilinear_evals: Vec<F>,
150}
151
152/// A reduction from a set of multilinear zerocheck claims to the set of univariatized eq-ind sumcheck claims.
153///
154/// Zerocheck claims should be in non-descending `n_vars` order. The resulting claims assume that a univariate
155/// round of `skip_rounds` has taken place before the eq-ind sumchecks.
156pub fn reduce_to_eq_ind_sumchecks<F: Field, Composition: CompositionPoly<F>>(
157	skip_rounds: usize,
158	claims: &[ZerocheckClaim<F, Composition>],
159) -> Result<Vec<EqIndSumcheckClaim<F, &Composition>>, Error> {
160	// Check that the claims are in non-descending order by n_vars
161	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
162		bail!(Error::ClaimsOutOfOrder);
163	}
164
165	claims
166		.iter()
167		.map(|zerocheck_claim| {
168			let &ZerocheckClaim {
169				n_vars,
170				n_multilinears,
171				ref composite_zeros,
172				..
173			} = zerocheck_claim;
174			EqIndSumcheckClaim::new(
175				n_vars.saturating_sub(skip_rounds),
176				n_multilinears,
177				composite_zeros
178					.iter()
179					.map(|composition| CompositeSumClaim {
180						composition,
181						sum: F::ZERO,
182					})
183					.collect(),
184			)
185		})
186		.collect()
187}
188
189/// Creates a "combined" sumcheck claim for the reduction from evaluations of univariatized virtual multilinear
190/// oracles to "regular" multilinear evaluations.
191///
192/// Univariatized virtual multilinear oracles are given by:
193/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot L_u(\hat{u}_1)$$
194/// It is assumed that `univariatized_multilinear_evals` came directly from a previous sumcheck with a univariate
195/// round batching `skip_rounds` variables. Multilinear evals of the reduction sumcheck are concatenated together
196/// in order to create the Lagrange coefficient MLE (in the last position) only once.
197pub fn univariatizing_reduction_claim<F: Field>(
198	skip_rounds: usize,
199	univariatized_multilinear_evals: &[impl AsRef<[F]>],
200) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
201	let n_multilinears = univariatized_multilinear_evals
202		.iter()
203		.map(|claim_evals| claim_evals.as_ref().len())
204		.sum();
205
206	// Assume that multilinear extension of Lagrange evaluations is the last multilinear,
207	// use IndexComposition to multiply each multilinear with it (using BivariateProduct).
208	let composite_sums = univariatized_multilinear_evals
209		.iter()
210		.flat_map(|claim_evals| claim_evals.as_ref())
211		.enumerate()
212		.map(|(i, &univariatized_multilinear_eval)| {
213			let composition =
214				IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
215					.expect("index composition indice correct by construction");
216
217			CompositeSumClaim {
218				composition,
219				sum: univariatized_multilinear_eval,
220			}
221		})
222		.collect();
223
224	SumcheckClaim::new(skip_rounds, n_multilinears + 1, composite_sums)
225}
226
227/// Verify the validity of sumcheck outputs for the reduction zerocheck.
228///
229/// This takes in the output of the univariatizing reduction sumcheck and returns the output that
230/// can be used to create multilinear evaluation claims. This simply strips off the evaluation of the
231/// Lagrange basis MLE at `univariate_challenge` (denoted by \hat{u}_1$) and verifies its correctness.
232pub fn verify_reduction_sumcheck_output<F>(
233	claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
234	skip_rounds: usize,
235	univariate_challenge: F,
236	reduction_sumcheck_output: BatchSumcheckOutput<F>,
237) -> Result<BatchSumcheckOutput<F>, Error>
238where
239	F: TowerField,
240{
241	let BatchSumcheckOutput {
242		challenges: reduction_sumcheck_challenges,
243		mut multilinear_evals,
244	} = reduction_sumcheck_output;
245
246	// Reduction sumcheck size equals number of skipped rounds.
247	if claim.n_vars() != skip_rounds {
248		bail!(Error::IncorrectUnivariatizingReductionClaims);
249	}
250
251	// Exactly one claim in the reduction sumcheck.
252	if reduction_sumcheck_challenges.len() != skip_rounds || multilinear_evals.len() != 1 {
253		bail!(Error::IncorrectUnivariatizingReductionSumcheck);
254	}
255
256	// Evaluate Lagrange MLE at `univariate_challenge`
257	let subspace = BinarySubspace::<F::Canonical>::with_dim(skip_rounds)?.isomorphic::<F>();
258	let evaluation_domain =
259		EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
260
261	let lagrange_mle =
262		lagrange_evals_multilinear_extension::<F, F, F>(&evaluation_domain, univariate_challenge)?;
263
264	let query = make_portable_backend().multilinear_query::<F>(&reduction_sumcheck_challenges)?;
265	let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
266
267	let first_claim_multilinear_evals = multilinear_evals
268		.first_mut()
269		.expect("exactly one claim in reduction sumcheck");
270
271	// Pop off the last multilinear eval (which is Lagrange MLE) and validate.
272	let multilinear_evals_last_eval = first_claim_multilinear_evals
273		.pop()
274		.ok_or(VerificationError::NumberOfFinalEvaluations)?;
275
276	if multilinear_evals_last_eval != expected_last_eval {
277		bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
278	}
279
280	let output = BatchSumcheckOutput {
281		challenges: reduction_sumcheck_challenges,
282		multilinear_evals,
283	};
284
285	Ok(output)
286}
287
288// Evaluate Lagrange coefficients at a challenge point and create a
289// multilinear extension of those.
290pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
291	evaluation_domain: &EvaluationDomain<FDomain>,
292	univariate_challenge: F,
293) -> Result<MultilinearExtension<P>, PolynomialError>
294where
295	FDomain: Field,
296	F: Field + ExtensionField<FDomain>,
297	P: PackedField<Scalar = F>,
298{
299	let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
300
301	let n_vars = log2_strict_usize(lagrange_evals.len());
302	let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
303
304	for (i, &lagrange_eval) in lagrange_evals.iter().enumerate() {
305		set_packed_slice(&mut packed, i, lagrange_eval);
306	}
307
308	Ok(MultilinearExtension::new(n_vars, packed)?)
309}
310
311#[cfg(test)]
312mod tests {
313	use std::sync::Arc;
314
315	use binius_field::{
316		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
317		as_packed_field::{PackScalar, PackedType},
318		underlier::{UnderlierType, WithUnderlier},
319		AESTowerField128b, AESTowerField16b, AESTowerField8b, BinaryField128b, BinaryField16b,
320		BinaryField8b, ByteSlicedAES64x128b,
321	};
322	use binius_hal::make_portable_backend;
323	use binius_hash::groestl::Groestl256;
324	use binius_math::IsomorphicEvaluationDomainFactory;
325	use rand::{prelude::StdRng, SeedableRng};
326
327	use super::*;
328	use crate::{
329		composition::ProductComposition,
330		fiat_shamir::{CanSample, HasherChallenger},
331		polynomial::CompositionScalarAdapter,
332		protocols::{
333			sumcheck::{self, prove::ZerocheckProverImpl},
334			test_utils::generate_zero_product_multilinears,
335		},
336		transcript::ProverTranscript,
337	};
338
339	fn test_zerocheck_end_to_end_helper<U, F, FDomain, FBase, FWitness>()
340	where
341		U: UnderlierType
342			+ PackScalar<F>
343			+ PackScalar<FBase>
344			+ PackScalar<FDomain>
345			+ PackScalar<FWitness>,
346		F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
347		FBase: TowerField + ExtensionField<FDomain>,
348		FDomain: TowerField,
349		FWitness: Field,
350	{
351		let max_n_vars = 6;
352		let n_multilinears = 9;
353
354		let backend = make_portable_backend();
355		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
356		let mut rng = StdRng::seed_from_u64(0);
357
358		let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
359		let triple =
360			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
361		let quad =
362			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
363
364		let prover_compositions = [
365			(
366				"pair".into(),
367				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
368				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
369			),
370			(
371				"triple".into(),
372				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
373				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
374			),
375			(
376				"quad".into(),
377				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
378				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
379			),
380		];
381
382		let prover_adapter_compositions = [
383			CompositionScalarAdapter::new(pair as Arc<dyn CompositionPoly<F>>),
384			CompositionScalarAdapter::new(triple as Arc<dyn CompositionPoly<F>>),
385			CompositionScalarAdapter::new(quad as Arc<dyn CompositionPoly<F>>),
386		];
387
388		for skip_rounds in 0..=max_n_vars {
389			let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
390
391			let prover_zerocheck_challenges: Vec<F> = proof.sample_vec(max_n_vars - skip_rounds);
392
393			let mut zerocheck_claims = Vec::new();
394			let mut zerocheck_provers = Vec::new();
395			for n_vars in 1..=max_n_vars {
396				let mut multilinears = generate_zero_product_multilinears::<
397					PackedType<U, FWitness>,
398					PackedType<U, F>,
399				>(&mut rng, n_vars, 2);
400				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
401				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
402
403				let claim = ZerocheckClaim::<F, _>::new(
404					n_vars,
405					n_multilinears,
406					prover_adapter_compositions.to_vec(),
407				)
408				.unwrap();
409
410				let prover =
411					ZerocheckProverImpl::<FDomain, FBase, PackedType<U, F>, _, _, _, _, _>::new(
412						multilinears,
413						prover_compositions.to_vec(),
414						&prover_zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..],
415						domain_factory.clone(),
416						&backend,
417					)
418					.unwrap();
419
420				zerocheck_claims.push(claim);
421				zerocheck_provers.push(prover);
422			}
423
424			let prover_zerocheck_output =
425				sumcheck::prove::batch_prove_zerocheck::<F, FDomain, PackedType<U, F>, _, _>(
426					zerocheck_provers,
427					skip_rounds,
428					&mut proof,
429				)
430				.unwrap();
431
432			let mut verifier_proof = proof.into_verifier();
433
434			let verifier_zerocheck_output = sumcheck::batch_verify_zerocheck(
435				&zerocheck_claims,
436				skip_rounds,
437				&mut verifier_proof,
438			)
439			.unwrap();
440
441			verifier_proof.finalize().unwrap();
442
443			assert_eq!(
444				prover_zerocheck_output.skipped_challenges,
445				verifier_zerocheck_output.skipped_challenges
446			);
447			assert_eq!(
448				prover_zerocheck_output.unskipped_challenges,
449				verifier_zerocheck_output.unskipped_challenges
450			);
451			assert_eq!(
452				prover_zerocheck_output.concat_multilinear_evals,
453				verifier_zerocheck_output.concat_multilinear_evals,
454			);
455		}
456	}
457
458	#[test]
459	fn test_zerocheck_end_to_end_basic() {
460		test_zerocheck_end_to_end_helper::<
461			OptimalUnderlier128b,
462			BinaryField128b,
463			BinaryField16b,
464			BinaryField16b,
465			BinaryField8b,
466		>()
467	}
468
469	#[test]
470	fn test_zerocheck_end_to_end_with_nontrivial_packing() {
471		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have a
472		// non-trivial packing width of 4.
473		test_zerocheck_end_to_end_helper::<
474			OptimalUnderlier512b,
475			BinaryField128b,
476			BinaryField16b,
477			BinaryField16b,
478			BinaryField8b,
479		>()
480	}
481
482	#[test]
483	fn test_zerocheck_end_to_end_bytesliced() {
484		test_zerocheck_end_to_end_helper::<
485			<ByteSlicedAES64x128b as WithUnderlier>::Underlier,
486			AESTowerField128b,
487			AESTowerField16b,
488			AESTowerField16b,
489			AESTowerField8b,
490		>()
491	}
492}