1use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range};
4
5use binius_field::{BinaryField, ExtensionField, Field, TowerField};
6use binius_math::evaluate_piecewise_multilinear;
7use binius_ntt::{AdditiveNTT, SingleThreadedNTT};
8use binius_utils::{DeserializeBytes, bail};
9use getset::CopyGetters;
10use tracing::instrument;
11
12use super::error::{Error, VerificationError};
13use crate::{
14 composition::{BivariateProduct, IndexComposition},
15 fiat_shamir::{CanSample, Challenger},
16 merkle_tree::MerkleTreeScheme,
17 piop::util::ResizeableIndex,
18 polynomial::MultivariatePoly,
19 protocols::{
20 fri::{FRIParams, FRIVerifier, estimate_optimal_arity},
21 sumcheck::{
22 CompositeSumClaim, SumcheckClaim, front_loaded::BatchVerifier as SumcheckBatchVerifier,
23 },
24 },
25 transcript::VerifierTranscript,
26};
27
28#[derive(Debug, CopyGetters)]
36pub struct CommitMeta {
37 n_multilins_by_vars: Vec<usize>,
38 offsets_by_vars: Vec<usize>,
39 #[getset(get_copy = "pub")]
41 total_vars: usize,
42 #[getset(get_copy = "pub")]
44 total_multilins: usize,
45}
46
47impl CommitMeta {
48 pub fn new(n_multilins_by_vars: Vec<usize>) -> Self {
55 let (offsets_by_vars, total_multilins, total_elems) =
56 n_multilins_by_vars.iter().enumerate().fold(
57 (Vec::with_capacity(n_multilins_by_vars.len()), 0, 0),
58 |(mut offsets, total_multilins, total_elems), (n_vars, &count)| {
59 offsets.push(total_multilins);
60 (offsets, total_multilins + count, total_elems + (count << n_vars))
61 },
62 );
63
64 Self {
65 offsets_by_vars,
66 n_multilins_by_vars,
67 total_vars: total_elems.next_power_of_two().ilog2() as usize,
68 total_multilins,
69 }
70 }
71
72 pub fn with_vars(n_varss: impl IntoIterator<Item = usize>) -> Self {
75 let mut n_multilins_by_vars = ResizeableIndex::new();
76 for n_vars in n_varss {
77 *n_multilins_by_vars.get_mut(n_vars) += 1;
78 }
79 Self::new(n_multilins_by_vars.into_vec())
80 }
81
82 pub fn max_n_vars(&self) -> usize {
84 self.n_multilins_by_vars.len().saturating_sub(1)
85 }
86
87 pub fn n_multilins_by_vars(&self) -> &[usize] {
90 &self.n_multilins_by_vars
91 }
92
93 pub fn range_by_vars(&self, n_vars: usize) -> Range<usize> {
95 let start = self.offsets_by_vars[n_vars];
96 start..start + self.n_multilins_by_vars[n_vars]
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct PIOPSumcheckClaim<F: Field> {
106 pub n_vars: usize,
108 pub committed: usize,
110 pub transparent: usize,
112 pub sum: F,
115}
116
117fn make_commit_params_with_constant_arity<F, FEncode>(
118 ntt: &impl AdditiveNTT<FEncode>,
119 commit_meta: &CommitMeta,
120 security_bits: usize,
121 log_inv_rate: usize,
122 arity: usize,
123) -> Result<FRIParams<F, FEncode>, Error>
124where
125 F: BinaryField + ExtensionField<FEncode>,
126 FEncode: BinaryField,
127{
128 let params = FRIParams::choose_with_constant_fold_arity(
129 ntt,
130 commit_meta.total_vars(),
131 security_bits,
132 log_inv_rate,
133 arity,
134 )?;
135 Ok(params)
136}
137
138pub fn make_commit_params_with_optimal_arity<F, FEncode, MTScheme>(
147 commit_meta: &CommitMeta,
148 _merkle_scheme: &MTScheme,
149 security_bits: usize,
150 log_inv_rate: usize,
151) -> Result<FRIParams<F, FEncode>, Error>
152where
153 F: BinaryField + ExtensionField<FEncode>,
154 FEncode: BinaryField,
155 MTScheme: MerkleTreeScheme<F>,
156{
157 let ntt = SingleThreadedNTT::<FEncode>::new(FEncode::N_BITS)?;
161
162 let arity = estimate_optimal_arity(
163 commit_meta.total_vars + log_inv_rate,
164 size_of::<MTScheme::Digest>(),
165 size_of::<F>(),
166 );
167 make_commit_params_with_constant_arity(&ntt, commit_meta, security_bits, log_inv_rate, arity)
168}
169
170#[derive(Debug, Clone)]
176pub struct SumcheckClaimDesc<F: Field> {
177 pub committed_indices: Range<usize>,
178 pub transparent_indices: Range<usize>,
179 pub composite_sums: Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>>,
180}
181
182impl<F: Field> SumcheckClaimDesc<F> {
183 pub fn n_committed(&self) -> usize {
184 self.committed_indices.len()
185 }
186
187 pub fn n_transparent(&self) -> usize {
188 self.transparent_indices.len()
189 }
190}
191
192pub fn make_sumcheck_claim_descs<F: Field>(
193 commit_meta: &CommitMeta,
194 transparent_n_vars_iter: impl Iterator<Item = usize>,
195 claims: &[PIOPSumcheckClaim<F>],
196) -> Result<Vec<SumcheckClaimDesc<F>>, Error> {
197 let mut sumcheck_claim_descs = vec![
199 SumcheckClaimDesc {
200 committed_indices: 0..0,
201 transparent_indices: 0..0,
202 composite_sums: vec![],
203 };
204 commit_meta.max_n_vars() + 1
205 ];
206
207 let mut last_offset = 0;
209 for (&n_multilins, claim_desc) in
210 iter::zip(commit_meta.n_multilins_by_vars(), &mut sumcheck_claim_descs)
211 {
212 claim_desc.committed_indices.start = last_offset;
213 last_offset += n_multilins;
214 claim_desc.committed_indices.end = last_offset;
215 }
216
217 let mut current_n_vars = 0;
220 for transparent_n_vars in transparent_n_vars_iter {
221 match transparent_n_vars.cmp(¤t_n_vars) {
222 Ordering::Less => return Err(Error::TransparentsNotSorted),
223 Ordering::Greater => {
224 let current_desc = &sumcheck_claim_descs[current_n_vars];
225 let offset = current_desc.transparent_indices.end;
226
227 current_n_vars = transparent_n_vars;
228 let next_desc = &mut sumcheck_claim_descs[current_n_vars];
229 next_desc.transparent_indices = offset..offset;
230 }
231 _ => {}
232 }
233
234 sumcheck_claim_descs[current_n_vars].transparent_indices.end += 1;
235 }
236
237 for (i, claim) in claims.iter().enumerate() {
241 let claim_desc = &mut sumcheck_claim_descs[claim.n_vars];
242
243 if !claim_desc.committed_indices.contains(&claim.committed) {
246 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
247 }
248 if !claim_desc.transparent_indices.contains(&claim.transparent) {
249 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
250 }
251
252 let composition = IndexComposition::new(
253 claim_desc.committed_indices.len() + claim_desc.transparent_indices.len(),
254 [
255 claim.committed - claim_desc.committed_indices.start,
256 claim_desc.committed_indices.len() + claim.transparent
257 - claim_desc.transparent_indices.start,
258 ],
259 BivariateProduct::default(),
260 )
261 .expect(
262 "claim.committed and claim.transparent are checked to be in the correct ranges above",
263 );
264 claim_desc.composite_sums.push(CompositeSumClaim {
265 sum: claim.sum,
266 composition,
267 });
268 }
269
270 Ok(sumcheck_claim_descs)
271}
272
273#[instrument("piop::verify", skip_all)]
286pub fn verify<'a, F, FEncode, Challenger_, MTScheme>(
287 commit_meta: &CommitMeta,
288 merkle_scheme: &MTScheme,
289 fri_params: &FRIParams<F, FEncode>,
290 commitment: &MTScheme::Digest,
291 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'a>],
292 claims: &[PIOPSumcheckClaim<F>],
293 transcript: &mut VerifierTranscript<Challenger_>,
294) -> Result<(), Error>
295where
296 F: TowerField + ExtensionField<FEncode>,
297 FEncode: BinaryField,
298 Challenger_: Challenger,
299 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
300{
301 let sumcheck_claim_descs = make_sumcheck_claim_descs(
303 commit_meta,
304 transparents.iter().map(|poly| poly.borrow().n_vars()),
305 claims,
306 )?;
307
308 let non_empty_sumcheck_descs = sumcheck_claim_descs
309 .iter()
310 .enumerate()
311 .filter(|(_n_vars, desc)| !desc.committed_indices.is_empty());
315 let sumcheck_claims = non_empty_sumcheck_descs
316 .clone()
317 .map(|(n_vars, desc)| {
318 SumcheckClaim::new(
321 n_vars,
322 desc.committed_indices.len() + desc.transparent_indices.len(),
323 desc.composite_sums.clone(),
324 )
325 })
326 .collect::<Result<Vec<_>, _>>()?;
327
328 let BatchInterleavedSumcheckFRIOutput {
330 challenges,
331 multilinear_evals,
332 fri_final,
333 } = verify_interleaved_fri_sumcheck(
334 commit_meta.total_vars(),
335 fri_params,
336 merkle_scheme,
337 &sumcheck_claims,
338 commitment,
339 transcript,
340 )?;
341
342 let mut piecewise_evals = verify_transparent_evals(
343 commit_meta,
344 non_empty_sumcheck_descs,
345 multilinear_evals,
346 transparents,
347 &challenges,
348 )?;
349
350 piecewise_evals.reverse();
352 let n_pieces_by_vars = sumcheck_claim_descs
353 .iter()
354 .map(|desc| desc.n_committed())
355 .collect::<Vec<_>>();
356 let piecewise_eval =
357 evaluate_piecewise_multilinear(&challenges, &n_pieces_by_vars, &mut piecewise_evals)?;
358 if piecewise_eval != fri_final {
359 return Err(VerificationError::IncorrectSumcheckEvaluation.into());
360 }
361
362 Ok(())
363}
364
365#[instrument(skip_all, level = "debug")]
367fn verify_transparent_evals<'a, 'b, F: Field>(
368 commit_meta: &CommitMeta,
369 sumcheck_descs: impl Iterator<Item = (usize, &'a SumcheckClaimDesc<F>)>,
370 multilinear_evals: Vec<Vec<F>>,
371 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'b>],
372 challenges: &[F],
373) -> Result<Vec<F>, Error> {
374 let mut challenges_rev = challenges.to_vec();
377 challenges_rev.reverse();
378 let n_challenges = challenges.len();
379
380 let mut piecewise_evals = Vec::with_capacity(commit_meta.total_multilins());
381 for ((n_vars, desc), multilinear_evals) in iter::zip(sumcheck_descs, multilinear_evals) {
382 let (committed_evals, transparent_evals) = multilinear_evals.split_at(desc.n_committed());
383 piecewise_evals.extend_from_slice(committed_evals);
384
385 assert_eq!(transparent_evals.len(), desc.n_transparent());
386 for (i, (&claimed_eval, transparent)) in
387 iter::zip(transparent_evals, &transparents[desc.transparent_indices.clone()])
388 .enumerate()
389 {
390 let computed_eval = transparent
391 .borrow()
392 .evaluate(&challenges_rev[n_challenges - n_vars..])?;
393 if claimed_eval != computed_eval {
394 return Err(VerificationError::IncorrectTransparentEvaluation {
395 index: desc.transparent_indices.start + i,
396 }
397 .into());
398 }
399 }
400 }
401 Ok(piecewise_evals)
402}
403
404#[derive(Debug)]
405struct BatchInterleavedSumcheckFRIOutput<F> {
406 challenges: Vec<F>,
407 multilinear_evals: Vec<Vec<F>>,
408 fri_final: F,
409}
410
411#[instrument(skip_all)]
419fn verify_interleaved_fri_sumcheck<F, FEncode, Challenger_, MTScheme>(
420 n_rounds: usize,
421 fri_params: &FRIParams<F, FEncode>,
422 merkle_scheme: &MTScheme,
423 claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
424 codeword_commitment: &MTScheme::Digest,
425 proof: &mut VerifierTranscript<Challenger_>,
426) -> Result<BatchInterleavedSumcheckFRIOutput<F>, Error>
427where
428 F: TowerField + ExtensionField<FEncode>,
429 FEncode: BinaryField,
430 Challenger_: Challenger,
431 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
432{
433 let mut arities_iter = fri_params.fold_arities().iter();
434 let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles());
435 let mut next_commit_round = arities_iter.next().copied();
436
437 let mut sumcheck_verifier = SumcheckBatchVerifier::new(claims, proof)?;
438 let mut multilinear_evals = Vec::with_capacity(claims.len());
439 let mut challenges = Vec::with_capacity(n_rounds);
440 for round_no in 0..n_rounds {
441 let mut reader = proof.message();
442 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
443 multilinear_evals.push(claim_multilinear_evals);
444 }
445 sumcheck_verifier.receive_round_proof(&mut reader)?;
446
447 let challenge = proof.sample();
448 challenges.push(challenge);
449
450 sumcheck_verifier.finish_round(challenge)?;
451
452 let observe_fri_comm = next_commit_round.is_some_and(|round| round == round_no + 1);
453 if observe_fri_comm {
454 let comm = proof
455 .message()
456 .read()
457 .map_err(VerificationError::Transcript)?;
458 fri_commitments.push(comm);
459 next_commit_round = arities_iter.next().map(|arity| round_no + 1 + arity);
460 }
461 }
462
463 let mut reader = proof.message();
464 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
465 multilinear_evals.push(claim_multilinear_evals);
466 }
467 sumcheck_verifier.finish()?;
468
469 let verifier = FRIVerifier::new(
470 fri_params,
471 merkle_scheme,
472 codeword_commitment,
473 &fri_commitments,
474 &challenges,
475 )?;
476 let fri_final = verifier.verify(proof)?;
477
478 Ok(BatchInterleavedSumcheckFRIOutput {
479 challenges,
480 multilinear_evals,
481 fri_final,
482 })
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_commit_meta_new_empty() {
491 let n_multilins_by_vars = vec![];
492 let commit_meta = CommitMeta::new(n_multilins_by_vars);
493
494 assert_eq!(commit_meta.total_vars, 0);
495 assert_eq!(commit_meta.total_multilins, 0);
496 assert!(commit_meta.n_multilins_by_vars.is_empty());
497 assert!(commit_meta.offsets_by_vars.is_empty());
498 }
499
500 #[test]
501 fn test_commit_meta_new_single_variable() {
502 let n_multilins_by_vars = vec![4];
503 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
504
505 assert_eq!(commit_meta.total_vars, 2);
506 assert_eq!(commit_meta.total_multilins, 4);
507 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
508 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
509 }
510
511 #[test]
512 fn test_commit_meta_new_multiple_variables() {
513 let n_multilins_by_vars = vec![3, 5, 2];
514
515 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
516
517 assert_eq!(commit_meta.total_vars, 5);
519 assert_eq!(commit_meta.total_multilins, 10);
521 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
522 assert_eq!(commit_meta.offsets_by_vars, vec![0, 3, 8]);
523 }
524
525 #[test]
526 #[allow(clippy::identity_op)]
527 fn test_commit_meta_new_large_numbers() {
528 let n_multilins_by_vars = vec![1_000_000, 2_000_000];
529 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
530
531 let expected_total_elems = 1_000_000 * (1 << 0) + 2_000_000 * (1 << 1) as usize;
532 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
533
534 assert_eq!(commit_meta.total_vars, expected_total_vars);
535 assert_eq!(commit_meta.total_multilins, 3_000_000);
536 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
537 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1_000_000]);
538 }
539
540 #[test]
541 fn test_with_vars_empty() {
542 let commit_meta = CommitMeta::with_vars(vec![]);
543
544 assert_eq!(commit_meta.total_vars, 0);
545 assert_eq!(commit_meta.total_multilins, 0);
546 assert!(commit_meta.n_multilins_by_vars().is_empty());
547 assert!(commit_meta.offsets_by_vars.is_empty());
548 }
549
550 #[test]
551 fn test_with_vars_single_variable() {
552 let commit_meta = CommitMeta::with_vars(vec![0, 0, 0, 0]);
553
554 assert_eq!(commit_meta.total_vars, 2);
555 assert_eq!(commit_meta.total_multilins, 4);
556 assert_eq!(commit_meta.n_multilins_by_vars(), &[4]);
557 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
558 }
559
560 #[test]
561 #[allow(clippy::identity_op)]
562 fn test_with_vars_multiple_variables() {
563 let commit_meta = CommitMeta::with_vars(vec![2, 3, 3, 4]);
564
565 let expected_total_elems = 1 * (1 << 2) + 2 * (1 << 3) + 1 * (1 << 4) as usize;
566 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
567
568 assert_eq!(commit_meta.total_vars, expected_total_vars);
569 assert_eq!(commit_meta.total_multilins, 4);
570 assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 0, 1, 2, 1]);
571 assert_eq!(commit_meta.offsets_by_vars, vec![0, 0, 0, 1, 3]);
572 }
573
574 #[test]
575 fn test_with_vars_large_numbers() {
576 let vars = vec![0; 1_000_000];
578 let commit_meta = CommitMeta::with_vars(vars);
579
580 let expected_total_elems = 1_000_000 * (1 << 0) as usize;
582 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
583
584 assert_eq!(commit_meta.total_vars, expected_total_vars);
585 assert_eq!(commit_meta.total_multilins, 1_000_000);
586 assert_eq!(commit_meta.n_multilins_by_vars(), &[1_000_000]);
587 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
588 }
589
590 #[test]
591 #[allow(clippy::identity_op)]
592 fn test_with_vars_mixed_variables() {
593 let vars = vec![0, 1, 1, 2, 2, 2, 3];
594 let commit_meta = CommitMeta::with_vars(vars);
595
596 let expected_total_elems =
598 1 * (1 << 0) + 2 * (1 << 1) + 3 * (1 << 2) + 1 * (1 << 3) as usize;
599 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
600
601 assert_eq!(commit_meta.total_vars, expected_total_vars);
602 assert_eq!(commit_meta.total_multilins, 7); assert_eq!(commit_meta.n_multilins_by_vars(), &[1, 2, 3, 1]);
604 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1, 3, 6]);
605 }
606}