1use std::{fmt::Debug, mem::MaybeUninit};
5
6use binius_field::Field;
7use binius_utils::{
8 checked_arithmetics::log2_strict_usize,
9 mem::slice_assume_init_mut,
10 rand::par_rand,
11 rayon::{prelude::*, slice::ParallelSlice},
12};
13use digest::{Digest, FixedOutputReset, Output, block_api::BlockSizeUser};
14use rand::{CryptoRng, Rng, rngs::StdRng};
15
16use super::{
17 compress::PseudoCompressionFunction, parallel_compression::ParallelPseudoCompression,
18 parallel_digest::ParallelDigest,
19};
20
21pub trait HashSuite {
29 type LeafHash: Digest + BlockSizeUser + FixedOutputReset + Send;
31 type Compression: PseudoCompressionFunction<Output<Self::LeafHash>, 2> + Default;
33 type ParLeafHash: ParallelDigest<Digest = Self::LeafHash> + Default;
35 type ParCompression: ParallelPseudoCompression<Output<Self::LeafHash>, 2, Compression = Self::Compression>
37 + Default;
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum Error {
42 #[error("Index exceeds Merkle tree base size: {max}")]
43 IndexOutOfRange { max: usize },
44 #[error("values length must be a multiple of the batch size")]
45 IncorrectBatchSize,
46 #[error("The argument length must be a power of two.")]
47 PowerOfTwoLengthRequired,
48 #[error("The layer does not exist in the Merkle tree")]
49 IncorrectLayerDepth,
50}
51
52#[derive(Debug, Clone)]
58pub struct BinaryMerkleTree<D, F> {
59 pub log_len: usize,
61 pub inner_nodes: Vec<D>,
63 pub salts: Vec<F>,
65}
66
67pub fn build<F, H, R>(
68 elements: &[F],
69 batch_size: usize,
70 salt_len: usize,
71 rng: R,
72) -> Result<BinaryMerkleTree<Output<H::LeafHash>, F>, Error>
73where
74 F: Field,
75 H: HashSuite,
76 R: Rng + CryptoRng,
77{
78 if !elements.len().is_multiple_of(batch_size) {
79 return Err(Error::IncorrectBatchSize);
80 }
81
82 let len = elements.len() / batch_size;
83
84 if !len.is_power_of_two() {
85 return Err(Error::PowerOfTwoLengthRequired);
86 }
87
88 build_from_iterator::<_, H, _, _>(
89 elements
90 .par_chunks(batch_size)
91 .map(|chunk| chunk.iter().copied()),
92 batch_size,
93 salt_len,
94 rng,
95 )
96}
97
98pub fn build_from_iterator<F, H, R, ParIter>(
99 iterated_chunks: ParIter,
100 n_items_per_input: usize,
101 salt_len: usize,
102 mut rng: R,
103) -> Result<BinaryMerkleTree<Output<H::LeafHash>, F>, Error>
104where
105 F: Field,
106 H: HashSuite,
107 R: Rng + CryptoRng,
108 ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F, IntoIter: Send>>,
109{
110 let log_len = log2_strict_usize(iterated_chunks.len()); let salts =
114 par_rand::<StdRng, _, _>(salt_len << log_len, &mut rng, F::random).collect::<Vec<_>>();
115
116 let total_length = (1 << (log_len + 1)) - 1;
117 let mut inner_nodes = Vec::with_capacity(total_length);
118 hash_leaves::<F, H, _>(
119 iterated_chunks,
120 n_items_per_input,
121 &mut inner_nodes.spare_capacity_mut()[..(1 << log_len)],
122 &salts,
123 );
124
125 let (prev_layer, mut remaining) = inner_nodes.spare_capacity_mut().split_at_mut(1 << log_len);
126
127 let mut prev_layer = unsafe {
128 slice_assume_init_mut(prev_layer)
130 };
131 let parallel_compression = H::ParCompression::default();
132 for i in 1..(log_len + 1) {
133 let (next_layer, next_remaining) = remaining.split_at_mut(1 << (log_len - i));
134 remaining = next_remaining;
135
136 parallel_compression.parallel_compress(prev_layer, next_layer);
137
138 prev_layer = unsafe {
139 slice_assume_init_mut(next_layer)
141 };
142 }
143
144 unsafe {
145 inner_nodes.set_len(total_length);
149 }
150 Ok(BinaryMerkleTree {
151 log_len,
152 inner_nodes,
153 salts,
154 })
155}
156
157impl<D: Clone, F> BinaryMerkleTree<D, F> {
158 pub fn root(&self) -> D {
159 self.inner_nodes
160 .last()
161 .expect("MerkleTree inner nodes can't be empty")
162 .clone()
163 }
164
165 pub fn get_salt(&self, index: usize) -> &[F] {
170 assert!(index < (1 << self.log_len));
171 let salt_len = self.salts.len() >> self.log_len;
172 &self.salts[index * salt_len..(index + 1) * salt_len]
173 }
174
175 pub fn layer(&self, layer_depth: usize) -> Result<&[D], Error> {
176 if layer_depth > self.log_len {
177 return Err(Error::IncorrectLayerDepth);
178 }
179 let range_start = self.inner_nodes.len() + 1 - (1 << (layer_depth + 1));
180
181 Ok(&self.inner_nodes[range_start..range_start + (1 << layer_depth)])
182 }
183
184 pub fn branch(&self, index: usize, layer_depth: usize) -> Result<Vec<D>, Error> {
188 if index >= 1 << self.log_len || layer_depth > self.log_len {
189 return Err(Error::IndexOutOfRange {
190 max: (1 << self.log_len) - 1,
191 });
192 }
193
194 let branch = (0..self.log_len - layer_depth)
195 .map(|j| {
196 let node_index = (((1 << j) - 1) << (self.log_len + 1 - j)) | (index >> j) ^ 1;
197 self.inner_nodes[node_index].clone()
198 })
199 .collect();
200
201 Ok(branch)
202 }
203}
204
205#[tracing::instrument("hash_leaves", skip_all, level = "debug")]
217fn hash_leaves<F, H, ParIter>(
218 iterated_chunks: ParIter,
219 n_items_per_input: usize,
220 digests: &mut [MaybeUninit<Output<H::LeafHash>>],
221 salts: &[F],
222) where
223 F: Field,
224 H: HashSuite,
225 ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F, IntoIter: Send>>,
226{
227 if salts.is_empty() {
228 let hasher = H::ParLeafHash::default();
231 hasher.digest_with_const_len(n_items_per_input, iterated_chunks, digests);
232 } else {
233 assert!(salts.len().is_multiple_of(digests.len()));
234
235 let salt_len = salts.len() / digests.len();
236
237 let salted_iter = iterated_chunks
239 .zip(salts.par_chunks(salt_len))
240 .map(|(chunk, salt)| chunk.into_iter().chain(salt.iter().copied()));
241
242 let hasher = H::ParLeafHash::default();
244 hasher.digest_with_const_len(n_items_per_input + salt_len, salted_iter, digests);
245 }
246}