binius_core/protocols/sumcheck/
univariate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::{self, repeat_n},
5	ops::{Mul, MulAssign},
6};
7
8use binius_field::{ExtensionField, Field, PackedFieldIndexable, TowerField};
9use binius_hal::{make_portable_backend, ComputationBackendExt};
10use binius_math::{
11	EvaluationDomain, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory,
12	MultilinearExtension,
13};
14use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
15use bytemuck::zeroed_vec;
16
17use crate::{
18	composition::{BivariateProduct, IndexComposition},
19	polynomial::Error as PolynomialError,
20	protocols::sumcheck::{
21		BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, VerificationError,
22	},
23};
24
25/// A univariate polynomial in Lagrange basis.
26///
27/// The coefficient at position `i` in the `lagrange_coeffs` corresponds to evaluation
28/// at `i+zeros_prefix_len`-th field element of some agreed upon domain. Coefficients
29/// at positions `0..zeros_prefix_len` are zero. Addition of Lagrange basis representations
30/// only makes sense for the polynomials in the same domain.
31#[derive(Clone, Debug)]
32pub struct LagrangeRoundEvals<F: Field> {
33	pub zeros_prefix_len: usize,
34	pub evals: Vec<F>,
35}
36
37impl<F: Field> LagrangeRoundEvals<F> {
38	/// A Lagrange representation of a zero polynomial, on a given domain.
39	pub const fn zeros(zeros_prefix_len: usize) -> Self {
40		Self {
41			zeros_prefix_len,
42			evals: Vec::new(),
43		}
44	}
45
46	/// An assigning addition of two polynomials in Lagrange basis. May fail,
47	/// thus it's not simply an `AddAssign` overload due to signature mismatch.
48	pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
49		let lhs_len = self.zeros_prefix_len + self.evals.len();
50		let rhs_len = rhs.zeros_prefix_len + rhs.evals.len();
51
52		if lhs_len != rhs_len {
53			bail!(Error::LagrangeRoundEvalsSizeMismatch);
54		}
55
56		let start_idx = if rhs.zeros_prefix_len < self.zeros_prefix_len {
57			self.evals
58				.splice(0..0, repeat_n(F::ZERO, self.zeros_prefix_len - rhs.zeros_prefix_len));
59			self.zeros_prefix_len = rhs.zeros_prefix_len;
60			0
61		} else {
62			rhs.zeros_prefix_len - self.zeros_prefix_len
63		};
64
65		for (lhs, rhs) in self.evals[start_idx..].iter_mut().zip(&rhs.evals) {
66			*lhs += rhs;
67		}
68
69		Ok(())
70	}
71}
72
73impl<F: Field> Mul<F> for LagrangeRoundEvals<F> {
74	type Output = Self;
75
76	fn mul(mut self, rhs: F) -> Self::Output {
77		self *= rhs;
78		self
79	}
80}
81
82impl<F: Field> MulAssign<F> for LagrangeRoundEvals<F> {
83	fn mul_assign(&mut self, rhs: F) {
84		for eval in &mut self.evals {
85			*eval *= rhs;
86		}
87	}
88}
89/// Creates sumcheck claims for the reduction from evaluations of univariatized virtual multilinear oracles to
90/// "regular" multilinear evaluations.
91///
92/// Univariatized virtual multilinear oracles are given by:
93/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot L_u(\hat{u}_1)$$
94/// It is assumed that `univariatized_multilinear_evals` came directly from a previous sumcheck with a univariate
95/// round batching `skip_rounds` variables.
96pub fn univariatizing_reduction_claim<F: Field>(
97	skip_rounds: usize,
98	univariatized_multilinear_evals: &[F],
99) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
100	let composite_sums =
101		univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
102	SumcheckClaim::new(skip_rounds, univariatized_multilinear_evals.len() + 1, composite_sums)
103}
104
105/// Verify the validity of sumcheck outputs for the reduction zerocheck.
106///
107/// This takes in the output of the batched univariatizing reduction sumcheck and returns the output
108/// that can be used to create multilinear evaluation claims. This simply strips off the evaluation of
109/// the multilinear extension of Lagrange polynomials evaluations at `univariate_challenge` (denoted by
110/// $\hat{u}_1$) and verifies that this value is correct. The argument `unskipped_sumcheck_challenges`
111/// holds the challenges of the sumcheck following the univariate round.
112pub fn verify_sumcheck_outputs<F>(
113	claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
114	univariate_challenge: F,
115	unskipped_sumcheck_challenges: &[F],
116	sumcheck_output: BatchSumcheckOutput<F>,
117) -> Result<BatchSumcheckOutput<F>, Error>
118where
119	F: TowerField,
120{
121	let BatchSumcheckOutput {
122		challenges: reduction_sumcheck_challenges,
123		mut multilinear_evals,
124	} = sumcheck_output;
125
126	assert_eq!(multilinear_evals.len(), claims.len());
127
128	// Check that the claims are in descending order by n_vars
129	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
130		bail!(Error::ClaimsOutOfOrder);
131	}
132
133	let max_n_vars = claims
134		.first()
135		.map(|claim| claim.n_vars())
136		.unwrap_or_default();
137
138	assert_eq!(reduction_sumcheck_challenges.len(), max_n_vars);
139
140	for (claim, multilinear_evals) in iter::zip(claims, multilinear_evals.iter_mut()) {
141		let skip_rounds = claim.n_vars();
142
143		let evaluation_domain = IsomorphicEvaluationDomainFactory::<F::Canonical>::default()
144			.create(1 << skip_rounds)?;
145
146		let lagrange_mle = lagrange_evals_multilinear_extension::<F, F, F>(
147			&evaluation_domain,
148			univariate_challenge,
149		)?;
150
151		let query = make_portable_backend()
152			.multilinear_query::<F>(&reduction_sumcheck_challenges[max_n_vars - skip_rounds..])?;
153		let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
154
155		let multilinear_evals_last = multilinear_evals
156			.pop()
157			.ok_or(VerificationError::NumberOfFinalEvaluations)?;
158
159		if multilinear_evals_last != expected_last_eval {
160			bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
161		}
162	}
163
164	let mut challenges = Vec::new();
165	challenges.extend(reduction_sumcheck_challenges);
166	challenges.extend(unskipped_sumcheck_challenges);
167
168	let output = BatchSumcheckOutput {
169		challenges,
170		multilinear_evals,
171	};
172
173	Ok(output)
174}
175
176// Helper method to create univariatized multilinear oracle evaluation claims.
177// Assumes that multilinear extension of Lagrange evaluations is the last multilinear,
178// uses IndexComposition to multiply each multilinear with it (using BivariateProduct).
179pub(super) fn univariatizing_reduction_composite_sum_claims<F: Field>(
180	univariatized_multilinear_evals: &[F],
181) -> Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>> {
182	let n_multilinears = univariatized_multilinear_evals.len();
183	univariatized_multilinear_evals
184		.iter()
185		.enumerate()
186		.map(|(i, &univariatized_multilinear_eval)| {
187			let composition =
188				IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
189					.expect("index composition indice correct by construction");
190
191			CompositeSumClaim {
192				composition,
193				sum: univariatized_multilinear_eval,
194			}
195		})
196		.collect()
197}
198
199// Given EvaluationDomain, evaluates Lagrange coefficients at a challenge point
200// and creates a multilinear extension of said evaluations.
201pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
202	evaluation_domain: &EvaluationDomain<FDomain>,
203	univariate_challenge: F,
204) -> Result<MultilinearExtension<P>, PolynomialError>
205where
206	FDomain: Field,
207	F: Field + ExtensionField<FDomain>,
208	P: PackedFieldIndexable<Scalar = F>,
209{
210	let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
211
212	let n_vars = log2_strict_usize(lagrange_evals.len());
213	let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
214	let scalars = P::unpack_scalars_mut(packed.as_mut_slice());
215	scalars[..lagrange_evals.len()].copy_from_slice(lagrange_evals.as_slice());
216
217	Ok(MultilinearExtension::new(n_vars, packed)?)
218}
219
220#[cfg(test)]
221mod tests {
222	use std::{iter, sync::Arc};
223
224	use binius_field::{
225		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
226		as_packed_field::{PackScalar, PackedType},
227		underlier::UnderlierType,
228		AESTowerField128b, AESTowerField16b, AESTowerField8b, BinaryField128b, BinaryField16b,
229		Field, PackedBinaryField1x128b, PackedBinaryField4x32b, PackedFieldIndexable, TowerField,
230	};
231	use binius_hal::ComputationBackend;
232	use binius_math::{
233		CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
234		IsomorphicEvaluationDomainFactory, MultilinearPoly,
235	};
236	use groestl_crypto::Groestl256;
237	use rand::{prelude::StdRng, SeedableRng};
238
239	use super::*;
240	use crate::{
241		composition::{IndexComposition, ProductComposition},
242		fiat_shamir::{CanSample, HasherChallenger},
243		polynomial::CompositionScalarAdapter,
244		protocols::{
245			sumcheck::{
246				batch_verify, batch_verify_with_start, batch_verify_zerocheck_univariate_round,
247				prove::{
248					batch_prove, batch_prove_with_start, batch_prove_zerocheck_univariate_round,
249					univariate::{reduce_to_skipped_projection, univariatizing_reduction_prover},
250					SumcheckProver, UnivariateZerocheck,
251				},
252				standard_switchover_heuristic,
253				zerocheck::reduce_to_sumchecks,
254				ZerocheckClaim,
255			},
256			test_utils::generate_zero_product_multilinears,
257		},
258		transcript::ProverTranscript,
259	};
260
261	#[test]
262	fn test_univariatizing_reduction_end_to_end() {
263		type F = BinaryField128b;
264		type FDomain = BinaryField16b;
265		type P = PackedBinaryField4x32b;
266		type PE = PackedBinaryField1x128b;
267
268		let backend = make_portable_backend();
269		let mut rng = StdRng::seed_from_u64(0);
270
271		let regular_vars = 3;
272		let max_skip_rounds = 3;
273		let n_multilinears = 2;
274
275		let evaluation_domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
276
277		let univariate_challenge = <F as Field>::random(&mut rng);
278
279		let sumcheck_challenges = (0..regular_vars)
280			.map(|_| <F as Field>::random(&mut rng))
281			.collect::<Vec<_>>();
282
283		let mut provers = Vec::new();
284		let mut all_multilinears = Vec::new();
285		let mut all_univariatized_multilinear_evals = Vec::new();
286		for skip_rounds in (0..=max_skip_rounds).rev() {
287			let n_vars = skip_rounds + regular_vars;
288
289			let multilinears =
290				generate_zero_product_multilinears::<P, PE>(&mut rng, n_vars, n_multilinears);
291			all_multilinears.push((skip_rounds, multilinears.clone()));
292
293			let domain = evaluation_domain_factory
294				.clone()
295				.create(1 << skip_rounds)
296				.unwrap();
297
298			let query = backend.multilinear_query(&sumcheck_challenges).unwrap();
299			let univariatized_multilinear_evals = multilinears
300				.iter()
301				.map(|multilinear| {
302					let partial_eval = backend
303						.evaluate_partial_high(multilinear, query.to_ref())
304						.unwrap();
305					domain
306						.extrapolate(PE::unpack_scalars(partial_eval.evals()), univariate_challenge)
307						.unwrap()
308				})
309				.collect::<Vec<_>>();
310
311			all_univariatized_multilinear_evals.push(univariatized_multilinear_evals.clone());
312
313			let reduced_multilinears =
314				reduce_to_skipped_projection(multilinears, &sumcheck_challenges, &backend).unwrap();
315
316			let prover = univariatizing_reduction_prover(
317				reduced_multilinears,
318				&univariatized_multilinear_evals,
319				univariate_challenge,
320				evaluation_domain_factory.clone(),
321				&backend,
322			)
323			.unwrap();
324
325			provers.push(prover);
326		}
327
328		let mut prove_challenger = ProverTranscript::<HasherChallenger<Groestl256>>::new();
329		let batch_sumcheck_output_prove = batch_prove(provers, &mut prove_challenger).unwrap();
330
331		for ((skip_rounds, multilinears), multilinear_evals) in
332			iter::zip(&all_multilinears, batch_sumcheck_output_prove.multilinear_evals)
333		{
334			assert_eq!(multilinears.len() + 1, multilinear_evals.len());
335
336			let mut query =
337				batch_sumcheck_output_prove.challenges[max_skip_rounds - skip_rounds..].to_vec();
338			query.extend(sumcheck_challenges.as_slice());
339
340			let query = backend.multilinear_query(&query).unwrap();
341
342			for (multilinear, eval) in iter::zip(multilinears, multilinear_evals) {
343				assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
344			}
345		}
346
347		let claims = iter::zip(&all_multilinears, &all_univariatized_multilinear_evals)
348			.map(|((skip_rounds, _q), univariatized_multilinear_evals)| {
349				univariatizing_reduction_claim(*skip_rounds, univariatized_multilinear_evals)
350					.unwrap()
351			})
352			.collect::<Vec<_>>();
353
354		let mut verify_challenger = prove_challenger.into_verifier();
355		let batch_sumcheck_output_verify =
356			batch_verify(EvaluationOrder::LowToHigh, claims.as_slice(), &mut verify_challenger)
357				.unwrap();
358		let batch_sumcheck_output_post = verify_sumcheck_outputs(
359			claims.as_slice(),
360			univariate_challenge,
361			&sumcheck_challenges,
362			batch_sumcheck_output_verify,
363		)
364		.unwrap();
365
366		for ((skip_rounds, multilinears), evals) in
367			iter::zip(all_multilinears, batch_sumcheck_output_post.multilinear_evals)
368		{
369			let mut query = batch_sumcheck_output_post.challenges
370				[max_skip_rounds - skip_rounds..max_skip_rounds]
371				.to_vec();
372			query.extend(sumcheck_challenges.as_slice());
373
374			let query = backend.multilinear_query(&query).unwrap();
375
376			for (multilinear, eval) in iter::zip(multilinears, evals) {
377				assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
378			}
379		}
380	}
381
382	#[test]
383	fn test_univariatized_zerocheck_end_to_end_basic() {
384		test_univariatized_zerocheck_end_to_end_helper::<
385			OptimalUnderlier128b,
386			BinaryField128b,
387			AESTowerField128b,
388			AESTowerField16b,
389			AESTowerField16b,
390			AESTowerField8b,
391		>()
392	}
393
394	#[test]
395	fn test_univariatized_zerocheck_end_to_end_with_nontrivial_packing() {
396		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have a
397		// non-trivial packing width of 4.
398		test_univariatized_zerocheck_end_to_end_helper::<
399			OptimalUnderlier512b,
400			BinaryField128b,
401			AESTowerField128b,
402			AESTowerField16b,
403			AESTowerField16b,
404			AESTowerField8b,
405		>()
406	}
407
408	fn test_univariatized_zerocheck_end_to_end_helper<U, F, FI, FDomain, FBase, FWitness>()
409	where
410		U: UnderlierType
411			+ PackScalar<FI>
412			+ PackScalar<FBase>
413			+ PackScalar<FDomain>
414			+ PackScalar<FWitness>,
415		F: TowerField + From<FI>,
416		FI: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
417		FBase: TowerField + ExtensionField<FDomain>,
418		FDomain: TowerField,
419		FWitness: Field,
420		PackedType<U, FBase>: PackedFieldIndexable,
421		PackedType<U, FDomain>: PackedFieldIndexable,
422		PackedType<U, FI>: PackedFieldIndexable,
423	{
424		let max_n_vars = 6;
425		let n_multilinears = 9;
426
427		let backend = make_portable_backend();
428		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
429		let switchover_fn = standard_switchover_heuristic(-2);
430		let mut rng = StdRng::seed_from_u64(0);
431
432		let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
433		let triple =
434			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
435		let quad =
436			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
437
438		let prover_compositions = [
439			(
440				"pair".into(),
441				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
442				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
443			),
444			(
445				"triple".into(),
446				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
447				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
448			),
449			(
450				"quad".into(),
451				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
452				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
453			),
454		];
455
456		let prover_adapter_compositions = [
457			CompositionScalarAdapter::new(pair.clone() as Arc<dyn CompositionPoly<FI>>),
458			CompositionScalarAdapter::new(triple.clone() as Arc<dyn CompositionPoly<FI>>),
459			CompositionScalarAdapter::new(quad.clone() as Arc<dyn CompositionPoly<FI>>),
460		];
461
462		let verifier_compositions = [
463			pair as Arc<dyn CompositionPoly<F>>,
464			triple as Arc<dyn CompositionPoly<F>>,
465			quad as Arc<dyn CompositionPoly<F>>,
466		];
467
468		for skip_rounds in 0..=max_n_vars {
469			let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
470
471			let prover_zerocheck_challenges: Vec<FI> = proof.sample_vec(max_n_vars - skip_rounds);
472
473			let mut prover_zerocheck_claims = Vec::new();
474			let mut univariate_provers = Vec::new();
475			for n_vars in (1..=max_n_vars).rev() {
476				let mut multilinears = generate_zero_product_multilinears::<
477					PackedType<U, FWitness>,
478					PackedType<U, FI>,
479				>(&mut rng, n_vars, 2);
480				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
481				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
482
483				let claim = ZerocheckClaim::<FI, _>::new(
484					n_vars,
485					n_multilinears,
486					prover_adapter_compositions.to_vec(),
487				)
488				.unwrap();
489
490				let prover =
491					UnivariateZerocheck::<FDomain, FBase, PackedType<U, FI>, _, _, _, _>::new(
492						multilinears,
493						prover_compositions.to_vec(),
494						&prover_zerocheck_challenges
495							[(max_n_vars - n_vars).saturating_sub(skip_rounds)..],
496						domain_factory.clone(),
497						switchover_fn,
498						&backend,
499					)
500					.unwrap();
501
502				prover_zerocheck_claims.push(claim);
503				univariate_provers.push(prover);
504			}
505
506			let univariate_cnt = prover_zerocheck_claims
507				.partition_point(|claim| claim.n_vars() > max_n_vars - skip_rounds);
508			let tail_provers = univariate_provers.split_off(univariate_cnt);
509
510			let tail_zerocheck_provers = tail_provers
511				.into_iter()
512				.map(|prover| {
513					let regular_zerocheck = prover.into_regular_zerocheck().unwrap();
514					Box::new(regular_zerocheck) as Box<dyn SumcheckProver<_>>
515				})
516				.collect::<Vec<_>>();
517
518			let prover_univariate_output =
519				batch_prove_zerocheck_univariate_round(univariate_provers, skip_rounds, &mut proof)
520					.unwrap();
521
522			let _ = batch_prove_with_start(
523				prover_univariate_output.batch_prove_start,
524				tail_zerocheck_provers,
525				&mut proof,
526			)
527			.unwrap();
528
529			let mut verifier_proof = proof.into_verifier();
530
531			let verifier_zerocheck_challenges: Vec<F> =
532				verifier_proof.sample_vec(max_n_vars - skip_rounds);
533			assert_eq!(
534				prover_zerocheck_challenges
535					.into_iter()
536					.map(F::from)
537					.collect::<Vec<_>>(),
538				verifier_zerocheck_challenges
539			);
540
541			let mut verifier_zerocheck_claims = Vec::new();
542			for n_vars in (1..=max_n_vars).rev() {
543				let claim = ZerocheckClaim::<F, _>::new(
544					n_vars,
545					n_multilinears,
546					verifier_compositions.to_vec(),
547				)
548				.unwrap();
549
550				verifier_zerocheck_claims.push(claim);
551			}
552			let verifier_univariate_output = batch_verify_zerocheck_univariate_round(
553				&verifier_zerocheck_claims[..univariate_cnt],
554				skip_rounds,
555				&mut verifier_proof,
556			)
557			.unwrap();
558
559			let verifier_sumcheck_claims = reduce_to_sumchecks(&verifier_zerocheck_claims).unwrap();
560			let _verifier_sumcheck_output = batch_verify_with_start(
561				EvaluationOrder::LowToHigh,
562				verifier_univariate_output.batch_verify_start,
563				&verifier_sumcheck_claims,
564				&mut verifier_proof,
565			)
566			.unwrap();
567
568			verifier_proof.finalize().unwrap()
569		}
570	}
571}