1use std::{
4 iter::{self, repeat_n},
5 ops::{Mul, MulAssign},
6};
7
8use binius_field::{ExtensionField, Field, PackedFieldIndexable, TowerField};
9use binius_hal::{make_portable_backend, ComputationBackendExt};
10use binius_math::{BinarySubspace, EvaluationDomain, MultilinearExtension};
11use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
12use bytemuck::zeroed_vec;
13
14use crate::{
15 composition::{BivariateProduct, IndexComposition},
16 polynomial::Error as PolynomialError,
17 protocols::sumcheck::{
18 BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, VerificationError,
19 },
20};
21
22#[derive(Clone, Debug)]
29pub struct LagrangeRoundEvals<F: Field> {
30 pub zeros_prefix_len: usize,
31 pub evals: Vec<F>,
32}
33
34impl<F: Field> LagrangeRoundEvals<F> {
35 pub const fn zeros(zeros_prefix_len: usize) -> Self {
37 Self {
38 zeros_prefix_len,
39 evals: Vec::new(),
40 }
41 }
42
43 pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
46 let lhs_len = self.zeros_prefix_len + self.evals.len();
47 let rhs_len = rhs.zeros_prefix_len + rhs.evals.len();
48
49 if lhs_len != rhs_len {
50 bail!(Error::LagrangeRoundEvalsSizeMismatch);
51 }
52
53 let start_idx = if rhs.zeros_prefix_len < self.zeros_prefix_len {
54 self.evals
55 .splice(0..0, repeat_n(F::ZERO, self.zeros_prefix_len - rhs.zeros_prefix_len));
56 self.zeros_prefix_len = rhs.zeros_prefix_len;
57 0
58 } else {
59 rhs.zeros_prefix_len - self.zeros_prefix_len
60 };
61
62 for (lhs, rhs) in self.evals[start_idx..].iter_mut().zip(&rhs.evals) {
63 *lhs += rhs;
64 }
65
66 Ok(())
67 }
68}
69
70impl<F: Field> Mul<F> for LagrangeRoundEvals<F> {
71 type Output = Self;
72
73 fn mul(mut self, rhs: F) -> Self::Output {
74 self *= rhs;
75 self
76 }
77}
78
79impl<F: Field> MulAssign<F> for LagrangeRoundEvals<F> {
80 fn mul_assign(&mut self, rhs: F) {
81 for eval in &mut self.evals {
82 *eval *= rhs;
83 }
84 }
85}
86pub fn univariatizing_reduction_claim<F: Field>(
94 skip_rounds: usize,
95 univariatized_multilinear_evals: &[F],
96) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
97 let composite_sums =
98 univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
99 SumcheckClaim::new(skip_rounds, univariatized_multilinear_evals.len() + 1, composite_sums)
100}
101
102pub fn verify_sumcheck_outputs<F>(
110 claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
111 univariate_challenge: F,
112 unskipped_sumcheck_challenges: &[F],
113 sumcheck_output: BatchSumcheckOutput<F>,
114) -> Result<BatchSumcheckOutput<F>, Error>
115where
116 F: TowerField,
117{
118 let BatchSumcheckOutput {
119 challenges: reduction_sumcheck_challenges,
120 mut multilinear_evals,
121 } = sumcheck_output;
122
123 assert_eq!(multilinear_evals.len(), claims.len());
124
125 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
127 bail!(Error::ClaimsOutOfOrder);
128 }
129
130 let max_n_vars = claims
131 .first()
132 .map(|claim| claim.n_vars())
133 .unwrap_or_default();
134
135 assert_eq!(reduction_sumcheck_challenges.len(), max_n_vars);
136
137 for (claim, multilinear_evals) in iter::zip(claims, multilinear_evals.iter_mut()) {
138 let skip_rounds = claim.n_vars();
139
140 let subspace = BinarySubspace::<F::Canonical>::with_dim(skip_rounds)?.isomorphic::<F>();
141 let evaluation_domain =
142 EvaluationDomain::from_points(subspace.iter().collect::<Vec<_>>(), false)?;
143
144 let lagrange_mle = lagrange_evals_multilinear_extension::<F, F, F>(
145 &evaluation_domain,
146 univariate_challenge,
147 )?;
148
149 let query = make_portable_backend()
150 .multilinear_query::<F>(&reduction_sumcheck_challenges[max_n_vars - skip_rounds..])?;
151 let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
152
153 let multilinear_evals_last = multilinear_evals
154 .pop()
155 .ok_or(VerificationError::NumberOfFinalEvaluations)?;
156
157 if multilinear_evals_last != expected_last_eval {
158 bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
159 }
160 }
161
162 let mut challenges = Vec::new();
163 challenges.extend(reduction_sumcheck_challenges);
164 challenges.extend(unskipped_sumcheck_challenges);
165
166 let output = BatchSumcheckOutput {
167 challenges,
168 multilinear_evals,
169 };
170
171 Ok(output)
172}
173
174pub(super) fn univariatizing_reduction_composite_sum_claims<F: Field>(
178 univariatized_multilinear_evals: &[F],
179) -> Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>> {
180 let n_multilinears = univariatized_multilinear_evals.len();
181 univariatized_multilinear_evals
182 .iter()
183 .enumerate()
184 .map(|(i, &univariatized_multilinear_eval)| {
185 let composition =
186 IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
187 .expect("index composition indice correct by construction");
188
189 CompositeSumClaim {
190 composition,
191 sum: univariatized_multilinear_eval,
192 }
193 })
194 .collect()
195}
196
197pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
200 evaluation_domain: &EvaluationDomain<FDomain>,
201 univariate_challenge: F,
202) -> Result<MultilinearExtension<P>, PolynomialError>
203where
204 FDomain: Field,
205 F: Field + ExtensionField<FDomain>,
206 P: PackedFieldIndexable<Scalar = F>,
207{
208 let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
209
210 let n_vars = log2_strict_usize(lagrange_evals.len());
211 let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
212 let scalars = P::unpack_scalars_mut(packed.as_mut_slice());
213 scalars[..lagrange_evals.len()].copy_from_slice(lagrange_evals.as_slice());
214
215 Ok(MultilinearExtension::new(n_vars, packed)?)
216}
217
218#[cfg(test)]
219mod tests {
220 use std::{iter, sync::Arc};
221
222 use binius_field::{
223 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
224 as_packed_field::{PackScalar, PackedType},
225 underlier::UnderlierType,
226 AESTowerField128b, AESTowerField16b, AESTowerField8b, BinaryField128b, BinaryField16b,
227 Field, PackedBinaryField1x128b, PackedBinaryField4x32b, PackedFieldIndexable, TowerField,
228 };
229 use binius_hal::ComputationBackend;
230 use binius_hash::groestl::Groestl256;
231 use binius_math::{
232 BinarySubspace, CompositionPoly, EvaluationOrder, IsomorphicEvaluationDomainFactory,
233 MultilinearPoly,
234 };
235 use rand::{prelude::StdRng, SeedableRng};
236
237 use super::*;
238 use crate::{
239 composition::{IndexComposition, ProductComposition},
240 fiat_shamir::{CanSample, HasherChallenger},
241 polynomial::CompositionScalarAdapter,
242 protocols::{
243 sumcheck::{
244 batch_verify, batch_verify_with_start, batch_verify_zerocheck_univariate_round,
245 eq_ind::reduce_to_regular_sumchecks,
246 prove::{
247 batch_prove, batch_prove_with_start, batch_prove_zerocheck_univariate_round,
248 univariate::{reduce_to_skipped_projection, univariatizing_reduction_prover},
249 SumcheckProver, UnivariateZerocheck,
250 },
251 standard_switchover_heuristic,
252 zerocheck::reduce_to_eq_ind_sumchecks,
253 ZerocheckClaim,
254 },
255 test_utils::generate_zero_product_multilinears,
256 },
257 transcript::ProverTranscript,
258 };
259
260 #[test]
261 fn test_univariatizing_reduction_end_to_end() {
262 type F = BinaryField128b;
263 type FDomain = BinaryField16b;
264 type P = PackedBinaryField4x32b;
265 type PE = PackedBinaryField1x128b;
266
267 let backend = make_portable_backend();
268 let mut rng = StdRng::seed_from_u64(0);
269
270 let regular_vars = 3;
271 let max_skip_rounds = 3;
272 let n_multilinears = 2;
273
274 let univariate_challenge = <F as Field>::random(&mut rng);
275
276 let sumcheck_challenges = (0..regular_vars)
277 .map(|_| <F as Field>::random(&mut rng))
278 .collect::<Vec<_>>();
279
280 let mut provers = Vec::new();
281 let mut all_multilinears = Vec::new();
282 let mut all_univariatized_multilinear_evals = Vec::new();
283 for skip_rounds in (0..=max_skip_rounds).rev() {
284 let n_vars = skip_rounds + regular_vars;
285
286 let multilinears =
287 generate_zero_product_multilinears::<P, PE>(&mut rng, n_vars, n_multilinears);
288 all_multilinears.push((skip_rounds, multilinears.clone()));
289
290 let subspace = BinarySubspace::with_dim(skip_rounds).unwrap();
291 let ntt_domain =
292 EvaluationDomain::from_points(subspace.iter().collect::<Vec<FDomain>>(), false)
293 .unwrap();
294
295 let query = backend.multilinear_query(&sumcheck_challenges).unwrap();
296 let univariatized_multilinear_evals = multilinears
297 .iter()
298 .map(|multilinear| {
299 let partial_eval = backend
300 .evaluate_partial_high(multilinear, query.to_ref())
301 .unwrap();
302 ntt_domain
303 .extrapolate(PE::unpack_scalars(partial_eval.evals()), univariate_challenge)
304 .unwrap()
305 })
306 .collect::<Vec<_>>();
307
308 all_univariatized_multilinear_evals.push(univariatized_multilinear_evals.clone());
309
310 let reduced_multilinears =
311 reduce_to_skipped_projection(multilinears, &sumcheck_challenges, &backend).unwrap();
312
313 let prover = univariatizing_reduction_prover::<_, FDomain, _, _>(
314 reduced_multilinears,
315 &univariatized_multilinear_evals,
316 univariate_challenge,
317 &backend,
318 )
319 .unwrap();
320
321 provers.push(prover);
322 }
323
324 let mut prove_challenger = ProverTranscript::<HasherChallenger<Groestl256>>::new();
325 let batch_sumcheck_output_prove = batch_prove(provers, &mut prove_challenger).unwrap();
326
327 for ((skip_rounds, multilinears), multilinear_evals) in
328 iter::zip(&all_multilinears, batch_sumcheck_output_prove.multilinear_evals)
329 {
330 assert_eq!(multilinears.len() + 1, multilinear_evals.len());
331
332 let mut query =
333 batch_sumcheck_output_prove.challenges[max_skip_rounds - skip_rounds..].to_vec();
334 query.extend(sumcheck_challenges.as_slice());
335
336 let query = backend.multilinear_query(&query).unwrap();
337
338 for (multilinear, eval) in iter::zip(multilinears, multilinear_evals) {
339 assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
340 }
341 }
342
343 let claims = iter::zip(&all_multilinears, &all_univariatized_multilinear_evals)
344 .map(|((skip_rounds, _q), univariatized_multilinear_evals)| {
345 univariatizing_reduction_claim(*skip_rounds, univariatized_multilinear_evals)
346 .unwrap()
347 })
348 .collect::<Vec<_>>();
349
350 let mut verify_challenger = prove_challenger.into_verifier();
351 let batch_sumcheck_output_verify =
352 batch_verify(EvaluationOrder::LowToHigh, claims.as_slice(), &mut verify_challenger)
353 .unwrap();
354 let batch_sumcheck_output_post = verify_sumcheck_outputs(
355 claims.as_slice(),
356 univariate_challenge,
357 &sumcheck_challenges,
358 batch_sumcheck_output_verify,
359 )
360 .unwrap();
361
362 for ((skip_rounds, multilinears), evals) in
363 iter::zip(all_multilinears, batch_sumcheck_output_post.multilinear_evals)
364 {
365 let mut query = batch_sumcheck_output_post.challenges
366 [max_skip_rounds - skip_rounds..max_skip_rounds]
367 .to_vec();
368 query.extend(sumcheck_challenges.as_slice());
369
370 let query = backend.multilinear_query(&query).unwrap();
371
372 for (multilinear, eval) in iter::zip(multilinears, evals) {
373 assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
374 }
375 }
376 }
377
378 #[test]
379 fn test_univariatized_zerocheck_end_to_end_basic() {
380 test_univariatized_zerocheck_end_to_end_helper::<
381 OptimalUnderlier128b,
382 BinaryField128b,
383 AESTowerField128b,
384 AESTowerField16b,
385 AESTowerField16b,
386 AESTowerField8b,
387 >()
388 }
389
390 #[test]
391 fn test_univariatized_zerocheck_end_to_end_with_nontrivial_packing() {
392 test_univariatized_zerocheck_end_to_end_helper::<
395 OptimalUnderlier512b,
396 BinaryField128b,
397 AESTowerField128b,
398 AESTowerField16b,
399 AESTowerField16b,
400 AESTowerField8b,
401 >()
402 }
403
404 fn test_univariatized_zerocheck_end_to_end_helper<U, F, FI, FDomain, FBase, FWitness>()
405 where
406 U: UnderlierType
407 + PackScalar<FI>
408 + PackScalar<FBase>
409 + PackScalar<FDomain>
410 + PackScalar<FWitness>,
411 F: TowerField + From<FI>,
412 FI: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
413 FBase: TowerField + ExtensionField<FDomain>,
414 FDomain: TowerField,
415 FWitness: Field,
416 PackedType<U, FBase>: PackedFieldIndexable,
417 PackedType<U, FDomain>: PackedFieldIndexable,
418 PackedType<U, FI>: PackedFieldIndexable,
419 {
420 let max_n_vars = 6;
421 let n_multilinears = 9;
422
423 let backend = make_portable_backend();
424 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
425 let switchover_fn = standard_switchover_heuristic(-2);
426 let mut rng = StdRng::seed_from_u64(0);
427
428 let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
429 let triple =
430 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
431 let quad =
432 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
433
434 let prover_compositions = [
435 (
436 "pair".into(),
437 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
438 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
439 ),
440 (
441 "triple".into(),
442 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
443 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
444 ),
445 (
446 "quad".into(),
447 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
448 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
449 ),
450 ];
451
452 let prover_adapter_compositions = [
453 CompositionScalarAdapter::new(pair.clone() as Arc<dyn CompositionPoly<FI>>),
454 CompositionScalarAdapter::new(triple.clone() as Arc<dyn CompositionPoly<FI>>),
455 CompositionScalarAdapter::new(quad.clone() as Arc<dyn CompositionPoly<FI>>),
456 ];
457
458 let verifier_compositions = [
459 pair as Arc<dyn CompositionPoly<F>>,
460 triple as Arc<dyn CompositionPoly<F>>,
461 quad as Arc<dyn CompositionPoly<F>>,
462 ];
463
464 for skip_rounds in 0..=max_n_vars {
465 let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
466
467 let prover_zerocheck_challenges: Vec<FI> = proof.sample_vec(max_n_vars - skip_rounds);
468
469 let mut prover_zerocheck_claims = Vec::new();
470 let mut univariate_provers = Vec::new();
471 for n_vars in (1..=max_n_vars).rev() {
472 let mut multilinears = generate_zero_product_multilinears::<
473 PackedType<U, FWitness>,
474 PackedType<U, FI>,
475 >(&mut rng, n_vars, 2);
476 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
477 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
478
479 let claim = ZerocheckClaim::<FI, _>::new(
480 n_vars,
481 n_multilinears,
482 prover_adapter_compositions.to_vec(),
483 )
484 .unwrap();
485
486 let prover = UnivariateZerocheck::<
487 FDomain,
488 FBase,
489 PackedType<U, FI>,
490 _,
491 _,
492 _,
493 _,
494 _,
495 _,
496 >::new(
497 multilinears,
498 prover_compositions.to_vec(),
499 &prover_zerocheck_challenges
500 [(max_n_vars - n_vars).saturating_sub(skip_rounds)..],
501 domain_factory.clone(),
502 switchover_fn,
503 &backend,
504 )
505 .unwrap();
506
507 prover_zerocheck_claims.push(claim);
508 univariate_provers.push(prover);
509 }
510
511 let univariate_cnt = prover_zerocheck_claims
512 .partition_point(|claim| claim.n_vars() > max_n_vars - skip_rounds);
513 let tail_provers = univariate_provers.split_off(univariate_cnt);
514
515 let tail_zerocheck_provers = tail_provers
516 .into_iter()
517 .map(|prover| {
518 let regular_zerocheck = prover.into_regular_zerocheck().unwrap();
519 Box::new(regular_zerocheck) as Box<dyn SumcheckProver<_>>
520 })
521 .collect::<Vec<_>>();
522
523 let prover_univariate_output =
524 batch_prove_zerocheck_univariate_round(univariate_provers, skip_rounds, &mut proof)
525 .unwrap();
526
527 let _ = batch_prove_with_start(
528 prover_univariate_output.batch_prove_start,
529 tail_zerocheck_provers,
530 &mut proof,
531 )
532 .unwrap();
533
534 let mut verifier_proof = proof.into_verifier();
535
536 let verifier_zerocheck_challenges: Vec<F> =
537 verifier_proof.sample_vec(max_n_vars - skip_rounds);
538 assert_eq!(
539 prover_zerocheck_challenges
540 .into_iter()
541 .map(F::from)
542 .collect::<Vec<_>>(),
543 verifier_zerocheck_challenges
544 );
545
546 let mut verifier_zerocheck_claims = Vec::new();
547 for n_vars in (1..=max_n_vars).rev() {
548 let claim = ZerocheckClaim::<F, _>::new(
549 n_vars,
550 n_multilinears,
551 verifier_compositions.to_vec(),
552 )
553 .unwrap();
554
555 verifier_zerocheck_claims.push(claim);
556 }
557 let verifier_univariate_output = batch_verify_zerocheck_univariate_round(
558 &verifier_zerocheck_claims[..univariate_cnt],
559 skip_rounds,
560 &mut verifier_proof,
561 )
562 .unwrap();
563
564 let verifier_eq_ind_sumcheck_claims =
565 reduce_to_eq_ind_sumchecks(&verifier_zerocheck_claims).unwrap();
566 let verifier_regular_sumcheck_claims =
567 reduce_to_regular_sumchecks(&verifier_eq_ind_sumcheck_claims).unwrap();
568 let _verifier_sumcheck_output = batch_verify_with_start(
569 EvaluationOrder::LowToHigh,
570 verifier_univariate_output.batch_verify_start,
571 &verifier_regular_sumcheck_claims,
572 &mut verifier_proof,
573 )
574 .unwrap();
575
576 verifier_proof.finalize().unwrap()
577 }
578 }
579}