1use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range};
4
5use binius_field::{BinaryField, ExtensionField, Field, TowerField};
6use binius_math::evaluate_piecewise_multilinear;
7use binius_utils::{bail, checked_arithmetics::log2_ceil_usize, DeserializeBytes};
8use getset::CopyGetters;
9use tracing::instrument;
10
11use super::error::{Error, VerificationError};
12use crate::{
13 composition::{BivariateProduct, IndexComposition},
14 fiat_shamir::{CanSample, Challenger},
15 merkle_tree::MerkleTreeScheme,
16 piop::util::ResizeableIndex,
17 polynomial::MultivariatePoly,
18 protocols::{
19 fri::{self, estimate_optimal_arity, FRIParams, FRIVerifier},
20 sumcheck::{
21 front_loaded::BatchVerifier as SumcheckBatchVerifier, CompositeSumClaim, SumcheckClaim,
22 },
23 },
24 reed_solomon::reed_solomon::ReedSolomonCode,
25 transcript::VerifierTranscript,
26};
27
28#[derive(Debug, CopyGetters)]
36pub struct CommitMeta {
37 n_multilins_by_vars: Vec<usize>,
38 offsets_by_vars: Vec<usize>,
39 #[getset(get_copy = "pub")]
41 total_vars: usize,
42 #[getset(get_copy = "pub")]
44 total_multilins: usize,
45}
46
47impl CommitMeta {
48 pub fn new(n_multilins_by_vars: Vec<usize>) -> Self {
55 let (offsets_by_vars, total_multilins, total_elems) =
56 n_multilins_by_vars.iter().enumerate().fold(
57 (Vec::with_capacity(n_multilins_by_vars.len()), 0, 0),
58 |(mut offsets, total_multilins, total_elems), (n_vars, &count)| {
59 offsets.push(total_multilins);
60 (offsets, total_multilins + count, total_elems + (count << n_vars))
61 },
62 );
63
64 Self {
65 offsets_by_vars,
66 n_multilins_by_vars,
67 total_vars: total_elems.next_power_of_two().ilog2() as usize,
68 total_multilins,
69 }
70 }
71
72 pub fn with_vars(n_varss: impl IntoIterator<Item = usize>) -> Self {
75 let mut n_multilins_by_vars = ResizeableIndex::new();
76 for n_vars in n_varss {
77 *n_multilins_by_vars.get_mut(n_vars) += 1;
78 }
79 Self::new(n_multilins_by_vars.into_vec())
80 }
81
82 pub fn max_n_vars(&self) -> usize {
84 self.n_multilins_by_vars.len().saturating_sub(1)
85 }
86
87 pub fn n_multilins_by_vars(&self) -> &[usize] {
90 &self.n_multilins_by_vars
91 }
92
93 pub fn range_by_vars(&self, n_vars: usize) -> Range<usize> {
95 let start = self.offsets_by_vars[n_vars];
96 start..start + self.n_multilins_by_vars[n_vars]
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct PIOPSumcheckClaim<F: Field> {
106 pub n_vars: usize,
108 pub committed: usize,
110 pub transparent: usize,
112 pub sum: F,
115}
116
117fn make_commit_params_with_constant_arity<F, FEncode>(
118 commit_meta: &CommitMeta,
119 security_bits: usize,
120 log_inv_rate: usize,
121 arity: usize,
122) -> Result<FRIParams<F, FEncode>, Error>
123where
124 F: BinaryField + ExtensionField<FEncode>,
125 FEncode: BinaryField,
126{
127 assert!(arity > 0);
128
129 let log_dim = commit_meta.total_vars.saturating_sub(arity);
130 let log_batch_size = commit_meta.total_vars.min(arity);
131 let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate)?;
132 let n_test_queries = fri::calculate_n_test_queries::<F, _>(security_bits, &rs_code)?;
133
134 let cap_height = log2_ceil_usize(n_test_queries);
135 let fold_arities = std::iter::repeat_n(
136 arity,
137 (commit_meta.total_vars + log_inv_rate).saturating_sub(cap_height) / arity,
138 )
139 .collect::<Vec<_>>();
140 let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?;
165 Ok(fri_params)
166}
167
168pub fn make_commit_params_with_optimal_arity<F, FEncode, MTScheme>(
169 commit_meta: &CommitMeta,
170 _merkle_scheme: &MTScheme,
171 security_bits: usize,
172 log_inv_rate: usize,
173) -> Result<FRIParams<F, FEncode>, Error>
174where
175 F: BinaryField + ExtensionField<FEncode>,
176 FEncode: BinaryField,
177 MTScheme: MerkleTreeScheme<F>,
178{
179 let arity = estimate_optimal_arity(
180 commit_meta.total_vars + log_inv_rate,
181 size_of::<MTScheme::Digest>(),
182 size_of::<F>(),
183 );
184 make_commit_params_with_constant_arity(commit_meta, security_bits, log_inv_rate, arity)
185}
186
187#[derive(Debug, Clone)]
193pub struct SumcheckClaimDesc<F: Field> {
194 pub committed_indices: Range<usize>,
195 pub transparent_indices: Range<usize>,
196 pub composite_sums: Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>>,
197}
198
199impl<F: Field> SumcheckClaimDesc<F> {
200 pub fn n_committed(&self) -> usize {
201 self.committed_indices.len()
202 }
203
204 pub fn n_transparent(&self) -> usize {
205 self.transparent_indices.len()
206 }
207}
208
209pub fn make_sumcheck_claim_descs<F: Field>(
210 commit_meta: &CommitMeta,
211 transparent_n_vars_iter: impl Iterator<Item = usize>,
212 claims: &[PIOPSumcheckClaim<F>],
213) -> Result<Vec<SumcheckClaimDesc<F>>, Error> {
214 let mut sumcheck_claim_descs = vec![
216 SumcheckClaimDesc {
217 committed_indices: 0..0,
218 transparent_indices: 0..0,
219 composite_sums: vec![],
220 };
221 commit_meta.max_n_vars() + 1
222 ];
223
224 let mut last_offset = 0;
226 for (&n_multilins, claim_desc) in
227 iter::zip(commit_meta.n_multilins_by_vars(), &mut sumcheck_claim_descs)
228 {
229 claim_desc.committed_indices.start = last_offset;
230 last_offset += n_multilins;
231 claim_desc.committed_indices.end = last_offset;
232 }
233
234 let mut current_n_vars = 0;
237 for transparent_n_vars in transparent_n_vars_iter {
238 match transparent_n_vars.cmp(¤t_n_vars) {
239 Ordering::Less => return Err(Error::TransparentsNotSorted),
240 Ordering::Greater => {
241 let current_desc = &sumcheck_claim_descs[current_n_vars];
242 let offset = current_desc.transparent_indices.end;
243
244 current_n_vars = transparent_n_vars;
245 let next_desc = &mut sumcheck_claim_descs[current_n_vars];
246 next_desc.transparent_indices = offset..offset;
247 }
248 _ => {}
249 }
250
251 sumcheck_claim_descs[current_n_vars].transparent_indices.end += 1;
252 }
253
254 for (i, claim) in claims.iter().enumerate() {
258 let claim_desc = &mut sumcheck_claim_descs[claim.n_vars];
259
260 if !claim_desc.committed_indices.contains(&claim.committed) {
263 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
264 }
265 if !claim_desc.transparent_indices.contains(&claim.transparent) {
266 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
267 }
268
269 let composition = IndexComposition::new(
270 claim_desc.committed_indices.len() + claim_desc.transparent_indices.len(),
271 [
272 claim.committed - claim_desc.committed_indices.start,
273 claim_desc.committed_indices.len() + claim.transparent
274 - claim_desc.transparent_indices.start,
275 ],
276 BivariateProduct::default(),
277 )
278 .expect(
279 "claim.committed and claim.transparent are checked to be in the correct ranges above",
280 );
281 claim_desc.composite_sums.push(CompositeSumClaim {
282 sum: claim.sum,
283 composition,
284 });
285 }
286
287 Ok(sumcheck_claim_descs)
288}
289
290#[instrument("piop::verify", skip_all)]
303pub fn verify<'a, F, FEncode, Challenger_, MTScheme>(
304 commit_meta: &CommitMeta,
305 merkle_scheme: &MTScheme,
306 fri_params: &FRIParams<F, FEncode>,
307 commitment: &MTScheme::Digest,
308 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'a>],
309 claims: &[PIOPSumcheckClaim<F>],
310 transcript: &mut VerifierTranscript<Challenger_>,
311) -> Result<(), Error>
312where
313 F: TowerField + ExtensionField<FEncode>,
314 FEncode: BinaryField,
315 Challenger_: Challenger,
316 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
317{
318 let sumcheck_claim_descs = make_sumcheck_claim_descs(
320 commit_meta,
321 transparents.iter().map(|poly| poly.borrow().n_vars()),
322 claims,
323 )?;
324
325 let non_empty_sumcheck_descs = sumcheck_claim_descs
326 .iter()
327 .enumerate()
328 .filter(|(_n_vars, desc)| !desc.committed_indices.is_empty());
332 let sumcheck_claims = non_empty_sumcheck_descs
333 .clone()
334 .map(|(n_vars, desc)| {
335 SumcheckClaim::new(
338 n_vars,
339 desc.committed_indices.len() + desc.transparent_indices.len(),
340 desc.composite_sums.clone(),
341 )
342 })
343 .collect::<Result<Vec<_>, _>>()?;
344
345 let BatchInterleavedSumcheckFRIOutput {
347 challenges,
348 multilinear_evals,
349 fri_final,
350 } = verify_interleaved_fri_sumcheck(
351 commit_meta.total_vars(),
352 fri_params,
353 merkle_scheme,
354 &sumcheck_claims,
355 commitment,
356 transcript,
357 )?;
358
359 let mut piecewise_evals = verify_transparent_evals(
360 commit_meta,
361 non_empty_sumcheck_descs,
362 multilinear_evals,
363 transparents,
364 &challenges,
365 )?;
366
367 piecewise_evals.reverse();
369 let n_pieces_by_vars = sumcheck_claim_descs
370 .iter()
371 .map(|desc| desc.n_committed())
372 .collect::<Vec<_>>();
373 let piecewise_eval =
374 evaluate_piecewise_multilinear(&challenges, &n_pieces_by_vars, &mut piecewise_evals)?;
375 if piecewise_eval != fri_final {
376 return Err(VerificationError::IncorrectSumcheckEvaluation.into());
377 }
378
379 Ok(())
380}
381
382#[instrument(skip_all, level = "debug")]
384fn verify_transparent_evals<'a, 'b, F: Field>(
385 commit_meta: &CommitMeta,
386 sumcheck_descs: impl Iterator<Item = (usize, &'a SumcheckClaimDesc<F>)>,
387 multilinear_evals: Vec<Vec<F>>,
388 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'b>],
389 challenges: &[F],
390) -> Result<Vec<F>, Error> {
391 let mut challenges_rev = challenges.to_vec();
394 challenges_rev.reverse();
395 let n_challenges = challenges.len();
396
397 let mut piecewise_evals = Vec::with_capacity(commit_meta.total_multilins());
398 for ((n_vars, desc), multilinear_evals) in iter::zip(sumcheck_descs, multilinear_evals) {
399 let (committed_evals, transparent_evals) = multilinear_evals.split_at(desc.n_committed());
400 piecewise_evals.extend_from_slice(committed_evals);
401
402 assert_eq!(transparent_evals.len(), desc.n_transparent());
403 for (i, (&claimed_eval, transparent)) in
404 iter::zip(transparent_evals, &transparents[desc.transparent_indices.clone()])
405 .enumerate()
406 {
407 let computed_eval = transparent
408 .borrow()
409 .evaluate(&challenges_rev[n_challenges - n_vars..])?;
410 if claimed_eval != computed_eval {
411 return Err(VerificationError::IncorrectTransparentEvaluation {
412 index: desc.transparent_indices.start + i,
413 }
414 .into());
415 }
416 }
417 }
418 Ok(piecewise_evals)
419}
420
421#[derive(Debug)]
422struct BatchInterleavedSumcheckFRIOutput<F> {
423 challenges: Vec<F>,
424 multilinear_evals: Vec<Vec<F>>,
425 fri_final: F,
426}
427
428#[instrument(skip_all)]
436fn verify_interleaved_fri_sumcheck<F, FEncode, Challenger_, MTScheme>(
437 n_rounds: usize,
438 fri_params: &FRIParams<F, FEncode>,
439 merkle_scheme: &MTScheme,
440 claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
441 codeword_commitment: &MTScheme::Digest,
442 proof: &mut VerifierTranscript<Challenger_>,
443) -> Result<BatchInterleavedSumcheckFRIOutput<F>, Error>
444where
445 F: TowerField + ExtensionField<FEncode>,
446 FEncode: BinaryField,
447 Challenger_: Challenger,
448 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
449{
450 let mut arities_iter = fri_params.fold_arities().iter();
451 let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles());
452 let mut next_commit_round = arities_iter.next().copied();
453
454 let mut sumcheck_verifier = SumcheckBatchVerifier::new(claims, proof)?;
455 let mut multilinear_evals = Vec::with_capacity(claims.len());
456 let mut challenges = Vec::with_capacity(n_rounds);
457 for round_no in 0..n_rounds {
458 let mut reader = proof.message();
459 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
460 multilinear_evals.push(claim_multilinear_evals);
461 }
462 sumcheck_verifier.receive_round_proof(&mut reader)?;
463
464 let challenge = proof.sample();
465 challenges.push(challenge);
466
467 sumcheck_verifier.finish_round(challenge)?;
468
469 let observe_fri_comm = next_commit_round.is_some_and(|round| round == round_no + 1);
470 if observe_fri_comm {
471 let comm = proof
472 .message()
473 .read()
474 .map_err(VerificationError::Transcript)?;
475 fri_commitments.push(comm);
476 next_commit_round = arities_iter.next().map(|arity| round_no + 1 + arity);
477 }
478 }
479
480 let mut reader = proof.message();
481 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
482 multilinear_evals.push(claim_multilinear_evals);
483 }
484 sumcheck_verifier.finish()?;
485
486 let verifier = FRIVerifier::new(
487 fri_params,
488 merkle_scheme,
489 codeword_commitment,
490 &fri_commitments,
491 &challenges,
492 )?;
493 let fri_final = verifier.verify(proof)?;
494
495 Ok(BatchInterleavedSumcheckFRIOutput {
496 challenges,
497 multilinear_evals,
498 fri_final,
499 })
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_commit_meta_new_empty() {
508 let n_multilins_by_vars = vec![];
509 let commit_meta = CommitMeta::new(n_multilins_by_vars);
510
511 assert_eq!(commit_meta.total_vars, 0);
512 assert_eq!(commit_meta.total_multilins, 0);
513 assert!(commit_meta.n_multilins_by_vars.is_empty());
514 assert!(commit_meta.offsets_by_vars.is_empty());
515 }
516
517 #[test]
518 fn test_commit_meta_new_single_variable() {
519 let n_multilins_by_vars = vec![4];
520 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
521
522 assert_eq!(commit_meta.total_vars, 2);
523 assert_eq!(commit_meta.total_multilins, 4);
524 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
525 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
526 }
527
528 #[test]
529 fn test_commit_meta_new_multiple_variables() {
530 let n_multilins_by_vars = vec![3, 5, 2];
531
532 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
533
534 assert_eq!(commit_meta.total_vars, 5);
536 assert_eq!(commit_meta.total_multilins, 10);
538 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
539 assert_eq!(commit_meta.offsets_by_vars, vec![0, 3, 8]);
540 }
541
542 #[test]
543 #[allow(clippy::identity_op)]
544 fn test_commit_meta_new_large_numbers() {
545 let n_multilins_by_vars = vec![1_000_000, 2_000_000];
546 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
547
548 let expected_total_elems = 1_000_000 * (1 << 0) + 2_000_000 * (1 << 1) as usize;
549 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
550
551 assert_eq!(commit_meta.total_vars, expected_total_vars);
552 assert_eq!(commit_meta.total_multilins, 3_000_000);
553 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
554 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1_000_000]);
555 }
556
557 #[test]
558 fn test_with_vars_empty() {
559 let commit_meta = CommitMeta::with_vars(vec![]);
560
561 assert_eq!(commit_meta.total_vars, 0);
562 assert_eq!(commit_meta.total_multilins, 0);
563 assert!(commit_meta.n_multilins_by_vars().is_empty());
564 assert!(commit_meta.offsets_by_vars.is_empty());
565 }
566
567 #[test]
568 fn test_with_vars_single_variable() {
569 let commit_meta = CommitMeta::with_vars(vec![0, 0, 0, 0]);
570
571 assert_eq!(commit_meta.total_vars, 2);
572 assert_eq!(commit_meta.total_multilins, 4);
573 assert_eq!(commit_meta.n_multilins_by_vars(), &[4]);
574 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
575 }
576
577 #[test]
578 #[allow(clippy::identity_op)]
579 fn test_with_vars_multiple_variables() {
580 let commit_meta = CommitMeta::with_vars(vec![2, 3, 3, 4]);
581
582 let expected_total_elems = 1 * (1 << 2) + 2 * (1 << 3) + 1 * (1 << 4) as usize;
583 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
584
585 assert_eq!(commit_meta.total_vars, expected_total_vars);
586 assert_eq!(commit_meta.total_multilins, 4);
587 assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 0, 1, 2, 1]);
588 assert_eq!(commit_meta.offsets_by_vars, vec![0, 0, 0, 1, 3]);
589 }
590
591 #[test]
592 fn test_with_vars_large_numbers() {
593 let vars = vec![0; 1_000_000];
595 let commit_meta = CommitMeta::with_vars(vars);
596
597 let expected_total_elems = 1_000_000 * (1 << 0) as usize;
599 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
600
601 assert_eq!(commit_meta.total_vars, expected_total_vars);
602 assert_eq!(commit_meta.total_multilins, 1_000_000);
603 assert_eq!(commit_meta.n_multilins_by_vars(), &[1_000_000]);
604 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
605 }
606
607 #[test]
608 #[allow(clippy::identity_op)]
609 fn test_with_vars_mixed_variables() {
610 let vars = vec![0, 1, 1, 2, 2, 2, 3];
611 let commit_meta = CommitMeta::with_vars(vars);
612
613 let expected_total_elems =
615 1 * (1 << 0) + 2 * (1 << 1) + 3 * (1 << 2) + 1 * (1 << 3) as usize;
616 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
617
618 assert_eq!(commit_meta.total_vars, expected_total_vars);
619 assert_eq!(commit_meta.total_multilins, 7); assert_eq!(commit_meta.n_multilins_by_vars(), &[1, 2, 3, 1]);
621 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1, 3, 6]);
622 }
623}