1use std::{array, fmt::Debug, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_hash::{multi_digest::ParallelDigest, PseudoCompressionFunction};
7use binius_maybe_rayon::{prelude::*, slice::ParallelSlice};
8use binius_utils::{bail, checked_arithmetics::log2_strict_usize};
9use digest::{crypto_common::BlockSizeUser, FixedOutputReset, Output};
10use tracing::instrument;
11
12use super::errors::Error;
13
14#[derive(Debug, Clone)]
20pub struct BinaryMerkleTree<D> {
21 pub log_len: usize,
23 pub inner_nodes: Vec<D>,
25}
26
27pub fn build<F, H, C>(
28 compression: &C,
29 elements: &[F],
30 batch_size: usize,
31) -> Result<BinaryMerkleTree<Output<H::Digest>>, Error>
32where
33 F: TowerField,
34 H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
35 C: PseudoCompressionFunction<Output<H::Digest>, 2> + Sync,
36{
37 if elements.len() % batch_size != 0 {
38 bail!(Error::IncorrectBatchSize);
39 }
40
41 let len = elements.len() / batch_size;
42
43 if !len.is_power_of_two() {
44 bail!(Error::PowerOfTwoLengthRequired);
45 }
46
47 let log_len = log2_strict_usize(len);
48
49 internal_build(
50 compression,
51 |inner_nodes| hash_interleaved::<_, H>(elements, inner_nodes),
52 log_len,
53 )
54}
55
56fn internal_build<Digest, C>(
57 compression: &C,
58 hash_leaves: impl FnOnce(&mut [MaybeUninit<Digest>]) -> Result<(), Error>,
60 log_len: usize,
61) -> Result<BinaryMerkleTree<Digest>, Error>
62where
63 Digest: Clone + Send + Sync,
64 C: PseudoCompressionFunction<Digest, 2> + Sync,
65{
66 let total_length = (1 << (log_len + 1)) - 1;
67 let mut inner_nodes = Vec::with_capacity(total_length);
68
69 hash_leaves(&mut inner_nodes.spare_capacity_mut()[..(1 << log_len)])?;
70
71 let (prev_layer, mut remaining) = inner_nodes.spare_capacity_mut().split_at_mut(1 << log_len);
72
73 let mut prev_layer = unsafe {
74 slice_assume_init_mut(prev_layer)
76 };
77 for i in 1..(log_len + 1) {
78 let (next_layer, next_remaining) = remaining.split_at_mut(1 << (log_len - i));
79 remaining = next_remaining;
80
81 compress_layer(compression, prev_layer, next_layer);
82
83 prev_layer = unsafe {
84 slice_assume_init_mut(next_layer)
86 };
87 }
88
89 unsafe {
90 inner_nodes.set_len(total_length);
94 }
95 Ok(BinaryMerkleTree {
96 log_len,
97 inner_nodes,
98 })
99}
100
101#[instrument("BinaryMerkleTree::build", skip_all, level = "debug")]
102pub fn build_from_iterator<F, H, C, ParIter>(
103 compression: &C,
104 iterated_chunks: ParIter,
105 log_len: usize,
106) -> Result<BinaryMerkleTree<Output<H::Digest>>, Error>
107where
108 F: TowerField,
109 H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
110 C: PseudoCompressionFunction<Output<H::Digest>, 2> + Sync,
111 ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
112{
113 internal_build(
114 compression,
115 |inner_nodes| hash_iterated::<F, H, _>(iterated_chunks, inner_nodes),
116 log_len,
117 )
118}
119
120impl<D: Clone> BinaryMerkleTree<D> {
121 pub fn root(&self) -> D {
122 self.inner_nodes
123 .last()
124 .expect("MerkleTree inner nodes can't be empty")
125 .clone()
126 }
127
128 pub fn layer(&self, layer_depth: usize) -> Result<&[D], Error> {
129 if layer_depth > self.log_len {
130 bail!(Error::IncorrectLayerDepth);
131 }
132 let range_start = self.inner_nodes.len() + 1 - (1 << (layer_depth + 1));
133
134 Ok(&self.inner_nodes[range_start..range_start + (1 << layer_depth)])
135 }
136
137 pub fn branch(&self, index: usize, layer_depth: usize) -> Result<Vec<D>, Error> {
141 if index >= 1 << self.log_len || layer_depth > self.log_len {
142 return Err(Error::IndexOutOfRange {
143 max: (1 << self.log_len) - 1,
144 });
145 }
146
147 let branch = (0..self.log_len - layer_depth)
148 .map(|j| {
149 let node_index = (((1 << j) - 1) << (self.log_len + 1 - j)) | (index >> j) ^ 1;
150 self.inner_nodes[node_index].clone()
151 })
152 .collect();
153
154 Ok(branch)
155 }
156}
157
158#[tracing::instrument("MerkleTree::compress_layer", skip_all, level = "debug")]
159fn compress_layer<D, C>(compression: &C, prev_layer: &[D], next_layer: &mut [MaybeUninit<D>])
160where
161 D: Clone + Send + Sync,
162 C: PseudoCompressionFunction<D, 2> + Sync,
163{
164 prev_layer
165 .par_chunks_exact(2)
166 .zip(next_layer.par_iter_mut())
167 .for_each(|(prev_pair, next_digest)| {
168 next_digest.write(compression.compress(array::from_fn(|i| prev_pair[i].clone())));
169 })
170}
171
172#[tracing::instrument("hash_interleaved", skip_all, level = "debug")]
178fn hash_interleaved<F, H>(
179 elems: &[F],
180 digests: &mut [MaybeUninit<Output<H::Digest>>],
181) -> Result<(), Error>
182where
183 F: TowerField,
184 H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
185{
186 if elems.len() % digests.len() != 0 {
187 return Err(Error::IncorrectVectorLen {
188 expected: digests.len(),
189 });
190 }
191
192 let hash_data_iter = elems
193 .par_chunks(elems.len() / digests.len())
194 .map(|s| s.iter().copied());
195 hash_iterated::<_, H, _>(hash_data_iter, digests)
196}
197
198fn hash_iterated<F, H, ParIter>(
199 iterated_chunks: ParIter,
200 digests: &mut [MaybeUninit<Output<H::Digest>>],
201) -> Result<(), Error>
202where
203 F: TowerField,
204 H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
205 ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
206{
207 let hasher = H::new();
208 hasher.digest(iterated_chunks, digests);
209
210 Ok(())
211}
212
213pub const unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
226 std::mem::transmute(slice)
227}