binius_core/merkle_tree/
scheme.rs1use std::{array, fmt::Debug, marker::PhantomData};
4
5use binius_field::TowerField;
6use binius_hash::{HashBuffer, PseudoCompressionFunction};
7use binius_utils::{
8 SerializationMode, SerializeBytes, bail,
9 checked_arithmetics::{log2_ceil_usize, log2_strict_usize},
10};
11use bytes::Buf;
12use digest::{Digest, Output, core_api::BlockSizeUser};
13use getset::Getters;
14
15use super::{
16 errors::{Error, VerificationError},
17 merkle_tree_vcs::MerkleTreeScheme,
18};
19use crate::transcript::TranscriptReader;
20
21#[derive(Debug, Getters)]
22pub struct BinaryMerkleTreeScheme<T, H, C> {
23 #[getset(get = "pub")]
24 compression: C,
25 _phantom: PhantomData<fn() -> (T, H)>,
28}
29
30impl<T, H, C> BinaryMerkleTreeScheme<T, H, C> {
31 pub fn new(compression: C) -> Self {
32 Self {
33 compression,
34 _phantom: PhantomData,
35 }
36 }
37}
38
39impl<F, H, C> MerkleTreeScheme<F> for BinaryMerkleTreeScheme<F, H, C>
40where
41 F: TowerField,
42 H: Digest + BlockSizeUser,
43 C: PseudoCompressionFunction<Output<H>, 2> + Sync,
44{
45 type Digest = Output<H>;
46
47 fn optimal_verify_layer(&self, n_queries: usize, tree_depth: usize) -> usize {
49 log2_ceil_usize(n_queries).min(tree_depth)
50 }
51
52 fn proof_size(&self, len: usize, n_queries: usize, layer_depth: usize) -> Result<usize, Error> {
53 if !len.is_power_of_two() {
54 bail!(Error::PowerOfTwoLengthRequired)
55 }
56
57 let log_len = log2_strict_usize(len);
58
59 if layer_depth > log_len {
60 bail!(Error::IncorrectLayerDepth)
61 }
62
63 Ok(((log_len - layer_depth - 1) * n_queries + (1 << layer_depth))
64 * <H as Digest>::output_size())
65 }
66
67 fn verify_vector(
68 &self,
69 root: &Self::Digest,
70 data: &[F],
71 batch_size: usize,
72 ) -> Result<(), Error> {
73 if data.len() % batch_size != 0 {
74 bail!(Error::IncorrectBatchSize);
75 }
76
77 let mut digests = data
78 .chunks(batch_size)
79 .map(|chunk| hash_field_elems::<_, H>(chunk))
80 .collect::<Vec<_>>();
81
82 fold_digests_vector_inplace(&self.compression, &mut digests)?;
83 if digests[0] != *root {
84 bail!(VerificationError::InvalidProof)
85 }
86 Ok(())
87 }
88
89 fn verify_layer(
90 &self,
91 root: &Self::Digest,
92 layer_depth: usize,
93 layer_digests: &[Self::Digest],
94 ) -> Result<(), Error> {
95 if 1 << layer_depth != layer_digests.len() {
96 bail!(VerificationError::IncorrectVectorLength)
97 }
98
99 let mut digests = layer_digests.to_owned();
100
101 fold_digests_vector_inplace(&self.compression, &mut digests)?;
102
103 if digests[0] != *root {
104 bail!(VerificationError::InvalidProof)
105 }
106 Ok(())
107 }
108
109 fn verify_opening<B: Buf>(
110 &self,
111 mut index: usize,
112 values: &[F],
113 layer_depth: usize,
114 tree_depth: usize,
115 layer_digests: &[Self::Digest],
116 proof: &mut TranscriptReader<B>,
117 ) -> Result<(), Error> {
118 if (1 << layer_depth) != layer_digests.len() {
119 bail!(VerificationError::IncorrectVectorLength);
120 }
121
122 if index >= (1 << tree_depth) {
123 bail!(Error::IndexOutOfRange {
124 max: (1 << tree_depth) - 1
125 });
126 }
127
128 let mut leaf_digest = hash_field_elems::<_, H>(values);
129 for branch_node in proof.read_vec(tree_depth - layer_depth)? {
130 leaf_digest = self.compression.compress(if index & 1 == 0 {
131 [leaf_digest, branch_node]
132 } else {
133 [branch_node, leaf_digest]
134 });
135 index >>= 1;
136 }
137
138 (leaf_digest == layer_digests[index])
139 .then_some(())
140 .ok_or_else(|| VerificationError::InvalidProof.into())
141 }
142}
143
144fn fold_digests_vector_inplace<C, D>(compression: &C, digests: &mut [D]) -> Result<(), Error>
146where
147 C: PseudoCompressionFunction<D, 2> + Sync,
148 D: Clone + Default + Send + Sync + Debug,
149{
150 if !digests.len().is_power_of_two() {
151 bail!(Error::PowerOfTwoLengthRequired);
152 }
153
154 let mut len = digests.len() / 2;
155
156 while len != 0 {
157 for i in 0..len {
158 digests[i] = compression.compress(array::from_fn(|j| digests[2 * i + j].clone()));
159 }
160 len /= 2;
161 }
162
163 Ok(())
164}
165
166fn hash_field_elems<F, H>(elems: &[F]) -> Output<H>
168where
169 F: TowerField,
170 H: Digest + BlockSizeUser,
171{
172 let mut hasher = H::new();
173 {
174 let mut buffer = HashBuffer::new(&mut hasher);
175 for elem in elems {
176 let mode = SerializationMode::CanonicalTower;
177 SerializeBytes::serialize(elem, &mut buffer, mode)
178 .expect("HashBuffer has infinite capacity");
179 }
180 }
181 hasher.finalize()
182}