binius_core/merkle_tree/
scheme.rs1use std::{array, fmt::Debug, marker::PhantomData};
4
5use binius_field::TowerField;
6use binius_hash::{HashBuffer, PseudoCompressionFunction};
7use binius_utils::{
8 bail,
9 checked_arithmetics::{log2_ceil_usize, log2_strict_usize},
10 SerializationMode, SerializeBytes,
11};
12use bytes::Buf;
13use digest::{core_api::BlockSizeUser, Digest, Output};
14use getset::Getters;
15
16use super::{
17 errors::{Error, VerificationError},
18 merkle_tree_vcs::MerkleTreeScheme,
19};
20use crate::transcript::TranscriptReader;
21
22#[derive(Debug, Getters)]
23pub struct BinaryMerkleTreeScheme<T, H, C> {
24 #[getset(get = "pub")]
25 compression: C,
26 _phantom: PhantomData<fn() -> (T, H)>,
29}
30
31impl<T, H, C> BinaryMerkleTreeScheme<T, H, C> {
32 pub fn new(compression: C) -> Self {
33 Self {
34 compression,
35 _phantom: PhantomData,
36 }
37 }
38}
39
40impl<F, H, C> MerkleTreeScheme<F> for BinaryMerkleTreeScheme<F, H, C>
41where
42 F: TowerField,
43 H: Digest + BlockSizeUser,
44 C: PseudoCompressionFunction<Output<H>, 2> + Sync,
45{
46 type Digest = Output<H>;
47
48 fn optimal_verify_layer(&self, n_queries: usize, tree_depth: usize) -> usize {
50 log2_ceil_usize(n_queries).min(tree_depth)
51 }
52
53 fn proof_size(&self, len: usize, n_queries: usize, layer_depth: usize) -> Result<usize, Error> {
54 if !len.is_power_of_two() {
55 bail!(Error::PowerOfTwoLengthRequired)
56 }
57
58 let log_len = log2_strict_usize(len);
59
60 if layer_depth > log_len {
61 bail!(Error::IncorrectLayerDepth)
62 }
63
64 Ok(((log_len - layer_depth - 1) * n_queries + (1 << layer_depth))
65 * <H as Digest>::output_size())
66 }
67
68 fn verify_vector(
69 &self,
70 root: &Self::Digest,
71 data: &[F],
72 batch_size: usize,
73 ) -> Result<(), Error> {
74 if data.len() % batch_size != 0 {
75 bail!(Error::IncorrectBatchSize);
76 }
77
78 let mut digests = data
79 .chunks(batch_size)
80 .map(|chunk| hash_field_elems::<_, H>(chunk))
81 .collect::<Vec<_>>();
82
83 fold_digests_vector_inplace(&self.compression, &mut digests)?;
84 if digests[0] != *root {
85 bail!(VerificationError::InvalidProof)
86 }
87 Ok(())
88 }
89
90 fn verify_layer(
91 &self,
92 root: &Self::Digest,
93 layer_depth: usize,
94 layer_digests: &[Self::Digest],
95 ) -> Result<(), Error> {
96 if 1 << layer_depth != layer_digests.len() {
97 bail!(VerificationError::IncorrectVectorLength)
98 }
99
100 let mut digests = layer_digests.to_owned();
101
102 fold_digests_vector_inplace(&self.compression, &mut digests)?;
103
104 if digests[0] != *root {
105 bail!(VerificationError::InvalidProof)
106 }
107 Ok(())
108 }
109
110 fn verify_opening<B: Buf>(
111 &self,
112 mut index: usize,
113 values: &[F],
114 layer_depth: usize,
115 tree_depth: usize,
116 layer_digests: &[Self::Digest],
117 proof: &mut TranscriptReader<B>,
118 ) -> Result<(), Error> {
119 if (1 << layer_depth) != layer_digests.len() {
120 bail!(VerificationError::IncorrectVectorLength);
121 }
122
123 if index >= (1 << tree_depth) {
124 bail!(Error::IndexOutOfRange {
125 max: (1 << tree_depth) - 1
126 });
127 }
128
129 let mut leaf_digest = hash_field_elems::<_, H>(values);
130 for branch_node in proof.read_vec(tree_depth - layer_depth)? {
131 leaf_digest = self.compression.compress(if index & 1 == 0 {
132 [leaf_digest, branch_node]
133 } else {
134 [branch_node, leaf_digest]
135 });
136 index >>= 1;
137 }
138
139 (leaf_digest == layer_digests[index])
140 .then_some(())
141 .ok_or_else(|| VerificationError::InvalidProof.into())
142 }
143}
144
145fn fold_digests_vector_inplace<C, D>(compression: &C, digests: &mut [D]) -> Result<(), Error>
147where
148 C: PseudoCompressionFunction<D, 2> + Sync,
149 D: Clone + Default + Send + Sync + Debug,
150{
151 if !digests.len().is_power_of_two() {
152 bail!(Error::PowerOfTwoLengthRequired);
153 }
154
155 let mut len = digests.len() / 2;
156
157 while len != 0 {
158 for i in 0..len {
159 digests[i] = compression.compress(array::from_fn(|j| digests[2 * i + j].clone()));
160 }
161 len /= 2;
162 }
163
164 Ok(())
165}
166
167fn hash_field_elems<F, H>(elems: &[F]) -> Output<H>
169where
170 F: TowerField,
171 H: Digest + BlockSizeUser,
172{
173 let mut hasher = H::new();
174 {
175 let mut buffer = HashBuffer::new(&mut hasher);
176 for elem in elems {
177 let mode = SerializationMode::CanonicalTower;
178 SerializeBytes::serialize(elem, &mut buffer, mode)
179 .expect("HashBuffer has infinite capacity");
180 }
181 }
182 hasher.finalize()
183}