1use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range};
4
5use binius_field::{BinaryField, ExtensionField, Field, TowerField};
6use binius_math::evaluate_piecewise_multilinear;
7use binius_ntt::NTTOptions;
8use binius_utils::{bail, DeserializeBytes};
9use getset::CopyGetters;
10use tracing::instrument;
11
12use super::error::{Error, VerificationError};
13use crate::{
14 composition::{BivariateProduct, IndexComposition},
15 fiat_shamir::{CanSample, Challenger},
16 merkle_tree::MerkleTreeScheme,
17 piop::util::ResizeableIndex,
18 polynomial::MultivariatePoly,
19 protocols::{
20 fri::{self, estimate_optimal_arity, FRIParams, FRIVerifier},
21 sumcheck::{
22 front_loaded::BatchVerifier as SumcheckBatchVerifier, CompositeSumClaim, SumcheckClaim,
23 },
24 },
25 reed_solomon::reed_solomon::ReedSolomonCode,
26 transcript::VerifierTranscript,
27};
28
29#[derive(Debug, CopyGetters)]
37pub struct CommitMeta {
38 n_multilins_by_vars: Vec<usize>,
39 offsets_by_vars: Vec<usize>,
40 #[getset(get_copy = "pub")]
42 total_vars: usize,
43 #[getset(get_copy = "pub")]
45 total_multilins: usize,
46}
47
48impl CommitMeta {
49 pub fn new(n_multilins_by_vars: Vec<usize>) -> Self {
56 let (offsets_by_vars, total_multilins, total_elems) =
57 n_multilins_by_vars.iter().enumerate().fold(
58 (Vec::with_capacity(n_multilins_by_vars.len()), 0, 0),
59 |(mut offsets, total_multilins, total_elems), (n_vars, &count)| {
60 offsets.push(total_multilins);
61 (offsets, total_multilins + count, total_elems + (count << n_vars))
62 },
63 );
64
65 Self {
66 offsets_by_vars,
67 n_multilins_by_vars,
68 total_vars: total_elems.next_power_of_two().ilog2() as usize,
69 total_multilins,
70 }
71 }
72
73 pub fn with_vars(n_varss: impl IntoIterator<Item = usize>) -> Self {
76 let mut n_multilins_by_vars = ResizeableIndex::new();
77 for n_vars in n_varss {
78 *n_multilins_by_vars.get_mut(n_vars) += 1;
79 }
80 Self::new(n_multilins_by_vars.into_vec())
81 }
82
83 pub fn max_n_vars(&self) -> usize {
85 self.n_multilins_by_vars.len().saturating_sub(1)
86 }
87
88 pub fn n_multilins_by_vars(&self) -> &[usize] {
91 &self.n_multilins_by_vars
92 }
93
94 pub fn range_by_vars(&self, n_vars: usize) -> Range<usize> {
96 let start = self.offsets_by_vars[n_vars];
97 start..start + self.n_multilins_by_vars[n_vars]
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
106pub struct PIOPSumcheckClaim<F: Field> {
107 pub n_vars: usize,
109 pub committed: usize,
111 pub transparent: usize,
113 pub sum: F,
116}
117
118fn make_commit_params_with_constant_arity<F, FEncode>(
119 commit_meta: &CommitMeta,
120 security_bits: usize,
121 log_inv_rate: usize,
122 arity: usize,
123) -> Result<FRIParams<F, FEncode>, Error>
124where
125 F: BinaryField + ExtensionField<FEncode>,
126 FEncode: BinaryField,
127{
128 assert!(arity > 0);
129
130 let fold_arities = std::iter::repeat_n(arity, commit_meta.total_vars.saturating_sub(1) / arity)
132 .collect::<Vec<_>>();
133
134 let log_batch_size = fold_arities.first().copied().unwrap_or(0);
137 let log_dim = commit_meta.total_vars - log_batch_size;
138
139 let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, &NTTOptions::default())?;
140 let n_test_queries = fri::calculate_n_test_queries::<F, _>(security_bits, &rs_code)?;
141 let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?;
142 Ok(fri_params)
143}
144
145pub fn make_commit_params_with_optimal_arity<F, FEncode, MTScheme>(
146 commit_meta: &CommitMeta,
147 _merkle_scheme: &MTScheme,
148 security_bits: usize,
149 log_inv_rate: usize,
150) -> Result<FRIParams<F, FEncode>, Error>
151where
152 F: BinaryField + ExtensionField<FEncode>,
153 FEncode: BinaryField,
154 MTScheme: MerkleTreeScheme<F>,
155{
156 let arity = estimate_optimal_arity(
157 commit_meta.total_vars + log_inv_rate,
158 size_of::<MTScheme::Digest>(),
159 size_of::<F>(),
160 );
161 make_commit_params_with_constant_arity(commit_meta, security_bits, log_inv_rate, arity)
162}
163
164#[derive(Debug, Clone)]
170pub struct SumcheckClaimDesc<F: Field> {
171 pub committed_indices: Range<usize>,
172 pub transparent_indices: Range<usize>,
173 pub composite_sums: Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>>,
174}
175
176impl<F: Field> SumcheckClaimDesc<F> {
177 pub fn n_committed(&self) -> usize {
178 self.committed_indices.len()
179 }
180
181 pub fn n_transparent(&self) -> usize {
182 self.transparent_indices.len()
183 }
184}
185
186pub fn make_sumcheck_claim_descs<F: Field>(
187 commit_meta: &CommitMeta,
188 transparent_n_vars_iter: impl Iterator<Item = usize>,
189 claims: &[PIOPSumcheckClaim<F>],
190) -> Result<Vec<SumcheckClaimDesc<F>>, Error> {
191 let mut sumcheck_claim_descs = vec![
193 SumcheckClaimDesc {
194 committed_indices: 0..0,
195 transparent_indices: 0..0,
196 composite_sums: vec![],
197 };
198 commit_meta.max_n_vars() + 1
199 ];
200
201 let mut last_offset = 0;
203 for (&n_multilins, claim_desc) in
204 iter::zip(commit_meta.n_multilins_by_vars(), &mut sumcheck_claim_descs)
205 {
206 claim_desc.committed_indices.start = last_offset;
207 last_offset += n_multilins;
208 claim_desc.committed_indices.end = last_offset;
209 }
210
211 let mut current_n_vars = 0;
214 for transparent_n_vars in transparent_n_vars_iter {
215 match transparent_n_vars.cmp(¤t_n_vars) {
216 Ordering::Less => return Err(Error::TransparentsNotSorted),
217 Ordering::Greater => {
218 let current_desc = &sumcheck_claim_descs[current_n_vars];
219 let offset = current_desc.transparent_indices.end;
220
221 current_n_vars = transparent_n_vars;
222 let next_desc = &mut sumcheck_claim_descs[current_n_vars];
223 next_desc.transparent_indices = offset..offset;
224 }
225 _ => {}
226 }
227
228 sumcheck_claim_descs[current_n_vars].transparent_indices.end += 1;
229 }
230
231 for (i, claim) in claims.iter().enumerate() {
235 let claim_desc = &mut sumcheck_claim_descs[claim.n_vars];
236
237 if !claim_desc.committed_indices.contains(&claim.committed) {
240 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
241 }
242 if !claim_desc.transparent_indices.contains(&claim.transparent) {
243 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
244 }
245
246 let composition = IndexComposition::new(
247 claim_desc.committed_indices.len() + claim_desc.transparent_indices.len(),
248 [
249 claim.committed - claim_desc.committed_indices.start,
250 claim_desc.committed_indices.len() + claim.transparent
251 - claim_desc.transparent_indices.start,
252 ],
253 BivariateProduct::default(),
254 )
255 .expect(
256 "claim.committed and claim.transparent are checked to be in the correct ranges above",
257 );
258 claim_desc.composite_sums.push(CompositeSumClaim {
259 sum: claim.sum,
260 composition,
261 });
262 }
263
264 Ok(sumcheck_claim_descs)
265}
266
267#[instrument("piop::verify", skip_all)]
280pub fn verify<'a, F, FEncode, Challenger_, MTScheme>(
281 commit_meta: &CommitMeta,
282 merkle_scheme: &MTScheme,
283 fri_params: &FRIParams<F, FEncode>,
284 commitment: &MTScheme::Digest,
285 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'a>],
286 claims: &[PIOPSumcheckClaim<F>],
287 transcript: &mut VerifierTranscript<Challenger_>,
288) -> Result<(), Error>
289where
290 F: TowerField + ExtensionField<FEncode>,
291 FEncode: BinaryField,
292 Challenger_: Challenger,
293 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
294{
295 let sumcheck_claim_descs = make_sumcheck_claim_descs(
297 commit_meta,
298 transparents.iter().map(|poly| poly.borrow().n_vars()),
299 claims,
300 )?;
301
302 let non_empty_sumcheck_descs = sumcheck_claim_descs
303 .iter()
304 .enumerate()
305 .filter(|(_n_vars, desc)| !desc.composite_sums.is_empty());
306 let sumcheck_claims = non_empty_sumcheck_descs
307 .clone()
308 .map(|(n_vars, desc)| {
309 SumcheckClaim::new(
312 n_vars,
313 desc.committed_indices.len() + desc.transparent_indices.len(),
314 desc.composite_sums.clone(),
315 )
316 })
317 .collect::<Result<Vec<_>, _>>()?;
318
319 let BatchInterleavedSumcheckFRIOutput {
321 challenges,
322 multilinear_evals,
323 fri_final,
324 } = verify_interleaved_fri_sumcheck(
325 commit_meta.total_vars(),
326 fri_params,
327 merkle_scheme,
328 &sumcheck_claims,
329 commitment,
330 transcript,
331 )?;
332
333 let mut piecewise_evals = verify_transparent_evals(
334 commit_meta,
335 non_empty_sumcheck_descs,
336 multilinear_evals,
337 transparents,
338 &challenges,
339 )?;
340
341 piecewise_evals.reverse();
343 let n_pieces_by_vars = sumcheck_claim_descs
344 .iter()
345 .map(|desc| desc.n_committed())
346 .collect::<Vec<_>>();
347 let piecewise_eval =
348 evaluate_piecewise_multilinear(&challenges, &n_pieces_by_vars, &mut piecewise_evals)?;
349 if piecewise_eval != fri_final {
350 return Err(VerificationError::IncorrectSumcheckEvaluation.into());
351 }
352
353 Ok(())
354}
355
356#[instrument(skip_all, level = "debug")]
358fn verify_transparent_evals<'a, 'b, F: Field>(
359 commit_meta: &CommitMeta,
360 sumcheck_descs: impl Iterator<Item = (usize, &'a SumcheckClaimDesc<F>)>,
361 multilinear_evals: Vec<Vec<F>>,
362 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'b>],
363 challenges: &[F],
364) -> Result<Vec<F>, Error> {
365 let mut challenges_rev = challenges.to_vec();
368 challenges_rev.reverse();
369 let n_challenges = challenges.len();
370
371 let mut piecewise_evals = Vec::with_capacity(commit_meta.total_multilins());
372 for ((n_vars, desc), multilinear_evals) in iter::zip(sumcheck_descs, multilinear_evals) {
373 let (committed_evals, transparent_evals) = multilinear_evals.split_at(desc.n_committed());
374 piecewise_evals.extend_from_slice(committed_evals);
375
376 assert_eq!(transparent_evals.len(), desc.n_transparent());
377 for (i, (&claimed_eval, transparent)) in
378 iter::zip(transparent_evals, &transparents[desc.transparent_indices.clone()])
379 .enumerate()
380 {
381 let computed_eval = transparent
382 .borrow()
383 .evaluate(&challenges_rev[n_challenges - n_vars..])?;
384 if claimed_eval != computed_eval {
385 return Err(VerificationError::IncorrectTransparentEvaluation {
386 index: desc.transparent_indices.start + i,
387 }
388 .into());
389 }
390 }
391 }
392 Ok(piecewise_evals)
393}
394
395#[derive(Debug)]
396struct BatchInterleavedSumcheckFRIOutput<F> {
397 challenges: Vec<F>,
398 multilinear_evals: Vec<Vec<F>>,
399 fri_final: F,
400}
401
402#[instrument(skip_all)]
410fn verify_interleaved_fri_sumcheck<F, FEncode, Challenger_, MTScheme>(
411 n_rounds: usize,
412 fri_params: &FRIParams<F, FEncode>,
413 merkle_scheme: &MTScheme,
414 claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
415 codeword_commitment: &MTScheme::Digest,
416 proof: &mut VerifierTranscript<Challenger_>,
417) -> Result<BatchInterleavedSumcheckFRIOutput<F>, Error>
418where
419 F: TowerField + ExtensionField<FEncode>,
420 FEncode: BinaryField,
421 Challenger_: Challenger,
422 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
423{
424 let mut arities_iter = fri_params.fold_arities().iter();
425 let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles());
426 let mut next_commit_round = arities_iter.next().copied();
427
428 let mut sumcheck_verifier = SumcheckBatchVerifier::new(claims, proof)?;
429 let mut multilinear_evals = Vec::with_capacity(claims.len());
430 let mut challenges = Vec::with_capacity(n_rounds);
431 for round_no in 0..n_rounds {
432 let mut reader = proof.message();
433 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
434 multilinear_evals.push(claim_multilinear_evals);
435 }
436 sumcheck_verifier.receive_round_proof(&mut reader)?;
437
438 let challenge = proof.sample();
439 challenges.push(challenge);
440
441 sumcheck_verifier.finish_round(challenge)?;
442
443 let observe_fri_comm = next_commit_round.is_some_and(|round| round == round_no + 1);
444 if observe_fri_comm {
445 let comm = proof
446 .message()
447 .read()
448 .map_err(VerificationError::Transcript)?;
449 fri_commitments.push(comm);
450 next_commit_round = arities_iter.next().map(|arity| round_no + 1 + arity);
451 }
452 }
453
454 let mut reader = proof.message();
455 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
456 multilinear_evals.push(claim_multilinear_evals);
457 }
458 sumcheck_verifier.finish()?;
459
460 let verifier = FRIVerifier::new(
461 fri_params,
462 merkle_scheme,
463 codeword_commitment,
464 &fri_commitments,
465 &challenges,
466 )?;
467 let fri_final = verifier.verify(proof)?;
468
469 Ok(BatchInterleavedSumcheckFRIOutput {
470 challenges,
471 multilinear_evals,
472 fri_final,
473 })
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_commit_meta_new_empty() {
482 let n_multilins_by_vars = vec![];
483 let commit_meta = CommitMeta::new(n_multilins_by_vars);
484
485 assert_eq!(commit_meta.total_vars, 0);
486 assert_eq!(commit_meta.total_multilins, 0);
487 assert_eq!(commit_meta.n_multilins_by_vars, vec![]);
488 assert_eq!(commit_meta.offsets_by_vars, vec![]);
489 }
490
491 #[test]
492 fn test_commit_meta_new_single_variable() {
493 let n_multilins_by_vars = vec![4];
494 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
495
496 assert_eq!(commit_meta.total_vars, 2);
497 assert_eq!(commit_meta.total_multilins, 4);
498 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
499 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
500 }
501
502 #[test]
503 fn test_commit_meta_new_multiple_variables() {
504 let n_multilins_by_vars = vec![3, 5, 2];
505
506 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
507
508 assert_eq!(commit_meta.total_vars, 5);
510 assert_eq!(commit_meta.total_multilins, 10);
512 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
513 assert_eq!(commit_meta.offsets_by_vars, vec![0, 3, 8]);
514 }
515
516 #[test]
517 #[allow(clippy::identity_op)]
518 fn test_commit_meta_new_large_numbers() {
519 let n_multilins_by_vars = vec![1_000_000, 2_000_000];
520 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
521
522 let expected_total_elems = 1_000_000 * (1 << 0) + 2_000_000 * (1 << 1) as usize;
523 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
524
525 assert_eq!(commit_meta.total_vars, expected_total_vars);
526 assert_eq!(commit_meta.total_multilins, 3_000_000);
527 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
528 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1_000_000]);
529 }
530
531 #[test]
532 fn test_with_vars_empty() {
533 let commit_meta = CommitMeta::with_vars(vec![]);
534
535 assert_eq!(commit_meta.total_vars, 0);
536 assert_eq!(commit_meta.total_multilins, 0);
537 assert_eq!(commit_meta.n_multilins_by_vars(), &[]);
538 assert_eq!(commit_meta.offsets_by_vars, vec![]);
539 }
540
541 #[test]
542 fn test_with_vars_single_variable() {
543 let commit_meta = CommitMeta::with_vars(vec![0, 0, 0, 0]);
544
545 assert_eq!(commit_meta.total_vars, 2);
546 assert_eq!(commit_meta.total_multilins, 4);
547 assert_eq!(commit_meta.n_multilins_by_vars(), &[4]);
548 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
549 }
550
551 #[test]
552 #[allow(clippy::identity_op)]
553 fn test_with_vars_multiple_variables() {
554 let commit_meta = CommitMeta::with_vars(vec![2, 3, 3, 4]);
555
556 let expected_total_elems = 1 * (1 << 2) + 2 * (1 << 3) + 1 * (1 << 4) as usize;
557 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
558
559 assert_eq!(commit_meta.total_vars, expected_total_vars);
560 assert_eq!(commit_meta.total_multilins, 4);
561 assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 0, 1, 2, 1]);
562 assert_eq!(commit_meta.offsets_by_vars, vec![0, 0, 0, 1, 3]);
563 }
564
565 #[test]
566 fn test_with_vars_large_numbers() {
567 let vars = vec![0; 1_000_000];
569 let commit_meta = CommitMeta::with_vars(vars);
570
571 let expected_total_elems = 1_000_000 * (1 << 0) as usize;
573 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
574
575 assert_eq!(commit_meta.total_vars, expected_total_vars);
576 assert_eq!(commit_meta.total_multilins, 1_000_000);
577 assert_eq!(commit_meta.n_multilins_by_vars(), &[1_000_000]);
578 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
579 }
580
581 #[test]
582 #[allow(clippy::identity_op)]
583 fn test_with_vars_mixed_variables() {
584 let vars = vec![0, 1, 1, 2, 2, 2, 3];
585 let commit_meta = CommitMeta::with_vars(vars);
586
587 let expected_total_elems =
589 1 * (1 << 0) + 2 * (1 << 1) + 3 * (1 << 2) + 1 * (1 << 3) as usize;
590 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
591
592 assert_eq!(commit_meta.total_vars, expected_total_vars);
593 assert_eq!(commit_meta.total_multilins, 7); assert_eq!(commit_meta.n_multilins_by_vars(), &[1, 2, 3, 1]);
595 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1, 3, 6]);
596 }
597}