binius_core/protocols/sumcheck/
univariate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::{self, repeat_n},
5	ops::{Mul, MulAssign},
6};
7
8use binius_field::{ExtensionField, Field, PackedFieldIndexable, TowerField};
9use binius_hal::{make_portable_backend, ComputationBackendExt};
10use binius_math::{BinarySubspace, EvaluationDomain, MultilinearExtension};
11use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
12use bytemuck::zeroed_vec;
13
14use crate::{
15	composition::{BivariateProduct, IndexComposition},
16	polynomial::Error as PolynomialError,
17	protocols::sumcheck::{
18		BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, VerificationError,
19	},
20};
21
22/// A univariate polynomial in Lagrange basis.
23///
24/// The coefficient at position `i` in the `lagrange_coeffs` corresponds to evaluation
25/// at `i+zeros_prefix_len`-th field element of some agreed upon domain. Coefficients
26/// at positions `0..zeros_prefix_len` are zero. Addition of Lagrange basis representations
27/// only makes sense for the polynomials in the same domain.
28#[derive(Clone, Debug)]
29pub struct LagrangeRoundEvals<F: Field> {
30	pub zeros_prefix_len: usize,
31	pub evals: Vec<F>,
32}
33
34impl<F: Field> LagrangeRoundEvals<F> {
35	/// A Lagrange representation of a zero polynomial, on a given domain.
36	pub const fn zeros(zeros_prefix_len: usize) -> Self {
37		Self {
38			zeros_prefix_len,
39			evals: Vec::new(),
40		}
41	}
42
43	/// An assigning addition of two polynomials in Lagrange basis. May fail,
44	/// thus it's not simply an `AddAssign` overload due to signature mismatch.
45	pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
46		let lhs_len = self.zeros_prefix_len + self.evals.len();
47		let rhs_len = rhs.zeros_prefix_len + rhs.evals.len();
48
49		if lhs_len != rhs_len {
50			bail!(Error::LagrangeRoundEvalsSizeMismatch);
51		}
52
53		let start_idx = if rhs.zeros_prefix_len < self.zeros_prefix_len {
54			self.evals
55				.splice(0..0, repeat_n(F::ZERO, self.zeros_prefix_len - rhs.zeros_prefix_len));
56			self.zeros_prefix_len = rhs.zeros_prefix_len;
57			0
58		} else {
59			rhs.zeros_prefix_len - self.zeros_prefix_len
60		};
61
62		for (lhs, rhs) in self.evals[start_idx..].iter_mut().zip(&rhs.evals) {
63			*lhs += rhs;
64		}
65
66		Ok(())
67	}
68}
69
70impl<F: Field> Mul<F> for LagrangeRoundEvals<F> {
71	type Output = Self;
72
73	fn mul(mut self, rhs: F) -> Self::Output {
74		self *= rhs;
75		self
76	}
77}
78
79impl<F: Field> MulAssign<F> for LagrangeRoundEvals<F> {
80	fn mul_assign(&mut self, rhs: F) {
81		for eval in &mut self.evals {
82			*eval *= rhs;
83		}
84	}
85}
86/// Creates sumcheck claims for the reduction from evaluations of univariatized virtual multilinear oracles to
87/// "regular" multilinear evaluations.
88///
89/// Univariatized virtual multilinear oracles are given by:
90/// $$\hat{M}(\hat{u}_1,x_1,\ldots,x_n) = \sum M(u_1,\ldots, u_k, x_1, \ldots, x_n) \cdot L_u(\hat{u}_1)$$
91/// It is assumed that `univariatized_multilinear_evals` came directly from a previous sumcheck with a univariate
92/// round batching `skip_rounds` variables.
93pub fn univariatizing_reduction_claim<F: Field>(
94	skip_rounds: usize,
95	univariatized_multilinear_evals: &[F],
96) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
97	let composite_sums =
98		univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
99	SumcheckClaim::new(skip_rounds, univariatized_multilinear_evals.len() + 1, composite_sums)
100}
101
102/// Verify the validity of sumcheck outputs for the reduction zerocheck.
103///
104/// This takes in the output of the batched univariatizing reduction sumcheck and returns the output
105/// that can be used to create multilinear evaluation claims. This simply strips off the evaluation of
106/// the multilinear extension of Lagrange polynomials evaluations at `univariate_challenge` (denoted by
107/// $\hat{u}_1$) and verifies that this value is correct. The argument `unskipped_sumcheck_challenges`
108/// holds the challenges of the sumcheck following the univariate round.
109pub fn verify_sumcheck_outputs<F>(
110	claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
111	univariate_challenge: F,
112	unskipped_sumcheck_challenges: &[F],
113	sumcheck_output: BatchSumcheckOutput<F>,
114) -> Result<BatchSumcheckOutput<F>, Error>
115where
116	F: TowerField,
117{
118	let BatchSumcheckOutput {
119		challenges: reduction_sumcheck_challenges,
120		mut multilinear_evals,
121	} = sumcheck_output;
122
123	assert_eq!(multilinear_evals.len(), claims.len());
124
125	// Check that the claims are in descending order by n_vars
126	if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
127		bail!(Error::ClaimsOutOfOrder);
128	}
129
130	let max_n_vars = claims
131		.first()
132		.map(|claim| claim.n_vars())
133		.unwrap_or_default();
134
135	assert_eq!(reduction_sumcheck_challenges.len(), max_n_vars);
136
137	for (claim, multilinear_evals) in iter::zip(claims, multilinear_evals.iter_mut()) {
138		let skip_rounds = claim.n_vars();
139
140		let subspace = BinarySubspace::<F::Canonical>::with_dim(skip_rounds)?.isomorphic::<F>();
141		let evaluation_domain =
142			EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
143
144		let lagrange_mle = lagrange_evals_multilinear_extension::<F, F, F>(
145			&evaluation_domain,
146			univariate_challenge,
147		)?;
148
149		let query = make_portable_backend()
150			.multilinear_query::<F>(&reduction_sumcheck_challenges[max_n_vars - skip_rounds..])?;
151		let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
152
153		let multilinear_evals_last = multilinear_evals
154			.pop()
155			.ok_or(VerificationError::NumberOfFinalEvaluations)?;
156
157		if multilinear_evals_last != expected_last_eval {
158			bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
159		}
160	}
161
162	let mut challenges = Vec::new();
163	challenges.extend(reduction_sumcheck_challenges);
164	challenges.extend(unskipped_sumcheck_challenges);
165
166	let output = BatchSumcheckOutput {
167		challenges,
168		multilinear_evals,
169	};
170
171	Ok(output)
172}
173
174// Helper method to create univariatized multilinear oracle evaluation claims.
175// Assumes that multilinear extension of Lagrange evaluations is the last multilinear,
176// uses IndexComposition to multiply each multilinear with it (using BivariateProduct).
177pub(super) fn univariatizing_reduction_composite_sum_claims<F: Field>(
178	univariatized_multilinear_evals: &[F],
179) -> Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>> {
180	let n_multilinears = univariatized_multilinear_evals.len();
181	univariatized_multilinear_evals
182		.iter()
183		.enumerate()
184		.map(|(i, &univariatized_multilinear_eval)| {
185			let composition =
186				IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
187					.expect("index composition indice correct by construction");
188
189			CompositeSumClaim {
190				composition,
191				sum: univariatized_multilinear_eval,
192			}
193		})
194		.collect()
195}
196
197// Given EvaluationDomain, evaluates Lagrange coefficients at a challenge point
198// and creates a multilinear extension of said evaluations.
199pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
200	evaluation_domain: &EvaluationDomain<FDomain>,
201	univariate_challenge: F,
202) -> Result<MultilinearExtension<P>, PolynomialError>
203where
204	FDomain: Field,
205	F: Field + ExtensionField<FDomain>,
206	P: PackedFieldIndexable<Scalar = F>,
207{
208	let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
209
210	let n_vars = log2_strict_usize(lagrange_evals.len());
211	let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
212	let scalars = P::unpack_scalars_mut(packed.as_mut_slice());
213	scalars[..lagrange_evals.len()].copy_from_slice(lagrange_evals.as_slice());
214
215	Ok(MultilinearExtension::new(n_vars, packed)?)
216}
217
218#[cfg(test)]
219mod tests {
220	use std::{iter, sync::Arc};
221
222	use binius_field::{
223		arch::{OptimalUnderlier128b, OptimalUnderlier512b},
224		as_packed_field::{PackScalar, PackedType},
225		underlier::UnderlierType,
226		AESTowerField128b, AESTowerField16b, AESTowerField8b, BinaryField128b, BinaryField16b,
227		Field, PackedBinaryField1x128b, PackedBinaryField4x32b, PackedFieldIndexable, TowerField,
228	};
229	use binius_hal::ComputationBackend;
230	use binius_hash::groestl::Groestl256;
231	use binius_math::{
232		BinarySubspace, CompositionPoly, EvaluationOrder, IsomorphicEvaluationDomainFactory,
233		MultilinearPoly,
234	};
235	use rand::{prelude::StdRng, SeedableRng};
236
237	use super::*;
238	use crate::{
239		composition::{IndexComposition, ProductComposition},
240		fiat_shamir::{CanSample, HasherChallenger},
241		polynomial::CompositionScalarAdapter,
242		protocols::{
243			sumcheck::{
244				batch_verify, batch_verify_with_start, batch_verify_zerocheck_univariate_round,
245				eq_ind::reduce_to_regular_sumchecks,
246				prove::{
247					batch_prove, batch_prove_with_start, batch_prove_zerocheck_univariate_round,
248					univariate::{reduce_to_skipped_projection, univariatizing_reduction_prover},
249					SumcheckProver, UnivariateZerocheck,
250				},
251				standard_switchover_heuristic,
252				zerocheck::reduce_to_eq_ind_sumchecks,
253				ZerocheckClaim,
254			},
255			test_utils::generate_zero_product_multilinears,
256		},
257		transcript::ProverTranscript,
258	};
259
260	#[test]
261	fn test_univariatizing_reduction_end_to_end() {
262		type F = BinaryField128b;
263		type FDomain = BinaryField16b;
264		type P = PackedBinaryField4x32b;
265		type PE = PackedBinaryField1x128b;
266
267		let backend = make_portable_backend();
268		let mut rng = StdRng::seed_from_u64(0);
269
270		let regular_vars = 3;
271		let max_skip_rounds = 3;
272		let n_multilinears = 2;
273
274		let univariate_challenge = <F as Field>::random(&mut rng);
275
276		let sumcheck_challenges = (0..regular_vars)
277			.map(|_| <F as Field>::random(&mut rng))
278			.collect::<Vec<_>>();
279
280		let mut provers = Vec::new();
281		let mut all_multilinears = Vec::new();
282		let mut all_univariatized_multilinear_evals = Vec::new();
283		for skip_rounds in (0..=max_skip_rounds).rev() {
284			let n_vars = skip_rounds + regular_vars;
285
286			let multilinears =
287				generate_zero_product_multilinears::<P, PE>(&mut rng, n_vars, n_multilinears);
288			all_multilinears.push((skip_rounds, multilinears.clone()));
289
290			let subspace = BinarySubspace::with_dim(skip_rounds).unwrap();
291			let ntt_domain =
292				EvaluationDomain::from_points(subspace.iter().collect::<Vec<FDomain>>(), false)
293					.unwrap();
294
295			let query = backend.multilinear_query(&sumcheck_challenges).unwrap();
296			let univariatized_multilinear_evals = multilinears
297				.iter()
298				.map(|multilinear| {
299					let partial_eval = backend
300						.evaluate_partial_high(multilinear, query.to_ref())
301						.unwrap();
302					ntt_domain
303						.extrapolate(PE::unpack_scalars(partial_eval.evals()), univariate_challenge)
304						.unwrap()
305				})
306				.collect::<Vec<_>>();
307
308			all_univariatized_multilinear_evals.push(univariatized_multilinear_evals.clone());
309
310			let reduced_multilinears =
311				reduce_to_skipped_projection(multilinears, &sumcheck_challenges, &backend).unwrap();
312
313			let prover = univariatizing_reduction_prover::<_, FDomain, _, _>(
314				reduced_multilinears,
315				&univariatized_multilinear_evals,
316				univariate_challenge,
317				&backend,
318			)
319			.unwrap();
320
321			provers.push(prover);
322		}
323
324		let mut prove_challenger = ProverTranscript::<HasherChallenger<Groestl256>>::new();
325		let batch_sumcheck_output_prove = batch_prove(provers, &mut prove_challenger).unwrap();
326
327		for ((skip_rounds, multilinears), multilinear_evals) in
328			iter::zip(&all_multilinears, batch_sumcheck_output_prove.multilinear_evals)
329		{
330			assert_eq!(multilinears.len() + 1, multilinear_evals.len());
331
332			let mut query =
333				batch_sumcheck_output_prove.challenges[max_skip_rounds - skip_rounds..].to_vec();
334			query.extend(sumcheck_challenges.as_slice());
335
336			let query = backend.multilinear_query(&query).unwrap();
337
338			for (multilinear, eval) in iter::zip(multilinears, multilinear_evals) {
339				assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
340			}
341		}
342
343		let claims = iter::zip(&all_multilinears, &all_univariatized_multilinear_evals)
344			.map(|((skip_rounds, _q), univariatized_multilinear_evals)| {
345				univariatizing_reduction_claim(*skip_rounds, univariatized_multilinear_evals)
346					.unwrap()
347			})
348			.collect::<Vec<_>>();
349
350		let mut verify_challenger = prove_challenger.into_verifier();
351		let batch_sumcheck_output_verify =
352			batch_verify(EvaluationOrder::LowToHigh, claims.as_slice(), &mut verify_challenger)
353				.unwrap();
354		let batch_sumcheck_output_post = verify_sumcheck_outputs(
355			claims.as_slice(),
356			univariate_challenge,
357			&sumcheck_challenges,
358			batch_sumcheck_output_verify,
359		)
360		.unwrap();
361
362		for ((skip_rounds, multilinears), evals) in
363			iter::zip(all_multilinears, batch_sumcheck_output_post.multilinear_evals)
364		{
365			let mut query = batch_sumcheck_output_post.challenges
366				[max_skip_rounds - skip_rounds..max_skip_rounds]
367				.to_vec();
368			query.extend(sumcheck_challenges.as_slice());
369
370			let query = backend.multilinear_query(&query).unwrap();
371
372			for (multilinear, eval) in iter::zip(multilinears, evals) {
373				assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
374			}
375		}
376	}
377
378	#[test]
379	fn test_univariatized_zerocheck_end_to_end_basic() {
380		test_univariatized_zerocheck_end_to_end_helper::<
381			OptimalUnderlier128b,
382			BinaryField128b,
383			AESTowerField128b,
384			AESTowerField16b,
385			AESTowerField16b,
386			AESTowerField8b,
387		>()
388	}
389
390	#[test]
391	fn test_univariatized_zerocheck_end_to_end_with_nontrivial_packing() {
392		// Using a 512-bit underlier with a 128-bit extension field means the packed field will have a
393		// non-trivial packing width of 4.
394		test_univariatized_zerocheck_end_to_end_helper::<
395			OptimalUnderlier512b,
396			BinaryField128b,
397			AESTowerField128b,
398			AESTowerField16b,
399			AESTowerField16b,
400			AESTowerField8b,
401		>()
402	}
403
404	fn test_univariatized_zerocheck_end_to_end_helper<U, F, FI, FDomain, FBase, FWitness>()
405	where
406		U: UnderlierType
407			+ PackScalar<FI>
408			+ PackScalar<FBase>
409			+ PackScalar<FDomain>
410			+ PackScalar<FWitness>,
411		F: TowerField + From<FI>,
412		FI: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
413		FBase: TowerField + ExtensionField<FDomain>,
414		FDomain: TowerField,
415		FWitness: Field,
416		PackedType<U, FBase>: PackedFieldIndexable,
417		PackedType<U, FDomain>: PackedFieldIndexable,
418		PackedType<U, FI>: PackedFieldIndexable,
419	{
420		let max_n_vars = 6;
421		let n_multilinears = 9;
422
423		let backend = make_portable_backend();
424		let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
425		let switchover_fn = standard_switchover_heuristic(-2);
426		let mut rng = StdRng::seed_from_u64(0);
427
428		let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
429		let triple =
430			Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
431		let quad =
432			Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
433
434		let prover_compositions = [
435			(
436				"pair".into(),
437				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
438				pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
439			),
440			(
441				"triple".into(),
442				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
443				triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
444			),
445			(
446				"quad".into(),
447				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
448				quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
449			),
450		];
451
452		let prover_adapter_compositions = [
453			CompositionScalarAdapter::new(pair.clone() as Arc<dyn CompositionPoly<FI>>),
454			CompositionScalarAdapter::new(triple.clone() as Arc<dyn CompositionPoly<FI>>),
455			CompositionScalarAdapter::new(quad.clone() as Arc<dyn CompositionPoly<FI>>),
456		];
457
458		let verifier_compositions = [
459			pair as Arc<dyn CompositionPoly<F>>,
460			triple as Arc<dyn CompositionPoly<F>>,
461			quad as Arc<dyn CompositionPoly<F>>,
462		];
463
464		for skip_rounds in 0..=max_n_vars {
465			let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
466
467			let prover_zerocheck_challenges: Vec<FI> = proof.sample_vec(max_n_vars - skip_rounds);
468
469			let mut prover_zerocheck_claims = Vec::new();
470			let mut univariate_provers = Vec::new();
471			for n_vars in (1..=max_n_vars).rev() {
472				let mut multilinears = generate_zero_product_multilinears::<
473					PackedType<U, FWitness>,
474					PackedType<U, FI>,
475				>(&mut rng, n_vars, 2);
476				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
477				multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
478
479				let claim = ZerocheckClaim::<FI, _>::new(
480					n_vars,
481					n_multilinears,
482					prover_adapter_compositions.to_vec(),
483				)
484				.unwrap();
485
486				let prover = UnivariateZerocheck::<
487					FDomain,
488					FBase,
489					PackedType<U, FI>,
490					_,
491					_,
492					_,
493					_,
494					_,
495					_,
496				>::new(
497					multilinears,
498					prover_compositions.to_vec(),
499					&prover_zerocheck_challenges
500						[(max_n_vars - n_vars).saturating_sub(skip_rounds)..],
501					domain_factory.clone(),
502					switchover_fn,
503					&backend,
504				)
505				.unwrap();
506
507				prover_zerocheck_claims.push(claim);
508				univariate_provers.push(prover);
509			}
510
511			let univariate_cnt = prover_zerocheck_claims
512				.partition_point(|claim| claim.n_vars() > max_n_vars - skip_rounds);
513			let tail_provers = univariate_provers.split_off(univariate_cnt);
514
515			let tail_zerocheck_provers = tail_provers
516				.into_iter()
517				.map(|prover| {
518					let regular_zerocheck = prover.into_regular_zerocheck().unwrap();
519					Box::new(regular_zerocheck) as Box<dyn SumcheckProver<_>>
520				})
521				.collect::<Vec<_>>();
522
523			let prover_univariate_output =
524				batch_prove_zerocheck_univariate_round(univariate_provers, skip_rounds, &mut proof)
525					.unwrap();
526
527			let _ = batch_prove_with_start(
528				prover_univariate_output.batch_prove_start,
529				tail_zerocheck_provers,
530				&mut proof,
531			)
532			.unwrap();
533
534			let mut verifier_proof = proof.into_verifier();
535
536			let verifier_zerocheck_challenges: Vec<F> =
537				verifier_proof.sample_vec(max_n_vars - skip_rounds);
538			assert_eq!(
539				prover_zerocheck_challenges
540					.into_iter()
541					.map(F::from)
542					.collect::<Vec<_>>(),
543				verifier_zerocheck_challenges
544			);
545
546			let mut verifier_zerocheck_claims = Vec::new();
547			for n_vars in (1..=max_n_vars).rev() {
548				let claim = ZerocheckClaim::<F, _>::new(
549					n_vars,
550					n_multilinears,
551					verifier_compositions.to_vec(),
552				)
553				.unwrap();
554
555				verifier_zerocheck_claims.push(claim);
556			}
557			let verifier_univariate_output = batch_verify_zerocheck_univariate_round(
558				&verifier_zerocheck_claims[..univariate_cnt],
559				skip_rounds,
560				&mut verifier_proof,
561			)
562			.unwrap();
563
564			let verifier_eq_ind_sumcheck_claims =
565				reduce_to_eq_ind_sumchecks(&verifier_zerocheck_claims).unwrap();
566			let verifier_regular_sumcheck_claims =
567				reduce_to_regular_sumchecks(&verifier_eq_ind_sumcheck_claims).unwrap();
568			let _verifier_sumcheck_output = batch_verify_with_start(
569				EvaluationOrder::LowToHigh,
570				verifier_univariate_output.batch_verify_start,
571				&verifier_regular_sumcheck_claims,
572				&mut verifier_proof,
573			)
574			.unwrap();
575
576			verifier_proof.finalize().unwrap()
577		}
578	}
579}