1use std::{
4 iter::{self, repeat_n},
5 ops::{Mul, MulAssign},
6};
7
8use binius_field::{ExtensionField, Field, PackedFieldIndexable, TowerField};
9use binius_hal::{make_portable_backend, ComputationBackendExt};
10use binius_math::{
11 EvaluationDomain, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory,
12 MultilinearExtension,
13};
14use binius_utils::{bail, checked_arithmetics::log2_strict_usize, sorting::is_sorted_ascending};
15use bytemuck::zeroed_vec;
16
17use crate::{
18 composition::{BivariateProduct, IndexComposition},
19 polynomial::Error as PolynomialError,
20 protocols::sumcheck::{
21 BatchSumcheckOutput, CompositeSumClaim, Error, SumcheckClaim, VerificationError,
22 },
23};
24
25#[derive(Clone, Debug)]
32pub struct LagrangeRoundEvals<F: Field> {
33 pub zeros_prefix_len: usize,
34 pub evals: Vec<F>,
35}
36
37impl<F: Field> LagrangeRoundEvals<F> {
38 pub const fn zeros(zeros_prefix_len: usize) -> Self {
40 Self {
41 zeros_prefix_len,
42 evals: Vec::new(),
43 }
44 }
45
46 pub fn add_assign_lagrange(&mut self, rhs: &Self) -> Result<(), Error> {
49 let lhs_len = self.zeros_prefix_len + self.evals.len();
50 let rhs_len = rhs.zeros_prefix_len + rhs.evals.len();
51
52 if lhs_len != rhs_len {
53 bail!(Error::LagrangeRoundEvalsSizeMismatch);
54 }
55
56 let start_idx = if rhs.zeros_prefix_len < self.zeros_prefix_len {
57 self.evals
58 .splice(0..0, repeat_n(F::ZERO, self.zeros_prefix_len - rhs.zeros_prefix_len));
59 self.zeros_prefix_len = rhs.zeros_prefix_len;
60 0
61 } else {
62 rhs.zeros_prefix_len - self.zeros_prefix_len
63 };
64
65 for (lhs, rhs) in self.evals[start_idx..].iter_mut().zip(&rhs.evals) {
66 *lhs += rhs;
67 }
68
69 Ok(())
70 }
71}
72
73impl<F: Field> Mul<F> for LagrangeRoundEvals<F> {
74 type Output = Self;
75
76 fn mul(mut self, rhs: F) -> Self::Output {
77 self *= rhs;
78 self
79 }
80}
81
82impl<F: Field> MulAssign<F> for LagrangeRoundEvals<F> {
83 fn mul_assign(&mut self, rhs: F) {
84 for eval in &mut self.evals {
85 *eval *= rhs;
86 }
87 }
88}
89pub fn univariatizing_reduction_claim<F: Field>(
97 skip_rounds: usize,
98 univariatized_multilinear_evals: &[F],
99) -> Result<SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>, Error> {
100 let composite_sums =
101 univariatizing_reduction_composite_sum_claims(univariatized_multilinear_evals);
102 SumcheckClaim::new(skip_rounds, univariatized_multilinear_evals.len() + 1, composite_sums)
103}
104
105pub fn verify_sumcheck_outputs<F>(
113 claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
114 univariate_challenge: F,
115 unskipped_sumcheck_challenges: &[F],
116 sumcheck_output: BatchSumcheckOutput<F>,
117) -> Result<BatchSumcheckOutput<F>, Error>
118where
119 F: TowerField,
120{
121 let BatchSumcheckOutput {
122 challenges: reduction_sumcheck_challenges,
123 mut multilinear_evals,
124 } = sumcheck_output;
125
126 assert_eq!(multilinear_evals.len(), claims.len());
127
128 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars()).rev()) {
130 bail!(Error::ClaimsOutOfOrder);
131 }
132
133 let max_n_vars = claims
134 .first()
135 .map(|claim| claim.n_vars())
136 .unwrap_or_default();
137
138 assert_eq!(reduction_sumcheck_challenges.len(), max_n_vars);
139
140 for (claim, multilinear_evals) in iter::zip(claims, multilinear_evals.iter_mut()) {
141 let skip_rounds = claim.n_vars();
142
143 let evaluation_domain = IsomorphicEvaluationDomainFactory::<F::Canonical>::default()
144 .create(1 << skip_rounds)?;
145
146 let lagrange_mle = lagrange_evals_multilinear_extension::<F, F, F>(
147 &evaluation_domain,
148 univariate_challenge,
149 )?;
150
151 let query = make_portable_backend()
152 .multilinear_query::<F>(&reduction_sumcheck_challenges[max_n_vars - skip_rounds..])?;
153 let expected_last_eval = lagrange_mle.evaluate(query.to_ref())?;
154
155 let multilinear_evals_last = multilinear_evals
156 .pop()
157 .ok_or(VerificationError::NumberOfFinalEvaluations)?;
158
159 if multilinear_evals_last != expected_last_eval {
160 bail!(VerificationError::IncorrectLagrangeMultilinearEvaluation);
161 }
162 }
163
164 let mut challenges = Vec::new();
165 challenges.extend(reduction_sumcheck_challenges);
166 challenges.extend(unskipped_sumcheck_challenges);
167
168 let output = BatchSumcheckOutput {
169 challenges,
170 multilinear_evals,
171 };
172
173 Ok(output)
174}
175
176pub(super) fn univariatizing_reduction_composite_sum_claims<F: Field>(
180 univariatized_multilinear_evals: &[F],
181) -> Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>> {
182 let n_multilinears = univariatized_multilinear_evals.len();
183 univariatized_multilinear_evals
184 .iter()
185 .enumerate()
186 .map(|(i, &univariatized_multilinear_eval)| {
187 let composition =
188 IndexComposition::new(n_multilinears + 1, [i, n_multilinears], BivariateProduct {})
189 .expect("index composition indice correct by construction");
190
191 CompositeSumClaim {
192 composition,
193 sum: univariatized_multilinear_eval,
194 }
195 })
196 .collect()
197}
198
199pub(super) fn lagrange_evals_multilinear_extension<FDomain, F, P>(
202 evaluation_domain: &EvaluationDomain<FDomain>,
203 univariate_challenge: F,
204) -> Result<MultilinearExtension<P>, PolynomialError>
205where
206 FDomain: Field,
207 F: Field + ExtensionField<FDomain>,
208 P: PackedFieldIndexable<Scalar = F>,
209{
210 let lagrange_evals = evaluation_domain.lagrange_evals(univariate_challenge);
211
212 let n_vars = log2_strict_usize(lagrange_evals.len());
213 let mut packed = zeroed_vec(lagrange_evals.len().div_ceil(P::WIDTH));
214 let scalars = P::unpack_scalars_mut(packed.as_mut_slice());
215 scalars[..lagrange_evals.len()].copy_from_slice(lagrange_evals.as_slice());
216
217 Ok(MultilinearExtension::new(n_vars, packed)?)
218}
219
220#[cfg(test)]
221mod tests {
222 use std::{iter, sync::Arc};
223
224 use binius_field::{
225 arch::{OptimalUnderlier128b, OptimalUnderlier512b},
226 as_packed_field::{PackScalar, PackedType},
227 underlier::UnderlierType,
228 AESTowerField128b, AESTowerField16b, AESTowerField8b, BinaryField128b, BinaryField16b,
229 Field, PackedBinaryField1x128b, PackedBinaryField4x32b, PackedFieldIndexable, TowerField,
230 };
231 use binius_hal::ComputationBackend;
232 use binius_math::{
233 CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, EvaluationOrder,
234 IsomorphicEvaluationDomainFactory, MultilinearPoly,
235 };
236 use groestl_crypto::Groestl256;
237 use rand::{prelude::StdRng, SeedableRng};
238
239 use super::*;
240 use crate::{
241 composition::{IndexComposition, ProductComposition},
242 fiat_shamir::{CanSample, HasherChallenger},
243 polynomial::CompositionScalarAdapter,
244 protocols::{
245 sumcheck::{
246 batch_verify, batch_verify_with_start, batch_verify_zerocheck_univariate_round,
247 prove::{
248 batch_prove, batch_prove_with_start, batch_prove_zerocheck_univariate_round,
249 univariate::{reduce_to_skipped_projection, univariatizing_reduction_prover},
250 SumcheckProver, UnivariateZerocheck,
251 },
252 standard_switchover_heuristic,
253 zerocheck::reduce_to_sumchecks,
254 ZerocheckClaim,
255 },
256 test_utils::generate_zero_product_multilinears,
257 },
258 transcript::ProverTranscript,
259 };
260
261 #[test]
262 fn test_univariatizing_reduction_end_to_end() {
263 type F = BinaryField128b;
264 type FDomain = BinaryField16b;
265 type P = PackedBinaryField4x32b;
266 type PE = PackedBinaryField1x128b;
267
268 let backend = make_portable_backend();
269 let mut rng = StdRng::seed_from_u64(0);
270
271 let regular_vars = 3;
272 let max_skip_rounds = 3;
273 let n_multilinears = 2;
274
275 let evaluation_domain_factory = DefaultEvaluationDomainFactory::<FDomain>::default();
276
277 let univariate_challenge = <F as Field>::random(&mut rng);
278
279 let sumcheck_challenges = (0..regular_vars)
280 .map(|_| <F as Field>::random(&mut rng))
281 .collect::<Vec<_>>();
282
283 let mut provers = Vec::new();
284 let mut all_multilinears = Vec::new();
285 let mut all_univariatized_multilinear_evals = Vec::new();
286 for skip_rounds in (0..=max_skip_rounds).rev() {
287 let n_vars = skip_rounds + regular_vars;
288
289 let multilinears =
290 generate_zero_product_multilinears::<P, PE>(&mut rng, n_vars, n_multilinears);
291 all_multilinears.push((skip_rounds, multilinears.clone()));
292
293 let domain = evaluation_domain_factory
294 .clone()
295 .create(1 << skip_rounds)
296 .unwrap();
297
298 let query = backend.multilinear_query(&sumcheck_challenges).unwrap();
299 let univariatized_multilinear_evals = multilinears
300 .iter()
301 .map(|multilinear| {
302 let partial_eval = backend
303 .evaluate_partial_high(multilinear, query.to_ref())
304 .unwrap();
305 domain
306 .extrapolate(PE::unpack_scalars(partial_eval.evals()), univariate_challenge)
307 .unwrap()
308 })
309 .collect::<Vec<_>>();
310
311 all_univariatized_multilinear_evals.push(univariatized_multilinear_evals.clone());
312
313 let reduced_multilinears =
314 reduce_to_skipped_projection(multilinears, &sumcheck_challenges, &backend).unwrap();
315
316 let prover = univariatizing_reduction_prover(
317 reduced_multilinears,
318 &univariatized_multilinear_evals,
319 univariate_challenge,
320 evaluation_domain_factory.clone(),
321 &backend,
322 )
323 .unwrap();
324
325 provers.push(prover);
326 }
327
328 let mut prove_challenger = ProverTranscript::<HasherChallenger<Groestl256>>::new();
329 let batch_sumcheck_output_prove = batch_prove(provers, &mut prove_challenger).unwrap();
330
331 for ((skip_rounds, multilinears), multilinear_evals) in
332 iter::zip(&all_multilinears, batch_sumcheck_output_prove.multilinear_evals)
333 {
334 assert_eq!(multilinears.len() + 1, multilinear_evals.len());
335
336 let mut query =
337 batch_sumcheck_output_prove.challenges[max_skip_rounds - skip_rounds..].to_vec();
338 query.extend(sumcheck_challenges.as_slice());
339
340 let query = backend.multilinear_query(&query).unwrap();
341
342 for (multilinear, eval) in iter::zip(multilinears, multilinear_evals) {
343 assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
344 }
345 }
346
347 let claims = iter::zip(&all_multilinears, &all_univariatized_multilinear_evals)
348 .map(|((skip_rounds, _q), univariatized_multilinear_evals)| {
349 univariatizing_reduction_claim(*skip_rounds, univariatized_multilinear_evals)
350 .unwrap()
351 })
352 .collect::<Vec<_>>();
353
354 let mut verify_challenger = prove_challenger.into_verifier();
355 let batch_sumcheck_output_verify =
356 batch_verify(EvaluationOrder::LowToHigh, claims.as_slice(), &mut verify_challenger)
357 .unwrap();
358 let batch_sumcheck_output_post = verify_sumcheck_outputs(
359 claims.as_slice(),
360 univariate_challenge,
361 &sumcheck_challenges,
362 batch_sumcheck_output_verify,
363 )
364 .unwrap();
365
366 for ((skip_rounds, multilinears), evals) in
367 iter::zip(all_multilinears, batch_sumcheck_output_post.multilinear_evals)
368 {
369 let mut query = batch_sumcheck_output_post.challenges
370 [max_skip_rounds - skip_rounds..max_skip_rounds]
371 .to_vec();
372 query.extend(sumcheck_challenges.as_slice());
373
374 let query = backend.multilinear_query(&query).unwrap();
375
376 for (multilinear, eval) in iter::zip(multilinears, evals) {
377 assert_eq!(multilinear.evaluate(query.to_ref()).unwrap(), eval);
378 }
379 }
380 }
381
382 #[test]
383 fn test_univariatized_zerocheck_end_to_end_basic() {
384 test_univariatized_zerocheck_end_to_end_helper::<
385 OptimalUnderlier128b,
386 BinaryField128b,
387 AESTowerField128b,
388 AESTowerField16b,
389 AESTowerField16b,
390 AESTowerField8b,
391 >()
392 }
393
394 #[test]
395 fn test_univariatized_zerocheck_end_to_end_with_nontrivial_packing() {
396 test_univariatized_zerocheck_end_to_end_helper::<
399 OptimalUnderlier512b,
400 BinaryField128b,
401 AESTowerField128b,
402 AESTowerField16b,
403 AESTowerField16b,
404 AESTowerField8b,
405 >()
406 }
407
408 fn test_univariatized_zerocheck_end_to_end_helper<U, F, FI, FDomain, FBase, FWitness>()
409 where
410 U: UnderlierType
411 + PackScalar<FI>
412 + PackScalar<FBase>
413 + PackScalar<FDomain>
414 + PackScalar<FWitness>,
415 F: TowerField + From<FI>,
416 FI: TowerField + ExtensionField<FDomain> + ExtensionField<FBase> + ExtensionField<FWitness>,
417 FBase: TowerField + ExtensionField<FDomain>,
418 FDomain: TowerField,
419 FWitness: Field,
420 PackedType<U, FBase>: PackedFieldIndexable,
421 PackedType<U, FDomain>: PackedFieldIndexable,
422 PackedType<U, FI>: PackedFieldIndexable,
423 {
424 let max_n_vars = 6;
425 let n_multilinears = 9;
426
427 let backend = make_portable_backend();
428 let domain_factory = IsomorphicEvaluationDomainFactory::<FDomain>::default();
429 let switchover_fn = standard_switchover_heuristic(-2);
430 let mut rng = StdRng::seed_from_u64(0);
431
432 let pair = Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap());
433 let triple =
434 Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap());
435 let quad =
436 Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap());
437
438 let prover_compositions = [
439 (
440 "pair".into(),
441 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
442 pair.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
443 ),
444 (
445 "triple".into(),
446 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
447 triple.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
448 ),
449 (
450 "quad".into(),
451 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FBase>>>,
452 quad.clone() as Arc<dyn CompositionPoly<PackedType<U, FI>>>,
453 ),
454 ];
455
456 let prover_adapter_compositions = [
457 CompositionScalarAdapter::new(pair.clone() as Arc<dyn CompositionPoly<FI>>),
458 CompositionScalarAdapter::new(triple.clone() as Arc<dyn CompositionPoly<FI>>),
459 CompositionScalarAdapter::new(quad.clone() as Arc<dyn CompositionPoly<FI>>),
460 ];
461
462 let verifier_compositions = [
463 pair as Arc<dyn CompositionPoly<F>>,
464 triple as Arc<dyn CompositionPoly<F>>,
465 quad as Arc<dyn CompositionPoly<F>>,
466 ];
467
468 for skip_rounds in 0..=max_n_vars {
469 let mut proof = ProverTranscript::<HasherChallenger<Groestl256>>::new();
470
471 let prover_zerocheck_challenges: Vec<FI> = proof.sample_vec(max_n_vars - skip_rounds);
472
473 let mut prover_zerocheck_claims = Vec::new();
474 let mut univariate_provers = Vec::new();
475 for n_vars in (1..=max_n_vars).rev() {
476 let mut multilinears = generate_zero_product_multilinears::<
477 PackedType<U, FWitness>,
478 PackedType<U, FI>,
479 >(&mut rng, n_vars, 2);
480 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 3));
481 multilinears.extend(generate_zero_product_multilinears(&mut rng, n_vars, 4));
482
483 let claim = ZerocheckClaim::<FI, _>::new(
484 n_vars,
485 n_multilinears,
486 prover_adapter_compositions.to_vec(),
487 )
488 .unwrap();
489
490 let prover =
491 UnivariateZerocheck::<FDomain, FBase, PackedType<U, FI>, _, _, _, _>::new(
492 multilinears,
493 prover_compositions.to_vec(),
494 &prover_zerocheck_challenges
495 [(max_n_vars - n_vars).saturating_sub(skip_rounds)..],
496 domain_factory.clone(),
497 switchover_fn,
498 &backend,
499 )
500 .unwrap();
501
502 prover_zerocheck_claims.push(claim);
503 univariate_provers.push(prover);
504 }
505
506 let univariate_cnt = prover_zerocheck_claims
507 .partition_point(|claim| claim.n_vars() > max_n_vars - skip_rounds);
508 let tail_provers = univariate_provers.split_off(univariate_cnt);
509
510 let tail_zerocheck_provers = tail_provers
511 .into_iter()
512 .map(|prover| {
513 let regular_zerocheck = prover.into_regular_zerocheck().unwrap();
514 Box::new(regular_zerocheck) as Box<dyn SumcheckProver<_>>
515 })
516 .collect::<Vec<_>>();
517
518 let prover_univariate_output =
519 batch_prove_zerocheck_univariate_round(univariate_provers, skip_rounds, &mut proof)
520 .unwrap();
521
522 let _ = batch_prove_with_start(
523 prover_univariate_output.batch_prove_start,
524 tail_zerocheck_provers,
525 &mut proof,
526 )
527 .unwrap();
528
529 let mut verifier_proof = proof.into_verifier();
530
531 let verifier_zerocheck_challenges: Vec<F> =
532 verifier_proof.sample_vec(max_n_vars - skip_rounds);
533 assert_eq!(
534 prover_zerocheck_challenges
535 .into_iter()
536 .map(F::from)
537 .collect::<Vec<_>>(),
538 verifier_zerocheck_challenges
539 );
540
541 let mut verifier_zerocheck_claims = Vec::new();
542 for n_vars in (1..=max_n_vars).rev() {
543 let claim = ZerocheckClaim::<F, _>::new(
544 n_vars,
545 n_multilinears,
546 verifier_compositions.to_vec(),
547 )
548 .unwrap();
549
550 verifier_zerocheck_claims.push(claim);
551 }
552 let verifier_univariate_output = batch_verify_zerocheck_univariate_round(
553 &verifier_zerocheck_claims[..univariate_cnt],
554 skip_rounds,
555 &mut verifier_proof,
556 )
557 .unwrap();
558
559 let verifier_sumcheck_claims = reduce_to_sumchecks(&verifier_zerocheck_claims).unwrap();
560 let _verifier_sumcheck_output = batch_verify_with_start(
561 EvaluationOrder::LowToHigh,
562 verifier_univariate_output.batch_verify_start,
563 &verifier_sumcheck_claims,
564 &mut verifier_proof,
565 )
566 .unwrap();
567
568 verifier_proof.finalize().unwrap()
569 }
570 }
571}