1use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range};
4
5use binius_field::{BinaryField, ExtensionField, Field, TowerField};
6use binius_math::evaluate_piecewise_multilinear;
7use binius_utils::{DeserializeBytes, bail, checked_arithmetics::log2_ceil_usize};
8use getset::CopyGetters;
9use tracing::instrument;
10
11use super::error::{Error, VerificationError};
12use crate::{
13 composition::{BivariateProduct, IndexComposition},
14 fiat_shamir::{CanSample, Challenger},
15 merkle_tree::MerkleTreeScheme,
16 piop::util::ResizeableIndex,
17 polynomial::MultivariatePoly,
18 protocols::{
19 fri::{self, FRIParams, FRIVerifier, estimate_optimal_arity},
20 sumcheck::{
21 CompositeSumClaim, SumcheckClaim, front_loaded::BatchVerifier as SumcheckBatchVerifier,
22 },
23 },
24 reed_solomon::reed_solomon::ReedSolomonCode,
25 transcript::VerifierTranscript,
26};
27
28#[derive(Debug, CopyGetters)]
36pub struct CommitMeta {
37 n_multilins_by_vars: Vec<usize>,
38 offsets_by_vars: Vec<usize>,
39 #[getset(get_copy = "pub")]
41 total_vars: usize,
42 #[getset(get_copy = "pub")]
44 total_multilins: usize,
45}
46
47impl CommitMeta {
48 pub fn new(n_multilins_by_vars: Vec<usize>) -> Self {
55 let (offsets_by_vars, total_multilins, total_elems) =
56 n_multilins_by_vars.iter().enumerate().fold(
57 (Vec::with_capacity(n_multilins_by_vars.len()), 0, 0),
58 |(mut offsets, total_multilins, total_elems), (n_vars, &count)| {
59 offsets.push(total_multilins);
60 (offsets, total_multilins + count, total_elems + (count << n_vars))
61 },
62 );
63
64 Self {
65 offsets_by_vars,
66 n_multilins_by_vars,
67 total_vars: total_elems.next_power_of_two().ilog2() as usize,
68 total_multilins,
69 }
70 }
71
72 pub fn with_vars(n_varss: impl IntoIterator<Item = usize>) -> Self {
75 let mut n_multilins_by_vars = ResizeableIndex::new();
76 for n_vars in n_varss {
77 *n_multilins_by_vars.get_mut(n_vars) += 1;
78 }
79 Self::new(n_multilins_by_vars.into_vec())
80 }
81
82 pub fn max_n_vars(&self) -> usize {
84 self.n_multilins_by_vars.len().saturating_sub(1)
85 }
86
87 pub fn n_multilins_by_vars(&self) -> &[usize] {
90 &self.n_multilins_by_vars
91 }
92
93 pub fn range_by_vars(&self, n_vars: usize) -> Range<usize> {
95 let start = self.offsets_by_vars[n_vars];
96 start..start + self.n_multilins_by_vars[n_vars]
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct PIOPSumcheckClaim<F: Field> {
106 pub n_vars: usize,
108 pub committed: usize,
110 pub transparent: usize,
112 pub sum: F,
115}
116
117fn make_commit_params_with_constant_arity<F, FEncode>(
118 commit_meta: &CommitMeta,
119 security_bits: usize,
120 log_inv_rate: usize,
121 arity: usize,
122) -> Result<FRIParams<F, FEncode>, Error>
123where
124 F: BinaryField + ExtensionField<FEncode>,
125 FEncode: BinaryField,
126{
127 assert!(arity > 0);
128
129 let log_dim = commit_meta.total_vars.saturating_sub(arity);
130 let log_batch_size = commit_meta.total_vars.min(arity);
131 let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate)?;
132 let n_test_queries = fri::calculate_n_test_queries::<F, _>(security_bits, &rs_code)?;
133
134 let cap_height = log2_ceil_usize(n_test_queries);
135 let fold_arities = std::iter::repeat_n(
136 arity,
137 (commit_meta.total_vars).saturating_sub(cap_height.saturating_sub(log_inv_rate)) / arity,
138 )
139 .collect::<Vec<_>>();
140 let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?;
186 Ok(fri_params)
187}
188
189pub fn make_commit_params_with_optimal_arity<F, FEncode, MTScheme>(
190 commit_meta: &CommitMeta,
191 _merkle_scheme: &MTScheme,
192 security_bits: usize,
193 log_inv_rate: usize,
194) -> Result<FRIParams<F, FEncode>, Error>
195where
196 F: BinaryField + ExtensionField<FEncode>,
197 FEncode: BinaryField,
198 MTScheme: MerkleTreeScheme<F>,
199{
200 let arity = estimate_optimal_arity(
201 commit_meta.total_vars + log_inv_rate,
202 size_of::<MTScheme::Digest>(),
203 size_of::<F>(),
204 );
205 make_commit_params_with_constant_arity(commit_meta, security_bits, log_inv_rate, arity)
206}
207
208#[derive(Debug, Clone)]
214pub struct SumcheckClaimDesc<F: Field> {
215 pub committed_indices: Range<usize>,
216 pub transparent_indices: Range<usize>,
217 pub composite_sums: Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>>,
218}
219
220impl<F: Field> SumcheckClaimDesc<F> {
221 pub fn n_committed(&self) -> usize {
222 self.committed_indices.len()
223 }
224
225 pub fn n_transparent(&self) -> usize {
226 self.transparent_indices.len()
227 }
228}
229
230pub fn make_sumcheck_claim_descs<F: Field>(
231 commit_meta: &CommitMeta,
232 transparent_n_vars_iter: impl Iterator<Item = usize>,
233 claims: &[PIOPSumcheckClaim<F>],
234) -> Result<Vec<SumcheckClaimDesc<F>>, Error> {
235 let mut sumcheck_claim_descs = vec![
237 SumcheckClaimDesc {
238 committed_indices: 0..0,
239 transparent_indices: 0..0,
240 composite_sums: vec![],
241 };
242 commit_meta.max_n_vars() + 1
243 ];
244
245 let mut last_offset = 0;
247 for (&n_multilins, claim_desc) in
248 iter::zip(commit_meta.n_multilins_by_vars(), &mut sumcheck_claim_descs)
249 {
250 claim_desc.committed_indices.start = last_offset;
251 last_offset += n_multilins;
252 claim_desc.committed_indices.end = last_offset;
253 }
254
255 let mut current_n_vars = 0;
258 for transparent_n_vars in transparent_n_vars_iter {
259 match transparent_n_vars.cmp(¤t_n_vars) {
260 Ordering::Less => return Err(Error::TransparentsNotSorted),
261 Ordering::Greater => {
262 let current_desc = &sumcheck_claim_descs[current_n_vars];
263 let offset = current_desc.transparent_indices.end;
264
265 current_n_vars = transparent_n_vars;
266 let next_desc = &mut sumcheck_claim_descs[current_n_vars];
267 next_desc.transparent_indices = offset..offset;
268 }
269 _ => {}
270 }
271
272 sumcheck_claim_descs[current_n_vars].transparent_indices.end += 1;
273 }
274
275 for (i, claim) in claims.iter().enumerate() {
279 let claim_desc = &mut sumcheck_claim_descs[claim.n_vars];
280
281 if !claim_desc.committed_indices.contains(&claim.committed) {
284 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
285 }
286 if !claim_desc.transparent_indices.contains(&claim.transparent) {
287 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
288 }
289
290 let composition = IndexComposition::new(
291 claim_desc.committed_indices.len() + claim_desc.transparent_indices.len(),
292 [
293 claim.committed - claim_desc.committed_indices.start,
294 claim_desc.committed_indices.len() + claim.transparent
295 - claim_desc.transparent_indices.start,
296 ],
297 BivariateProduct::default(),
298 )
299 .expect(
300 "claim.committed and claim.transparent are checked to be in the correct ranges above",
301 );
302 claim_desc.composite_sums.push(CompositeSumClaim {
303 sum: claim.sum,
304 composition,
305 });
306 }
307
308 Ok(sumcheck_claim_descs)
309}
310
311#[instrument("piop::verify", skip_all)]
324pub fn verify<'a, F, FEncode, Challenger_, MTScheme>(
325 commit_meta: &CommitMeta,
326 merkle_scheme: &MTScheme,
327 fri_params: &FRIParams<F, FEncode>,
328 commitment: &MTScheme::Digest,
329 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'a>],
330 claims: &[PIOPSumcheckClaim<F>],
331 transcript: &mut VerifierTranscript<Challenger_>,
332) -> Result<(), Error>
333where
334 F: TowerField + ExtensionField<FEncode>,
335 FEncode: BinaryField,
336 Challenger_: Challenger,
337 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
338{
339 let sumcheck_claim_descs = make_sumcheck_claim_descs(
341 commit_meta,
342 transparents.iter().map(|poly| poly.borrow().n_vars()),
343 claims,
344 )?;
345
346 let non_empty_sumcheck_descs = sumcheck_claim_descs
347 .iter()
348 .enumerate()
349 .filter(|(_n_vars, desc)| !desc.committed_indices.is_empty());
353 let sumcheck_claims = non_empty_sumcheck_descs
354 .clone()
355 .map(|(n_vars, desc)| {
356 SumcheckClaim::new(
359 n_vars,
360 desc.committed_indices.len() + desc.transparent_indices.len(),
361 desc.composite_sums.clone(),
362 )
363 })
364 .collect::<Result<Vec<_>, _>>()?;
365
366 let BatchInterleavedSumcheckFRIOutput {
368 challenges,
369 multilinear_evals,
370 fri_final,
371 } = verify_interleaved_fri_sumcheck(
372 commit_meta.total_vars(),
373 fri_params,
374 merkle_scheme,
375 &sumcheck_claims,
376 commitment,
377 transcript,
378 )?;
379
380 let mut piecewise_evals = verify_transparent_evals(
381 commit_meta,
382 non_empty_sumcheck_descs,
383 multilinear_evals,
384 transparents,
385 &challenges,
386 )?;
387
388 piecewise_evals.reverse();
390 let n_pieces_by_vars = sumcheck_claim_descs
391 .iter()
392 .map(|desc| desc.n_committed())
393 .collect::<Vec<_>>();
394 let piecewise_eval =
395 evaluate_piecewise_multilinear(&challenges, &n_pieces_by_vars, &mut piecewise_evals)?;
396 if piecewise_eval != fri_final {
397 return Err(VerificationError::IncorrectSumcheckEvaluation.into());
398 }
399
400 Ok(())
401}
402
403#[instrument(skip_all, level = "debug")]
405fn verify_transparent_evals<'a, 'b, F: Field>(
406 commit_meta: &CommitMeta,
407 sumcheck_descs: impl Iterator<Item = (usize, &'a SumcheckClaimDesc<F>)>,
408 multilinear_evals: Vec<Vec<F>>,
409 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'b>],
410 challenges: &[F],
411) -> Result<Vec<F>, Error> {
412 let mut challenges_rev = challenges.to_vec();
415 challenges_rev.reverse();
416 let n_challenges = challenges.len();
417
418 let mut piecewise_evals = Vec::with_capacity(commit_meta.total_multilins());
419 for ((n_vars, desc), multilinear_evals) in iter::zip(sumcheck_descs, multilinear_evals) {
420 let (committed_evals, transparent_evals) = multilinear_evals.split_at(desc.n_committed());
421 piecewise_evals.extend_from_slice(committed_evals);
422
423 assert_eq!(transparent_evals.len(), desc.n_transparent());
424 for (i, (&claimed_eval, transparent)) in
425 iter::zip(transparent_evals, &transparents[desc.transparent_indices.clone()])
426 .enumerate()
427 {
428 let computed_eval = transparent
429 .borrow()
430 .evaluate(&challenges_rev[n_challenges - n_vars..])?;
431 if claimed_eval != computed_eval {
432 return Err(VerificationError::IncorrectTransparentEvaluation {
433 index: desc.transparent_indices.start + i,
434 }
435 .into());
436 }
437 }
438 }
439 Ok(piecewise_evals)
440}
441
442#[derive(Debug)]
443struct BatchInterleavedSumcheckFRIOutput<F> {
444 challenges: Vec<F>,
445 multilinear_evals: Vec<Vec<F>>,
446 fri_final: F,
447}
448
449#[instrument(skip_all)]
457fn verify_interleaved_fri_sumcheck<F, FEncode, Challenger_, MTScheme>(
458 n_rounds: usize,
459 fri_params: &FRIParams<F, FEncode>,
460 merkle_scheme: &MTScheme,
461 claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
462 codeword_commitment: &MTScheme::Digest,
463 proof: &mut VerifierTranscript<Challenger_>,
464) -> Result<BatchInterleavedSumcheckFRIOutput<F>, Error>
465where
466 F: TowerField + ExtensionField<FEncode>,
467 FEncode: BinaryField,
468 Challenger_: Challenger,
469 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
470{
471 let mut arities_iter = fri_params.fold_arities().iter();
472 let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles());
473 let mut next_commit_round = arities_iter.next().copied();
474
475 let mut sumcheck_verifier = SumcheckBatchVerifier::new(claims, proof)?;
476 let mut multilinear_evals = Vec::with_capacity(claims.len());
477 let mut challenges = Vec::with_capacity(n_rounds);
478 for round_no in 0..n_rounds {
479 let mut reader = proof.message();
480 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
481 multilinear_evals.push(claim_multilinear_evals);
482 }
483 sumcheck_verifier.receive_round_proof(&mut reader)?;
484
485 let challenge = proof.sample();
486 challenges.push(challenge);
487
488 sumcheck_verifier.finish_round(challenge)?;
489
490 let observe_fri_comm = next_commit_round.is_some_and(|round| round == round_no + 1);
491 if observe_fri_comm {
492 let comm = proof
493 .message()
494 .read()
495 .map_err(VerificationError::Transcript)?;
496 fri_commitments.push(comm);
497 next_commit_round = arities_iter.next().map(|arity| round_no + 1 + arity);
498 }
499 }
500
501 let mut reader = proof.message();
502 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
503 multilinear_evals.push(claim_multilinear_evals);
504 }
505 sumcheck_verifier.finish()?;
506
507 let verifier = FRIVerifier::new(
508 fri_params,
509 merkle_scheme,
510 codeword_commitment,
511 &fri_commitments,
512 &challenges,
513 )?;
514 let fri_final = verifier.verify(proof)?;
515
516 Ok(BatchInterleavedSumcheckFRIOutput {
517 challenges,
518 multilinear_evals,
519 fri_final,
520 })
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_commit_meta_new_empty() {
529 let n_multilins_by_vars = vec![];
530 let commit_meta = CommitMeta::new(n_multilins_by_vars);
531
532 assert_eq!(commit_meta.total_vars, 0);
533 assert_eq!(commit_meta.total_multilins, 0);
534 assert!(commit_meta.n_multilins_by_vars.is_empty());
535 assert!(commit_meta.offsets_by_vars.is_empty());
536 }
537
538 #[test]
539 fn test_commit_meta_new_single_variable() {
540 let n_multilins_by_vars = vec![4];
541 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
542
543 assert_eq!(commit_meta.total_vars, 2);
544 assert_eq!(commit_meta.total_multilins, 4);
545 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
546 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
547 }
548
549 #[test]
550 fn test_commit_meta_new_multiple_variables() {
551 let n_multilins_by_vars = vec![3, 5, 2];
552
553 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
554
555 assert_eq!(commit_meta.total_vars, 5);
557 assert_eq!(commit_meta.total_multilins, 10);
559 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
560 assert_eq!(commit_meta.offsets_by_vars, vec![0, 3, 8]);
561 }
562
563 #[test]
564 #[allow(clippy::identity_op)]
565 fn test_commit_meta_new_large_numbers() {
566 let n_multilins_by_vars = vec![1_000_000, 2_000_000];
567 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
568
569 let expected_total_elems = 1_000_000 * (1 << 0) + 2_000_000 * (1 << 1) as usize;
570 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
571
572 assert_eq!(commit_meta.total_vars, expected_total_vars);
573 assert_eq!(commit_meta.total_multilins, 3_000_000);
574 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
575 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1_000_000]);
576 }
577
578 #[test]
579 fn test_with_vars_empty() {
580 let commit_meta = CommitMeta::with_vars(vec![]);
581
582 assert_eq!(commit_meta.total_vars, 0);
583 assert_eq!(commit_meta.total_multilins, 0);
584 assert!(commit_meta.n_multilins_by_vars().is_empty());
585 assert!(commit_meta.offsets_by_vars.is_empty());
586 }
587
588 #[test]
589 fn test_with_vars_single_variable() {
590 let commit_meta = CommitMeta::with_vars(vec![0, 0, 0, 0]);
591
592 assert_eq!(commit_meta.total_vars, 2);
593 assert_eq!(commit_meta.total_multilins, 4);
594 assert_eq!(commit_meta.n_multilins_by_vars(), &[4]);
595 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
596 }
597
598 #[test]
599 #[allow(clippy::identity_op)]
600 fn test_with_vars_multiple_variables() {
601 let commit_meta = CommitMeta::with_vars(vec![2, 3, 3, 4]);
602
603 let expected_total_elems = 1 * (1 << 2) + 2 * (1 << 3) + 1 * (1 << 4) as usize;
604 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
605
606 assert_eq!(commit_meta.total_vars, expected_total_vars);
607 assert_eq!(commit_meta.total_multilins, 4);
608 assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 0, 1, 2, 1]);
609 assert_eq!(commit_meta.offsets_by_vars, vec![0, 0, 0, 1, 3]);
610 }
611
612 #[test]
613 fn test_with_vars_large_numbers() {
614 let vars = vec![0; 1_000_000];
616 let commit_meta = CommitMeta::with_vars(vars);
617
618 let expected_total_elems = 1_000_000 * (1 << 0) as usize;
620 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
621
622 assert_eq!(commit_meta.total_vars, expected_total_vars);
623 assert_eq!(commit_meta.total_multilins, 1_000_000);
624 assert_eq!(commit_meta.n_multilins_by_vars(), &[1_000_000]);
625 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
626 }
627
628 #[test]
629 #[allow(clippy::identity_op)]
630 fn test_with_vars_mixed_variables() {
631 let vars = vec![0, 1, 1, 2, 2, 2, 3];
632 let commit_meta = CommitMeta::with_vars(vars);
633
634 let expected_total_elems =
636 1 * (1 << 0) + 2 * (1 << 1) + 3 * (1 << 2) + 1 * (1 << 3) as usize;
637 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
638
639 assert_eq!(commit_meta.total_vars, expected_total_vars);
640 assert_eq!(commit_meta.total_multilins, 7); assert_eq!(commit_meta.n_multilins_by_vars(), &[1, 2, 3, 1]);
642 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1, 3, 6]);
643 }
644}