1use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range};
4
5use binius_field::{BinaryField, ExtensionField, Field, TowerField};
6use binius_math::evaluate_piecewise_multilinear;
7use binius_ntt::NTTOptions;
8use binius_utils::{bail, DeserializeBytes};
9use getset::CopyGetters;
10use tracing::instrument;
11
12use super::error::{Error, VerificationError};
13use crate::{
14 composition::{BivariateProduct, IndexComposition},
15 fiat_shamir::{CanSample, Challenger},
16 merkle_tree::MerkleTreeScheme,
17 piop::util::ResizeableIndex,
18 polynomial::MultivariatePoly,
19 protocols::{
20 fri::{self, estimate_optimal_arity, FRIParams, FRIVerifier},
21 sumcheck::{
22 front_loaded::BatchVerifier as SumcheckBatchVerifier, CompositeSumClaim, SumcheckClaim,
23 },
24 },
25 reed_solomon::reed_solomon::ReedSolomonCode,
26 transcript::VerifierTranscript,
27};
28
29#[derive(Debug, CopyGetters)]
37pub struct CommitMeta {
38 n_multilins_by_vars: Vec<usize>,
39 offsets_by_vars: Vec<usize>,
40 #[getset(get_copy = "pub")]
42 total_vars: usize,
43 #[getset(get_copy = "pub")]
45 total_multilins: usize,
46}
47
48impl CommitMeta {
49 pub fn new(n_multilins_by_vars: Vec<usize>) -> Self {
56 let (offsets_by_vars, total_multilins, total_elems) =
57 n_multilins_by_vars.iter().enumerate().fold(
58 (Vec::with_capacity(n_multilins_by_vars.len()), 0, 0),
59 |(mut offsets, total_multilins, total_elems), (n_vars, &count)| {
60 offsets.push(total_multilins);
61 (offsets, total_multilins + count, total_elems + (count << n_vars))
62 },
63 );
64
65 Self {
66 offsets_by_vars,
67 n_multilins_by_vars,
68 total_vars: total_elems.next_power_of_two().ilog2() as usize,
69 total_multilins,
70 }
71 }
72
73 pub fn with_vars(n_varss: impl IntoIterator<Item = usize>) -> Self {
76 let mut n_multilins_by_vars = ResizeableIndex::new();
77 for n_vars in n_varss {
78 *n_multilins_by_vars.get_mut(n_vars) += 1;
79 }
80 Self::new(n_multilins_by_vars.into_vec())
81 }
82
83 pub fn max_n_vars(&self) -> usize {
85 self.n_multilins_by_vars.len().saturating_sub(1)
86 }
87
88 pub fn n_multilins_by_vars(&self) -> &[usize] {
91 &self.n_multilins_by_vars
92 }
93
94 pub fn range_by_vars(&self, n_vars: usize) -> Range<usize> {
96 let start = self.offsets_by_vars[n_vars];
97 start..start + self.n_multilins_by_vars[n_vars]
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
106pub struct PIOPSumcheckClaim<F: Field> {
107 pub n_vars: usize,
109 pub committed: usize,
111 pub transparent: usize,
113 pub sum: F,
116}
117
118fn make_commit_params_with_constant_arity<F, FEncode>(
119 commit_meta: &CommitMeta,
120 security_bits: usize,
121 log_inv_rate: usize,
122 arity: usize,
123) -> Result<FRIParams<F, FEncode>, Error>
124where
125 F: BinaryField + ExtensionField<FEncode>,
126 FEncode: BinaryField,
127{
128 assert!(arity > 0);
129
130 let fold_arities = std::iter::repeat_n(arity, commit_meta.total_vars.saturating_sub(1) / arity)
132 .collect::<Vec<_>>();
133
134 let log_batch_size = fold_arities.first().copied().unwrap_or(0);
137 let log_dim = commit_meta.total_vars - log_batch_size;
138
139 let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, &NTTOptions::default())?;
140 let n_test_queries = fri::calculate_n_test_queries::<F, _>(security_bits, &rs_code)?;
141 let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?;
142 Ok(fri_params)
143}
144
145pub fn make_commit_params_with_optimal_arity<F, FEncode, MTScheme>(
146 commit_meta: &CommitMeta,
147 _merkle_scheme: &MTScheme,
148 security_bits: usize,
149 log_inv_rate: usize,
150) -> Result<FRIParams<F, FEncode>, Error>
151where
152 F: BinaryField + ExtensionField<FEncode>,
153 FEncode: BinaryField,
154 MTScheme: MerkleTreeScheme<F>,
155{
156 let arity = estimate_optimal_arity(
157 commit_meta.total_vars + log_inv_rate,
158 size_of::<MTScheme::Digest>(),
159 size_of::<F>(),
160 );
161 make_commit_params_with_constant_arity(commit_meta, security_bits, log_inv_rate, arity)
162}
163
164#[derive(Debug, Clone)]
170pub struct SumcheckClaimDesc<F: Field> {
171 pub committed_indices: Range<usize>,
172 pub transparent_indices: Range<usize>,
173 pub composite_sums: Vec<CompositeSumClaim<F, IndexComposition<BivariateProduct, 2>>>,
174}
175
176impl<F: Field> SumcheckClaimDesc<F> {
177 pub fn n_committed(&self) -> usize {
178 self.committed_indices.len()
179 }
180
181 pub fn n_transparent(&self) -> usize {
182 self.transparent_indices.len()
183 }
184}
185
186pub fn make_sumcheck_claim_descs<F: Field>(
187 commit_meta: &CommitMeta,
188 transparent_n_vars_iter: impl Iterator<Item = usize>,
189 claims: &[PIOPSumcheckClaim<F>],
190) -> Result<Vec<SumcheckClaimDesc<F>>, Error> {
191 let mut sumcheck_claim_descs = vec![
193 SumcheckClaimDesc {
194 committed_indices: 0..0,
195 transparent_indices: 0..0,
196 composite_sums: vec![],
197 };
198 commit_meta.max_n_vars() + 1
199 ];
200
201 let mut last_offset = 0;
203 for (&n_multilins, claim_desc) in
204 iter::zip(commit_meta.n_multilins_by_vars(), &mut sumcheck_claim_descs)
205 {
206 claim_desc.committed_indices.start = last_offset;
207 last_offset += n_multilins;
208 claim_desc.committed_indices.end = last_offset;
209 }
210
211 let mut current_n_vars = 0;
214 for transparent_n_vars in transparent_n_vars_iter {
215 match transparent_n_vars.cmp(¤t_n_vars) {
216 Ordering::Less => return Err(Error::TransparentsNotSorted),
217 Ordering::Greater => {
218 let current_desc = &sumcheck_claim_descs[current_n_vars];
219 let offset = current_desc.transparent_indices.end;
220
221 current_n_vars = transparent_n_vars;
222 let next_desc = &mut sumcheck_claim_descs[current_n_vars];
223 next_desc.transparent_indices = offset..offset;
224 }
225 _ => {}
226 }
227
228 sumcheck_claim_descs[current_n_vars].transparent_indices.end += 1;
229 }
230
231 for (i, claim) in claims.iter().enumerate() {
235 let claim_desc = &mut sumcheck_claim_descs[claim.n_vars];
236
237 if !claim_desc.committed_indices.contains(&claim.committed) {
240 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
241 }
242 if !claim_desc.transparent_indices.contains(&claim.transparent) {
243 bail!(Error::SumcheckClaimVariablesMismatch { index: i });
244 }
245
246 let composition = IndexComposition::new(
247 claim_desc.committed_indices.len() + claim_desc.transparent_indices.len(),
248 [
249 claim.committed - claim_desc.committed_indices.start,
250 claim_desc.committed_indices.len() + claim.transparent
251 - claim_desc.transparent_indices.start,
252 ],
253 BivariateProduct::default(),
254 )
255 .expect(
256 "claim.committed and claim.transparent are checked to be in the correct ranges above",
257 );
258 claim_desc.composite_sums.push(CompositeSumClaim {
259 sum: claim.sum,
260 composition,
261 });
262 }
263
264 Ok(sumcheck_claim_descs)
265}
266
267#[instrument("piop::verify", skip_all)]
280pub fn verify<'a, F, FEncode, Challenger_, MTScheme>(
281 commit_meta: &CommitMeta,
282 merkle_scheme: &MTScheme,
283 fri_params: &FRIParams<F, FEncode>,
284 commitment: &MTScheme::Digest,
285 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'a>],
286 claims: &[PIOPSumcheckClaim<F>],
287 transcript: &mut VerifierTranscript<Challenger_>,
288) -> Result<(), Error>
289where
290 F: TowerField + ExtensionField<FEncode>,
291 FEncode: BinaryField,
292 Challenger_: Challenger,
293 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
294{
295 let sumcheck_claim_descs = make_sumcheck_claim_descs(
297 commit_meta,
298 transparents.iter().map(|poly| poly.borrow().n_vars()),
299 claims,
300 )?;
301
302 let non_empty_sumcheck_descs = sumcheck_claim_descs
303 .iter()
304 .enumerate()
305 .filter(|(_n_vars, desc)| !desc.composite_sums.is_empty());
306 let sumcheck_claims = non_empty_sumcheck_descs
307 .clone()
308 .map(|(n_vars, desc)| {
309 SumcheckClaim::new(
312 n_vars,
313 desc.committed_indices.len() + desc.transparent_indices.len(),
314 desc.composite_sums.clone(),
315 )
316 })
317 .collect::<Result<Vec<_>, _>>()?;
318
319 let BatchInterleavedSumcheckFRIOutput {
321 challenges,
322 multilinear_evals,
323 fri_final,
324 } = verify_interleaved_fri_sumcheck(
325 commit_meta.total_vars(),
326 fri_params,
327 merkle_scheme,
328 &sumcheck_claims,
329 commitment,
330 transcript,
331 )?;
332
333 let mut piecewise_evals = verify_transparent_evals(
334 commit_meta,
335 non_empty_sumcheck_descs,
336 multilinear_evals,
337 transparents,
338 &challenges,
339 )?;
340
341 piecewise_evals.reverse();
343 let n_pieces_by_vars = sumcheck_claim_descs
344 .iter()
345 .map(|desc| desc.n_committed())
346 .collect::<Vec<_>>();
347 let piecewise_eval =
348 evaluate_piecewise_multilinear(&challenges, &n_pieces_by_vars, &mut piecewise_evals)?;
349 if piecewise_eval != fri_final {
350 return Err(VerificationError::IncorrectSumcheckEvaluation.into());
351 }
352
353 Ok(())
354}
355
356#[instrument(skip_all, level = "debug")]
358fn verify_transparent_evals<'a, 'b, F: Field>(
359 commit_meta: &CommitMeta,
360 sumcheck_descs: impl Iterator<Item = (usize, &'a SumcheckClaimDesc<F>)>,
361 multilinear_evals: Vec<Vec<F>>,
362 transparents: &[impl Borrow<dyn MultivariatePoly<F> + 'b>],
363 challenges: &[F],
364) -> Result<Vec<F>, Error> {
365 let mut piecewise_evals = Vec::with_capacity(commit_meta.total_multilins());
366 for ((n_vars, desc), multilinear_evals) in iter::zip(sumcheck_descs, multilinear_evals) {
367 let (committed_evals, transparent_evals) = multilinear_evals.split_at(desc.n_committed());
368 piecewise_evals.extend_from_slice(committed_evals);
369
370 assert_eq!(transparent_evals.len(), desc.n_transparent());
371 for (i, (&claimed_eval, transparent)) in
372 iter::zip(transparent_evals, &transparents[desc.transparent_indices.clone()])
373 .enumerate()
374 {
375 let computed_eval = transparent.borrow().evaluate(&challenges[..n_vars])?;
376 if claimed_eval != computed_eval {
377 return Err(VerificationError::IncorrectTransparentEvaluation {
378 index: desc.transparent_indices.start + i,
379 }
380 .into());
381 }
382 }
383 }
384 Ok(piecewise_evals)
385}
386
387#[derive(Debug)]
388struct BatchInterleavedSumcheckFRIOutput<F> {
389 challenges: Vec<F>,
390 multilinear_evals: Vec<Vec<F>>,
391 fri_final: F,
392}
393
394#[instrument(skip_all)]
402fn verify_interleaved_fri_sumcheck<F, FEncode, Challenger_, MTScheme>(
403 n_rounds: usize,
404 fri_params: &FRIParams<F, FEncode>,
405 merkle_scheme: &MTScheme,
406 claims: &[SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>],
407 codeword_commitment: &MTScheme::Digest,
408 proof: &mut VerifierTranscript<Challenger_>,
409) -> Result<BatchInterleavedSumcheckFRIOutput<F>, Error>
410where
411 F: TowerField + ExtensionField<FEncode>,
412 FEncode: BinaryField,
413 Challenger_: Challenger,
414 MTScheme: MerkleTreeScheme<F, Digest: DeserializeBytes>,
415{
416 let mut arities_iter = fri_params.fold_arities().iter();
417 let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles());
418 let mut next_commit_round = arities_iter.next().copied();
419
420 let mut sumcheck_verifier = SumcheckBatchVerifier::new(claims, proof)?;
421 let mut multilinear_evals = Vec::with_capacity(claims.len());
422 let mut challenges = Vec::with_capacity(n_rounds);
423 for round_no in 0..n_rounds {
424 let mut reader = proof.message();
425 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
426 multilinear_evals.push(claim_multilinear_evals);
427 }
428 sumcheck_verifier.receive_round_proof(&mut reader)?;
429
430 let challenge = proof.sample();
431 challenges.push(challenge);
432
433 sumcheck_verifier.finish_round(challenge)?;
434
435 let observe_fri_comm = next_commit_round.is_some_and(|round| round == round_no + 1);
436 if observe_fri_comm {
437 let comm = proof
438 .message()
439 .read()
440 .map_err(VerificationError::Transcript)?;
441 fri_commitments.push(comm);
442 next_commit_round = arities_iter.next().map(|arity| round_no + 1 + arity);
443 }
444 }
445
446 let mut reader = proof.message();
447 while let Some(claim_multilinear_evals) = sumcheck_verifier.try_finish_claim(&mut reader)? {
448 multilinear_evals.push(claim_multilinear_evals);
449 }
450 sumcheck_verifier.finish()?;
451
452 let verifier = FRIVerifier::new(
453 fri_params,
454 merkle_scheme,
455 codeword_commitment,
456 &fri_commitments,
457 &challenges,
458 )?;
459 let fri_final = verifier.verify(proof)?;
460
461 Ok(BatchInterleavedSumcheckFRIOutput {
462 challenges,
463 multilinear_evals,
464 fri_final,
465 })
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[test]
473 fn test_commit_meta_new_empty() {
474 let n_multilins_by_vars = vec![];
475 let commit_meta = CommitMeta::new(n_multilins_by_vars);
476
477 assert_eq!(commit_meta.total_vars, 0);
478 assert_eq!(commit_meta.total_multilins, 0);
479 assert_eq!(commit_meta.n_multilins_by_vars, vec![]);
480 assert_eq!(commit_meta.offsets_by_vars, vec![]);
481 }
482
483 #[test]
484 fn test_commit_meta_new_single_variable() {
485 let n_multilins_by_vars = vec![4];
486 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
487
488 assert_eq!(commit_meta.total_vars, 2);
489 assert_eq!(commit_meta.total_multilins, 4);
490 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
491 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
492 }
493
494 #[test]
495 fn test_commit_meta_new_multiple_variables() {
496 let n_multilins_by_vars = vec![3, 5, 2];
497
498 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
499
500 assert_eq!(commit_meta.total_vars, 5);
502 assert_eq!(commit_meta.total_multilins, 10);
504 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
505 assert_eq!(commit_meta.offsets_by_vars, vec![0, 3, 8]);
506 }
507
508 #[test]
509 #[allow(clippy::identity_op)]
510 fn test_commit_meta_new_large_numbers() {
511 let n_multilins_by_vars = vec![1_000_000, 2_000_000];
512 let commit_meta = CommitMeta::new(n_multilins_by_vars.clone());
513
514 let expected_total_elems = 1_000_000 * (1 << 0) + 2_000_000 * (1 << 1) as usize;
515 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
516
517 assert_eq!(commit_meta.total_vars, expected_total_vars);
518 assert_eq!(commit_meta.total_multilins, 3_000_000);
519 assert_eq!(commit_meta.n_multilins_by_vars, n_multilins_by_vars);
520 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1_000_000]);
521 }
522
523 #[test]
524 fn test_with_vars_empty() {
525 let commit_meta = CommitMeta::with_vars(vec![]);
526
527 assert_eq!(commit_meta.total_vars, 0);
528 assert_eq!(commit_meta.total_multilins, 0);
529 assert_eq!(commit_meta.n_multilins_by_vars(), &[]);
530 assert_eq!(commit_meta.offsets_by_vars, vec![]);
531 }
532
533 #[test]
534 fn test_with_vars_single_variable() {
535 let commit_meta = CommitMeta::with_vars(vec![0, 0, 0, 0]);
536
537 assert_eq!(commit_meta.total_vars, 2);
538 assert_eq!(commit_meta.total_multilins, 4);
539 assert_eq!(commit_meta.n_multilins_by_vars(), &[4]);
540 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
541 }
542
543 #[test]
544 #[allow(clippy::identity_op)]
545 fn test_with_vars_multiple_variables() {
546 let commit_meta = CommitMeta::with_vars(vec![2, 3, 3, 4]);
547
548 let expected_total_elems = 1 * (1 << 2) + 2 * (1 << 3) + 1 * (1 << 4) as usize;
549 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
550
551 assert_eq!(commit_meta.total_vars, expected_total_vars);
552 assert_eq!(commit_meta.total_multilins, 4);
553 assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 0, 1, 2, 1]);
554 assert_eq!(commit_meta.offsets_by_vars, vec![0, 0, 0, 1, 3]);
555 }
556
557 #[test]
558 fn test_with_vars_large_numbers() {
559 let vars = vec![0; 1_000_000];
561 let commit_meta = CommitMeta::with_vars(vars);
562
563 let expected_total_elems = 1_000_000 * (1 << 0) as usize;
565 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
566
567 assert_eq!(commit_meta.total_vars, expected_total_vars);
568 assert_eq!(commit_meta.total_multilins, 1_000_000);
569 assert_eq!(commit_meta.n_multilins_by_vars(), &[1_000_000]);
570 assert_eq!(commit_meta.offsets_by_vars, vec![0]);
571 }
572
573 #[test]
574 #[allow(clippy::identity_op)]
575 fn test_with_vars_mixed_variables() {
576 let vars = vec![0, 1, 1, 2, 2, 2, 3];
577 let commit_meta = CommitMeta::with_vars(vars);
578
579 let expected_total_elems =
581 1 * (1 << 0) + 2 * (1 << 1) + 3 * (1 << 2) + 1 * (1 << 3) as usize;
582 let expected_total_vars = expected_total_elems.next_power_of_two().ilog2() as usize;
583
584 assert_eq!(commit_meta.total_vars, expected_total_vars);
585 assert_eq!(commit_meta.total_multilins, 7); assert_eq!(commit_meta.n_multilins_by_vars(), &[1, 2, 3, 1]);
587 assert_eq!(commit_meta.offsets_by_vars, vec![0, 1, 3, 6]);
588 }
589}