1use std::{array, fmt::Debug, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_hash::{HashBuffer, PseudoCompressionFunction};
7use binius_maybe_rayon::{prelude::*, slice::ParallelSlice};
8use binius_utils::{
9 bail, checked_arithmetics::log2_strict_usize, SerializationMode, SerializeBytes,
10};
11use digest::{crypto_common::BlockSizeUser, Digest, FixedOutputReset, Output};
12use tracing::instrument;
13
14use super::errors::Error;
15
16#[derive(Debug, Clone)]
22pub struct BinaryMerkleTree<D> {
23 pub log_len: usize,
25 pub inner_nodes: Vec<D>,
27}
28
29pub fn build<F, H, C>(
30 compression: &C,
31 elements: &[F],
32 batch_size: usize,
33) -> Result<BinaryMerkleTree<Output<H>>, Error>
34where
35 F: TowerField,
36 H: Digest + BlockSizeUser + FixedOutputReset,
37 C: PseudoCompressionFunction<Output<H>, 2> + Sync,
38{
39 if elements.len() % batch_size != 0 {
40 bail!(Error::IncorrectBatchSize);
41 }
42
43 let len = elements.len() / batch_size;
44
45 if !len.is_power_of_two() {
46 bail!(Error::PowerOfTwoLengthRequired);
47 }
48
49 let log_len = log2_strict_usize(len);
50
51 internal_build(
52 compression,
53 |inner_nodes| hash_interleaved::<_, H>(elements, inner_nodes),
54 log_len,
55 )
56}
57
58fn internal_build<Digest, C>(
59 compression: &C,
60 hash_leaves: impl FnOnce(&mut [MaybeUninit<Digest>]) -> Result<(), Error>,
62 log_len: usize,
63) -> Result<BinaryMerkleTree<Digest>, Error>
64where
65 Digest: Clone + Send + Sync,
66 C: PseudoCompressionFunction<Digest, 2> + Sync,
67{
68 let total_length = (1 << (log_len + 1)) - 1;
69 let mut inner_nodes = Vec::with_capacity(total_length);
70
71 hash_leaves(&mut inner_nodes.spare_capacity_mut()[..(1 << log_len)])?;
72
73 let (prev_layer, mut remaining) = inner_nodes.spare_capacity_mut().split_at_mut(1 << log_len);
74
75 let mut prev_layer = unsafe {
76 slice_assume_init_mut(prev_layer)
78 };
79 for i in 1..(log_len + 1) {
80 let (next_layer, next_remaining) = remaining.split_at_mut(1 << (log_len - i));
81 remaining = next_remaining;
82
83 compress_layer(compression, prev_layer, next_layer);
84
85 prev_layer = unsafe {
86 slice_assume_init_mut(next_layer)
88 };
89 }
90
91 unsafe {
92 inner_nodes.set_len(total_length);
96 }
97 Ok(BinaryMerkleTree {
98 log_len,
99 inner_nodes,
100 })
101}
102
103#[instrument("BinaryMerkleTree::build", skip_all, level = "debug")]
104pub fn build_from_iterator<F, H, C, ParIter>(
105 compression: &C,
106 iterated_chunks: ParIter,
107 log_len: usize,
108) -> Result<BinaryMerkleTree<Output<H>>, Error>
109where
110 F: TowerField,
111 H: Digest + BlockSizeUser + FixedOutputReset,
112 C: PseudoCompressionFunction<Output<H>, 2> + Sync,
113 ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
114{
115 internal_build(
116 compression,
117 |inner_nodes| hash_iterated::<F, H, _>(iterated_chunks, inner_nodes),
118 log_len,
119 )
120}
121
122impl<D: Clone> BinaryMerkleTree<D> {
123 pub fn root(&self) -> D {
124 self.inner_nodes
125 .last()
126 .expect("MerkleTree inner nodes can't be empty")
127 .clone()
128 }
129
130 pub fn layer(&self, layer_depth: usize) -> Result<&[D], Error> {
131 if layer_depth > self.log_len {
132 bail!(Error::IncorrectLayerDepth);
133 }
134 let range_start = self.inner_nodes.len() + 1 - (1 << (layer_depth + 1));
135
136 Ok(&self.inner_nodes[range_start..range_start + (1 << layer_depth)])
137 }
138
139 pub fn branch(&self, index: usize, layer_depth: usize) -> Result<Vec<D>, Error> {
143 if index >= 1 << self.log_len || layer_depth > self.log_len {
144 return Err(Error::IndexOutOfRange {
145 max: (1 << self.log_len) - 1,
146 });
147 }
148
149 let branch = (0..self.log_len - layer_depth)
150 .map(|j| {
151 let node_index = (((1 << j) - 1) << (self.log_len + 1 - j)) | (index >> j) ^ 1;
152 self.inner_nodes[node_index].clone()
153 })
154 .collect();
155
156 Ok(branch)
157 }
158}
159
160#[tracing::instrument("MerkleTree::compress_layer", skip_all, level = "debug")]
161fn compress_layer<D, C>(compression: &C, prev_layer: &[D], next_layer: &mut [MaybeUninit<D>])
162where
163 D: Clone + Send + Sync,
164 C: PseudoCompressionFunction<D, 2> + Sync,
165{
166 prev_layer
167 .par_chunks_exact(2)
168 .zip(next_layer.par_iter_mut())
169 .for_each(|(prev_pair, next_digest)| {
170 next_digest.write(compression.compress(array::from_fn(|i| prev_pair[i].clone())));
171 })
172}
173
174#[tracing::instrument("hash_interleaved", skip_all, level = "debug")]
180fn hash_interleaved<F, H>(elems: &[F], digests: &mut [MaybeUninit<Output<H>>]) -> Result<(), Error>
181where
182 F: TowerField,
183 H: Digest + BlockSizeUser + FixedOutputReset,
184{
185 if elems.len() % digests.len() != 0 {
186 return Err(Error::IncorrectVectorLen {
187 expected: digests.len(),
188 });
189 }
190 let batch_size = elems.len() / digests.len();
191 hash_iterated::<F, H, _>(
192 elems
193 .par_chunks(batch_size)
194 .map(|chunk| chunk.iter().copied()),
195 digests,
196 )
197}
198
199fn hash_iterated<F, H, ParIter>(
200 iterated_chunks: ParIter,
201 digests: &mut [MaybeUninit<Output<H>>],
202) -> Result<(), Error>
203where
204 F: TowerField,
205 H: Digest + BlockSizeUser + FixedOutputReset,
206 ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
207{
208 digests
209 .par_iter_mut()
210 .zip(iterated_chunks)
211 .for_each_init(H::new, |hasher, (digest, elems)| {
212 {
213 let mut hash_buffer = HashBuffer::new(hasher);
214 for elem in elems {
215 let mode = SerializationMode::CanonicalTower;
216 SerializeBytes::serialize(&elem, &mut hash_buffer, mode)
217 .expect("HashBuffer has infinite capacity");
218 }
219 }
220 digest.write(Digest::finalize_reset(hasher));
221 });
222 Ok(())
223}
224
225pub const unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
238 std::mem::transmute(slice)
239}