binius_core/protocols/sumcheck/
zerocheck.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	marker::PhantomData,
5	ops::{Mul, MulAssign},
6};
7
8use binius_field::{ExtensionField, Field, PackedField, TowerField, packed::set_packed_slice};
9use binius_hal::{ComputationBackendExt, make_portable_backend};
10use binius_math::{BinarySubspace, CompositionPoly, EvaluationDomain, MultilinearExtension};
11use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
12use bytemuck::zeroed_vec;
13use getset::CopyGetters;
14use itertools::izip;
15
16use super::error::Error;
17use crate::{
18	composition::{BivariateProduct, IndexComposition},
19	polynomial::Error as PolynomialError,
20	protocols::sumcheck::{
21		BatchSumcheckOutput, CompositeSumClaim, SumcheckClaim, VerificationError,
22		eq_ind::EqIndSumcheckClaim,
23	},
24};
25
26#[derive(Debug, CopyGetters)]
27pub struct ZerocheckClaim<F: Field, Composition> {
28	#[getset(get_copy = "pub")]
29	n_vars: usize,
30	#[getset(get_copy = "pub")]
31	n_multilinears: usize,
32	composite_zeros: Vec<Composition>,
33	_marker: PhantomData<F>,
34}
35
36impl<F: Field, Composition> ZerocheckClaim<F, Composition>
37where
38	Composition: CompositionPoly<F>,
39{
40	pub fn new(
41		n_vars: usize,
42		n_multilinears: usize,
43		composite_zeros: Vec<Composition>,
44	) -> Result<Self, Error> {
45		for composition in &composite_zeros {
46			if composition.n_vars() != n_multilinears {
47				bail!(Error::InvalidComposition {
48					actual: composition.n_vars(),
49					expected: n_multilinears,
50				});
51			}
52		}
53		Ok(Self {
54			n_vars,
55			n_multilinears,
56			composite_zeros,
57			_marker: PhantomData,
58		})
59	}
60
61	/// Returns the maximum individual degree of all composite polynomials.
62	pub fn max_individual_degree(&self) -> usize {
63		self.composite_zeros
64			.iter()
65			.map(|composite_zero| composite_zero.degree())
66			.max()
67			.unwrap_or(0)
68	}
69
70	pub fn composite_zeros(&self) -> &[Composition] {
71		&self.composite_zeros
72	}
73}
74
75/// Zerocheck round polynomial in Lagrange basis
76///
77/// Has `(composition_max_degree - 1) * 2^skip_rounds` length, where first `2^skip_rounds`
78/// evaluations are assumed to be zero. Addition of Lagrange basis representations only
79/// makes sense for the polynomials in the same domain.
80#[derive(Clone, Debug)]
81pub struct ZerocheckRoundEvals<F: Field> {
82	pub evals: Vec<F>,
83}
84
85impl<F: Field> ZerocheckRoundEvals<F> {
86	/// A Lagrange representation of a zero polynomial, on a given domain.
87	pub fn zeros(len: usize) -> Self {
88		Self {
89			evals: vec![F::ZERO; len],
90		}
91	}
92
93	/// An assigning addition of two polynomials in Lagrange basis. May fail,
94	/// thus it's not simply an `AddAssign` overload due to signature mismatch.
95	pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
96		if self.evals.len() != rhs.evals.len() {
97			bail!(Error::LagrangeRoundEvalsSizeMismatch);
98		}
99
100		for (lhs, rhs) in izip!(&mut self.evals, &rhs.evals) {
101			*lhs += rhs;
102		}
103
104		Ok(())
105	}
106}
107
108impl<F: Field> Mul<F> for ZerocheckRoundEvals<F> {
109	type Output = Self;
110
111	fn mul(mut self, rhs: F) -> Self::Output {
112		self *= rhs;
113		self
114	}
115}
116
117impl<F: Field> MulAssign<F> for ZerocheckRoundEvals<F> {
118	fn mul_assign(&mut self, rhs: F) {
119		for eval in &mut self.evals {
120			*eval *= rhs;
121		}
122	}
123}
124
125/// Univariatized domain size.
126///
127/// Note that composition over univariatized multilinears has degree $d (2^n - 1)$ and
128/// can be uniquely determined by its evaluations on $d (2^n - 1) + 1$ points. We however
129/// deliberately round this number up to $d 2^n$ to be able to use additive NTT interpolation
130/// techniques on round evaluations.
131pub const fn domain_size(composition_degree: usize, skip_rounds: usize) -> usize {
132	composition_degree << skip_rounds
133}
134
135/// For zerocheck, we know that a honest prover would evaluate to zero on the skipped domain.
136pub const fn extrapolated_scalars_count(composition_degree: usize, skip_rounds: usize) -> usize {
137	composition_degree.saturating_sub(1) << skip_rounds
138}
139
140/// Output of the batched zerocheck reduction
141pub struct BatchZerocheckOutput<F: Field> {
142	/// Sumcheck challenges corresponding to low indexed variables "skipped" by the univariate
143	/// round. Assigned by the univariatizing reduction sumcheck.
144	pub skipped_challenges: Vec<F>,
145	/// Sumcheck challenges corresponding to high indexed variables that are not "skipped" and are
146	/// reduced via follow up multilinear eq-ind sumcheck.
147	pub unskipped_challenges: Vec<F>,
148	/// Multilinear evals of all batched claims, concatenated in the non-descending `n_vars` order.
149	pub concat_multilinear_evals: Vec<F>,
150}
151
152/// A reduction from a set of multilinear zerocheck claims to the set of univariatized eq-ind
153/// sumcheck claims.
154///
155/// Zerocheck claims should be in non-descending `n_vars` order. The resulting claims assume that a
156/// univariate round of `skip_rounds` has taken place before the eq-ind sumchecks.
157pub fn reduce_to_eq_ind_sumchecks<F: Field, Composition: CompositionPoly<F>>(
158	skip_rounds: usize,
159	claims: &[ZerocheckClaim<F, Composition>],
160) -> Result<Vec<EqIndSumcheckClaim<F, &Composition>>, Error> {
161	// Check that the claims are in non-descending order by n_vars
162	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars())) {
163		bail!(Error::ClaimsOutOfOrder);
164	}
165
166	claims
167		.iter()
168		.map(|zerocheck_claim| {
169			let &ZerocheckClaim {
170				n_vars,
171				n_multilinears,
172				ref composite_zeros,
173				..
174			} = zerocheck_claim;
175			EqIndSumcheckClaim::new(
176				n_vars.saturating_sub(skip_rounds),
177				n_multilinears,
178				composite_zeros
179					.iter()
180					.map(|composition| CompositeSumClaim {
181						composition,
182						sum: F::ZERO,
183					})
184					.collect(),
185			)
186		})
187		.collect()
188}
189
190/// Creates a "combined" sumcheck claim for the reduction from evaluations of univariatized virtual
191/// multilinear oracles to "regular" multilinear evaluations.
192///
193/// Univariatized virtual multilinear oracles are given by:
194/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot
195/// L_u(\hat{u}_1)$$ It is assumed that `univariatized_multilinear_evals` came directly from a
196/// previous sumcheck with a univariate round batching `skip_rounds` variables. Multilinear evals of
197/// the reduction sumcheck are concatenated together in order to create the Lagrange coefficient MLE
198/// (in the last position) only once.
199pub fn univariatizing_reduction_claim<F: Field>(
200	skip_rounds: usize,
201	univariatized_multilinear_evals: &[impl AsRef<[F]>],
202) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
203	let n_multilinears = univariatized_multilinear_evals
204		.iter()
205		.map(|claim_evals| claim_evals.as_ref().len())
206		.sum();
207
208	// Assume that multilinear extension of Lagrange evaluations is the last multilinear,
209	// use IndexComposition to multiply each multilinear with it (using BivariateProduct).
210	let composite_sums = univariatized_multilinear_evals
211		.iter()
212		.flat_map(|claim_evals| claim_evals.as_ref())
213		.enumerate()
214		.map(|(i, &univariatized_multilinear_eval)| {
215			let composition =
216				IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
217					.expect("index composition indice correct by construction");
218
219			CompositeSumClaim {
220				composition,
221				sum: univariatized_multilinear_eval,
222			}
223		})
224		.collect();
225
226	SumcheckClaim::new(skip_rounds, n_multilinears + 1, composite_sums)
227}
228
229/// Verify the validity of sumcheck outputs for the reduction zerocheck.
230///
231/// This takes in the output of the univariatizing reduction sumcheck and returns the output that
232/// can be used to create multilinear evaluation claims. This simply strips off the evaluation of
233/// the Lagrange basis MLE at `univariate_challenge` (denoted by \hat{u}_1$) and verifies its
234/// correctness.
235pub fn verify_reduction_sumcheck_output<F>(
236	claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
237	skip_rounds: usize,
238	univariate_challenge: F,
239	reduction_sumcheck_output: BatchSumcheckOutput<F>,
240) -> Result<BatchSumcheckOutput<F>, Error>
241where
242	F: TowerField,
243{
244	let BatchSumcheckOutput {
245		challenges: reduction_sumcheck_challenges,
246		mut multilinear_evals,
247	} = reduction_sumcheck_output;
248
249	// Reduction sumcheck size equals number of skipped rounds.
250	if claim.n_vars() != skip_rounds {
251		bail!(Error::IncorrectUnivariatizingReductionClaims);
252	}
253
254	// Exactly one claim in the reduction sumcheck.
255	if reduction_sumcheck_challenges.len() != skip_rounds || multilinear_evals.len() != 1 {
256		bail!(Error::IncorrectUnivariatizingReductionSumcheck);
257	}
258
259	// Evaluate Lagrange MLE at `univariate_challenge`
260	let subspace = BinarySubspace::<F::Canonical>::with_dim(skip_rounds)?.isomorphic::<F>();
261	let evaluation_domain =
262		EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
263
264	let lagrange_mle =
265		lagrange_evals_multilinear_extension::<F, F, F>(&evaluation_domain, univariate_challenge)?;
266
267	let query = make_portable_backend().multilinear_query::<F>(&reduction_sumcheck_challenges)?;
268	let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
269
270	let first_claim_multilinear_evals = multilinear_evals
271		.first_mut()
272		.expect("exactly one claim in reduction sumcheck");
273
274	// Pop off the last multilinear eval (which is Lagrange MLE) and validate.
275	let multilinear_evals_last_eval = first_claim_multilinear_evals
276		.pop()
277		.ok_or(VerificationError::NumberOfFinalEvaluations)?;
278
279	if multilinear_evals_last_eval != expected_last_eval {
280		bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
281	}
282
283	let output = BatchSumcheckOutput {
284		challenges: reduction_sumcheck_challenges,
285		multilinear_evals,
286	};
287
288	Ok(output)
289}
290
291// Evaluate Lagrange coefficients at a challenge point and create a
292// multilinear extension of those.
293pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
294	evaluation_domain: &EvaluationDomain<FDomain>,
295	univariate_challenge: F,
296) -> Result<MultilinearExtension<P>, PolynomialError>
297where
298	FDomain: Field,
299	F: Field + ExtensionField<FDomain>,
300	P: PackedField<Scalar = F>,
301{
302	let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
303
304	let n_vars = log2_strict_usize(lagrange_evals.len());
305	let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
306
307	for (i, &lagrange_eval) in lagrange_evals.iter().enumerate() {
308		set_packed_slice(&mut packed, i, lagrange_eval);
309	}
310
311	Ok(MultilinearExtension::new(n_vars, packed)?)
312}
313
314#[cfg(test)]
315mod tests {
316	use std::sync::Arc;
317
318	use binius_field::{
319		AESTowerField8b, AESTowerField16b, AESTowerField128b, BinaryField8b, BinaryField16b,
320		BinaryField128b, ByteSlicedAES64x128b,
321		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
322		as_packed_field::{PackScalar, PackedType},
323		underlier::{UnderlierType, WithUnderlier},
324	};
325	use binius_hal::make_portable_backend;
326	use binius_hash::groestl::Groestl256;
327	use binius_math::IsomorphicEvaluationDomainFactory;
328	use rand::{SeedableRng, prelude::StdRng};
329
330	use super::*;
331	use crate::{
332		composition::ProductComposition,
333		fiat_shamir::{CanSample, HasherChallenger},
334		polynomial::CompositionScalarAdapter,
335		protocols::{
336			sumcheck::{self, prove::ZerocheckProverImpl},
337			test_utils::generate_zero_product_multilinears,
338		},
339		transcript::ProverTranscript,
340	};
341
342	fn test_zerocheck_end_to_end_helper<U, F, FDomain, FBase, FWitness>()
343	where
344		U: UnderlierType
345			+ PackScalar<F>
346			+ PackScalar<FBase>
347			+ PackScalar<FDomain>
348			+ PackScalar<FWitness>,
349		F: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
350		FBase: TowerField + ExtensionField<FDomain>,
351		FDomain: TowerField,
352		FWitness: Field,
353	{
354		let max_n_vars = 6;
355		let n_multilinears = 9;
356
357		let backend = make_portable_backend();
358		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
359		let mut rng = StdRng::seed_from_u64(0);
360
361		let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
362		let triple =
363			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
364		let quad =
365			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
366
367		let prover_compositions = [
368			(
369				"pair".into(),
370				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
371				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
372			),
373			(
374				"triple".into(),
375				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
376				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
377			),
378			(
379				"quad".into(),
380				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
381				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, F>>>,
382			),
383		];
384
385		let prover_adapter_compositions = [
386			CompositionScalarAdapter::new(pair as Arc<dyn CompositionPoly<F>>),
387			CompositionScalarAdapter::new(triple as Arc<dyn CompositionPoly<F>>),
388			CompositionScalarAdapter::new(quad as Arc<dyn CompositionPoly<F>>),
389		];
390
391		for skip_rounds in 0..=max_n_vars {
392			let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
393
394			let prover_zerocheck_challenges: Vec<F> = proof.sample_vec(max_n_vars - skip_rounds);
395
396			let mut zerocheck_claims = Vec::new();
397			let mut zerocheck_provers = Vec::new();
398			for n_vars in 1..=max_n_vars {
399				let mut multilinears = generate_zero_product_multilinears::<
400					PackedType<U, FWitness>,
401					PackedType<U, F>,
402				>(&mut rng, n_vars, 2);
403				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
404				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
405
406				let claim = ZerocheckClaim::<F, _>::new(
407					n_vars,
408					n_multilinears,
409					prover_adapter_compositions.to_vec(),
410				)
411				.unwrap();
412
413				let prover =
414					ZerocheckProverImpl::<FDomain, FBase, PackedType<U, F>, _, _, _, _, _>::new(
415						multilinears,
416						prover_compositions.to_vec(),
417						&prover_zerocheck_challenges[max_n_vars - n_vars.max(skip_rounds)..],
418						domain_factory.clone(),
419						&backend,
420					)
421					.unwrap();
422
423				zerocheck_claims.push(claim);
424				zerocheck_provers.push(prover);
425			}
426
427			let prover_zerocheck_output =
428				sumcheck::prove::batch_prove_zerocheck::<F, FDomain, PackedType<U, F>, _, _>(
429					zerocheck_provers,
430					skip_rounds,
431					&mut proof,
432				)
433				.unwrap();
434
435			let mut verifier_proof = proof.into_verifier();
436
437			let verifier_zerocheck_output = sumcheck::batch_verify_zerocheck(
438				&zerocheck_claims,
439				skip_rounds,
440				&mut verifier_proof,
441			)
442			.unwrap();
443
444			verifier_proof.finalize().unwrap();
445
446			assert_eq!(
447				prover_zerocheck_output.skipped_challenges,
448				verifier_zerocheck_output.skipped_challenges
449			);
450			assert_eq!(
451				prover_zerocheck_output.unskipped_challenges,
452				verifier_zerocheck_output.unskipped_challenges
453			);
454			assert_eq!(
455				prover_zerocheck_output.concat_multilinear_evals,
456				verifier_zerocheck_output.concat_multilinear_evals,
457			);
458		}
459	}
460
461	#[test]
462	fn test_zerocheck_end_to_end_basic() {
463		test_zerocheck_end_to_end_helper::<
464			OptimalUnderlier128b,
465			BinaryField128b,
466			BinaryField16b,
467			BinaryField16b,
468			BinaryField8b,
469		>()
470	}
471
472	#[test]
473	fn test_zerocheck_end_to_end_with_nontrivial_packing() {
474		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have
475		// a non-trivial packing width of 4.
476		test_zerocheck_end_to_end_helper::<
477			OptimalUnderlier512b,
478			BinaryField128b,
479			BinaryField16b,
480			BinaryField16b,
481			BinaryField8b,
482		>()
483	}
484
485	#[test]
486	fn test_zerocheck_end_to_end_bytesliced() {
487		test_zerocheck_end_to_end_helper::<
488			<ByteSlicedAES64x128b as WithUnderlier>::Underlier,
489			AESTowerField128b,
490			AESTowerField16b,
491			AESTowerField16b,
492			AESTowerField8b,
493		>()
494	}
495}