binius_core/merkle_tree/
binary_merkle_tree.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_hash::{multi_digest::ParallelDigest, PseudoCompressionFunction};
7use binius_maybe_rayon::{prelude::*, slice::ParallelSlice};
8use binius_utils::{bail, checked_arithmetics::log2_strict_usize};
9use digest::{crypto_common::BlockSizeUser, FixedOutputReset, Output};
10use tracing::instrument;
11
12use super::errors::Error;
13
14/// A binary Merkle tree that commits batches of vectors.
15///
16/// The vector entries at each index in a batch are hashed together into leaf digests. Then a
17/// Merkle tree is constructed over the leaf digests. The implementation requires that the vector
18/// lengths are all equal to each other and a power of two.
19#[derive(Debug, Clone)]
20pub struct BinaryMerkleTree<D> {
21	/// Base-2 logarithm of the number of leaves
22	pub log_len: usize,
23	/// The inner nodes, arranged as a flattened array of layers with the root at the end
24	pub inner_nodes: Vec<D>,
25}
26
27pub fn build<F, H, C>(
28	compression: &C,
29	elements: &[F],
30	batch_size: usize,
31) -> Result<BinaryMerkleTree<Output<H::Digest>>, Error>
32where
33	F: TowerField,
34	H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
35	C: PseudoCompressionFunction<Output<H::Digest>, 2> + Sync,
36{
37	if elements.len() % batch_size != 0 {
38		bail!(Error::IncorrectBatchSize);
39	}
40
41	let len = elements.len() / batch_size;
42
43	if !len.is_power_of_two() {
44		bail!(Error::PowerOfTwoLengthRequired);
45	}
46
47	let log_len = log2_strict_usize(len);
48
49	internal_build(
50		compression,
51		|inner_nodes| hash_interleaved::<_, H>(elements, inner_nodes),
52		log_len,
53	)
54}
55
56fn internal_build<Digest, C>(
57	compression: &C,
58	// Must either successfully initialize the passed in slice or return error
59	hash_leaves: impl FnOnce(&mut [MaybeUninit<Digest>]) -> Result<(), Error>,
60	log_len: usize,
61) -> Result<BinaryMerkleTree<Digest>, Error>
62where
63	Digest: Clone + Send + Sync,
64	C: PseudoCompressionFunction<Digest, 2> + Sync,
65{
66	let total_length = (1 << (log_len + 1)) - 1;
67	let mut inner_nodes = Vec::with_capacity(total_length);
68
69	hash_leaves(&mut inner_nodes.spare_capacity_mut()[..(1 << log_len)])?;
70
71	let (prev_layer, mut remaining) = inner_nodes.spare_capacity_mut().split_at_mut(1 << log_len);
72
73	let mut prev_layer = unsafe {
74		// SAFETY: prev-layer was initialized by hash_leaves
75		slice_assume_init_mut(prev_layer)
76	};
77	for i in 1..(log_len + 1) {
78		let (next_layer, next_remaining) = remaining.split_at_mut(1 << (log_len - i));
79		remaining = next_remaining;
80
81		compress_layer(compression, prev_layer, next_layer);
82
83		prev_layer = unsafe {
84			// SAFETY: next_layer was just initialized by compress_layer
85			slice_assume_init_mut(next_layer)
86		};
87	}
88
89	unsafe {
90		// SAFETY: inner_nodes should be entirely initialized by now
91		// Note that we don't incrementally update inner_nodes.len() since
92		// that doesn't play well with using split_at_mut on spare capacity.
93		inner_nodes.set_len(total_length);
94	}
95	Ok(BinaryMerkleTree {
96		log_len,
97		inner_nodes,
98	})
99}
100
101#[instrument("BinaryMerkleTree::build", skip_all, level = "debug")]
102pub fn build_from_iterator<F, H, C, ParIter>(
103	compression: &C,
104	iterated_chunks: ParIter,
105	log_len: usize,
106) -> Result<BinaryMerkleTree<Output<H::Digest>>, Error>
107where
108	F: TowerField,
109	H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
110	C: PseudoCompressionFunction<Output<H::Digest>, 2> + Sync,
111	ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
112{
113	internal_build(
114		compression,
115		|inner_nodes| hash_iterated::<F, H, _>(iterated_chunks, inner_nodes),
116		log_len,
117	)
118}
119
120impl<D: Clone> BinaryMerkleTree<D> {
121	pub fn root(&self) -> D {
122		self.inner_nodes
123			.last()
124			.expect("MerkleTree inner nodes can't be empty")
125			.clone()
126	}
127
128	pub fn layer(&self, layer_depth: usize) -> Result<&[D], Error> {
129		if layer_depth > self.log_len {
130			bail!(Error::IncorrectLayerDepth);
131		}
132		let range_start = self.inner_nodes.len() + 1 - (1 << (layer_depth + 1));
133
134		Ok(&self.inner_nodes[range_start..range_start + (1 << layer_depth)])
135	}
136
137	/// Get a Merkle branch for the given index
138	///
139	/// Throws if the index is out of range
140	pub fn branch(&self, index: usize, layer_depth: usize) -> Result<Vec<D>, Error> {
141		if index >= 1 << self.log_len || layer_depth > self.log_len {
142			return Err(Error::IndexOutOfRange {
143				max: (1 << self.log_len) - 1,
144			});
145		}
146
147		let branch = (0..self.log_len - layer_depth)
148			.map(|j| {
149				let node_index = (((1 << j) - 1) << (self.log_len + 1 - j)) | (index >> j) ^ 1;
150				self.inner_nodes[node_index].clone()
151			})
152			.collect();
153
154		Ok(branch)
155	}
156}
157
158#[tracing::instrument("MerkleTree::compress_layer", skip_all, level = "debug")]
159fn compress_layer<D, C>(compression: &C, prev_layer: &[D], next_layer: &mut [MaybeUninit<D>])
160where
161	D: Clone + Send + Sync,
162	C: PseudoCompressionFunction<D, 2> + Sync,
163{
164	prev_layer
165		.par_chunks_exact(2)
166		.zip(next_layer.par_iter_mut())
167		.for_each(|(prev_pair, next_digest)| {
168			next_digest.write(compression.compress(array::from_fn(|i| prev_pair[i].clone())));
169		})
170}
171
172/// Hashes the elements in chunks of a vector into digests.
173///
174/// Given a vector of elements and an output buffer of N hash digests, this splits the elements
175/// into N equal-sized chunks and hashes each chunks into the corresponding output digest. This
176/// returns the number of elements hashed into each digest.
177#[tracing::instrument("hash_interleaved", skip_all, level = "debug")]
178fn hash_interleaved<F, H>(
179	elems: &[F],
180	digests: &mut [MaybeUninit<Output<H::Digest>>],
181) -> Result<(), Error>
182where
183	F: TowerField,
184	H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
185{
186	if elems.len() % digests.len() != 0 {
187		return Err(Error::IncorrectVectorLen {
188			expected: digests.len(),
189		});
190	}
191
192	let hash_data_iter = elems
193		.par_chunks(elems.len() / digests.len())
194		.map(|s| s.iter().copied());
195	hash_iterated::<_, H, _>(hash_data_iter, digests)
196}
197
198fn hash_iterated<F, H, ParIter>(
199	iterated_chunks: ParIter,
200	digests: &mut [MaybeUninit<Output<H::Digest>>],
201) -> Result<(), Error>
202where
203	F: TowerField,
204	H: ParallelDigest<Digest: BlockSizeUser + FixedOutputReset>,
205	ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
206{
207	let hasher = H::new();
208	hasher.digest(iterated_chunks, digests);
209
210	Ok(())
211}
212
213/// This can be removed when MaybeUninit::slice_assume_init_mut is stabilized
214/// <https://github.com/rust-lang/rust/issues/63569>
215///
216/// # Safety
217///
218/// It is up to the caller to guarantee that the `MaybeUninit<T>` elements
219/// really are in an initialized state.
220/// Calling this when the content is not yet fully initialized causes undefined behavior.
221///
222/// See [`assume_init_mut`] for more details and examples.
223///
224/// [`assume_init_mut`]: MaybeUninit::assume_init_mut
225pub const unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
226	std::mem::transmute(slice)
227}