binius_core/merkle_tree/
binary_merkle_tree.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_hash::{HashBuffer, PseudoCompressionFunction};
7use binius_maybe_rayon::{prelude::*, slice::ParallelSlice};
8use binius_utils::{
9	bail, checked_arithmetics::log2_strict_usize, SerializationMode, SerializeBytes,
10};
11use digest::{crypto_common::BlockSizeUser, Digest, FixedOutputReset, Output};
12use tracing::instrument;
13
14use super::errors::Error;
15
16/// A binary Merkle tree that commits batches of vectors.
17///
18/// The vector entries at each index in a batch are hashed together into leaf digests. Then a
19/// Merkle tree is constructed over the leaf digests. The implementation requires that the vector
20/// lengths are all equal to each other and a power of two.
21#[derive(Debug, Clone)]
22pub struct BinaryMerkleTree<D> {
23	/// Base-2 logarithm of the number of leaves
24	pub log_len: usize,
25	/// The inner nodes, arranged as a flattened array of layers with the root at the end
26	pub inner_nodes: Vec<D>,
27}
28
29pub fn build<F, H, C>(
30	compression: &C,
31	elements: &[F],
32	batch_size: usize,
33) -> Result<BinaryMerkleTree<Output<H>>, Error>
34where
35	F: TowerField,
36	H: Digest + BlockSizeUser + FixedOutputReset,
37	C: PseudoCompressionFunction<Output<H>, 2> + Sync,
38{
39	if elements.len() % batch_size != 0 {
40		bail!(Error::IncorrectBatchSize);
41	}
42
43	let len = elements.len() / batch_size;
44
45	if !len.is_power_of_two() {
46		bail!(Error::PowerOfTwoLengthRequired);
47	}
48
49	let log_len = log2_strict_usize(len);
50
51	internal_build(
52		compression,
53		|inner_nodes| hash_interleaved::<_, H>(elements, inner_nodes),
54		log_len,
55	)
56}
57
58fn internal_build<Digest, C>(
59	compression: &C,
60	// Must either successfully initialize the passed in slice or return error
61	hash_leaves: impl FnOnce(&mut [MaybeUninit<Digest>]) -> Result<(), Error>,
62	log_len: usize,
63) -> Result<BinaryMerkleTree<Digest>, Error>
64where
65	Digest: Clone + Send + Sync,
66	C: PseudoCompressionFunction<Digest, 2> + Sync,
67{
68	let total_length = (1 << (log_len + 1)) - 1;
69	let mut inner_nodes = Vec::with_capacity(total_length);
70
71	hash_leaves(&mut inner_nodes.spare_capacity_mut()[..(1 << log_len)])?;
72
73	let (prev_layer, mut remaining) = inner_nodes.spare_capacity_mut().split_at_mut(1 << log_len);
74
75	let mut prev_layer = unsafe {
76		// SAFETY: prev-layer was initialized by hash_leaves
77		slice_assume_init_mut(prev_layer)
78	};
79	for i in 1..(log_len + 1) {
80		let (next_layer, next_remaining) = remaining.split_at_mut(1 << (log_len - i));
81		remaining = next_remaining;
82
83		compress_layer(compression, prev_layer, next_layer);
84
85		prev_layer = unsafe {
86			// SAFETY: next_layer was just initialized by compress_layer
87			slice_assume_init_mut(next_layer)
88		};
89	}
90
91	unsafe {
92		// SAFETY: inner_nodes should be entirely initialized by now
93		// Note that we don't incrementally update inner_nodes.len() since
94		// that doesn't play well with using split_at_mut on spare capacity.
95		inner_nodes.set_len(total_length);
96	}
97	Ok(BinaryMerkleTree {
98		log_len,
99		inner_nodes,
100	})
101}
102
103#[instrument("BinaryMerkleTree::build", skip_all, level = "debug")]
104pub fn build_from_iterator<F, H, C, ParIter>(
105	compression: &C,
106	iterated_chunks: ParIter,
107	log_len: usize,
108) -> Result<BinaryMerkleTree<Output<H>>, Error>
109where
110	F: TowerField,
111	H: Digest + BlockSizeUser + FixedOutputReset,
112	C: PseudoCompressionFunction<Output<H>, 2> + Sync,
113	ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
114{
115	internal_build(
116		compression,
117		|inner_nodes| hash_iterated::<F, H, _>(iterated_chunks, inner_nodes),
118		log_len,
119	)
120}
121
122impl<D: Clone> BinaryMerkleTree<D> {
123	pub fn root(&self) -> D {
124		self.inner_nodes
125			.last()
126			.expect("MerkleTree inner nodes can't be empty")
127			.clone()
128	}
129
130	pub fn layer(&self, layer_depth: usize) -> Result<&[D], Error> {
131		if layer_depth > self.log_len {
132			bail!(Error::IncorrectLayerDepth);
133		}
134		let range_start = self.inner_nodes.len() + 1 - (1 << (layer_depth + 1));
135
136		Ok(&self.inner_nodes[range_start..range_start + (1 << layer_depth)])
137	}
138
139	/// Get a Merkle branch for the given index
140	///
141	/// Throws if the index is out of range
142	pub fn branch(&self, index: usize, layer_depth: usize) -> Result<Vec<D>, Error> {
143		if index >= 1 << self.log_len || layer_depth > self.log_len {
144			return Err(Error::IndexOutOfRange {
145				max: (1 << self.log_len) - 1,
146			});
147		}
148
149		let branch = (0..self.log_len - layer_depth)
150			.map(|j| {
151				let node_index = (((1 << j) - 1) << (self.log_len + 1 - j)) | (index >> j) ^ 1;
152				self.inner_nodes[node_index].clone()
153			})
154			.collect();
155
156		Ok(branch)
157	}
158}
159
160#[tracing::instrument("MerkleTree::compress_layer", skip_all, level = "debug")]
161fn compress_layer<D, C>(compression: &C, prev_layer: &[D], next_layer: &mut [MaybeUninit<D>])
162where
163	D: Clone + Send + Sync,
164	C: PseudoCompressionFunction<D, 2> + Sync,
165{
166	prev_layer
167		.par_chunks_exact(2)
168		.zip(next_layer.par_iter_mut())
169		.for_each(|(prev_pair, next_digest)| {
170			next_digest.write(compression.compress(array::from_fn(|i| prev_pair[i].clone())));
171		})
172}
173
174/// Hashes the elements in chunks of a vector into digests.
175///
176/// Given a vector of elements and an output buffer of N hash digests, this splits the elements
177/// into N equal-sized chunks and hashes each chunks into the corresponding output digest. This
178/// returns the number of elements hashed into each digest.
179#[tracing::instrument("hash_interleaved", skip_all, level = "debug")]
180fn hash_interleaved<F, H>(elems: &[F], digests: &mut [MaybeUninit<Output<H>>]) -> Result<(), Error>
181where
182	F: TowerField,
183	H: Digest + BlockSizeUser + FixedOutputReset,
184{
185	if elems.len() % digests.len() != 0 {
186		return Err(Error::IncorrectVectorLen {
187			expected: digests.len(),
188		});
189	}
190	let batch_size = elems.len() / digests.len();
191	hash_iterated::<F, H, _>(
192		elems
193			.par_chunks(batch_size)
194			.map(|chunk| chunk.iter().copied()),
195		digests,
196	)
197}
198
199fn hash_iterated<F, H, ParIter>(
200	iterated_chunks: ParIter,
201	digests: &mut [MaybeUninit<Output<H>>],
202) -> Result<(), Error>
203where
204	F: TowerField,
205	H: Digest + BlockSizeUser + FixedOutputReset,
206	ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F>>,
207{
208	digests
209		.par_iter_mut()
210		.zip(iterated_chunks)
211		.for_each_init(H::new, |hasher, (digest, elems)| {
212			{
213				let mut hash_buffer = HashBuffer::new(hasher);
214				for elem in elems {
215					let mode = SerializationMode::CanonicalTower;
216					SerializeBytes::serialize(&elem, &mut hash_buffer, mode)
217						.expect("HashBuffer has infinite capacity");
218				}
219			}
220			digest.write(Digest::finalize_reset(hasher));
221		});
222	Ok(())
223}
224
225/// This can be removed when MaybeUninit::slice_assume_init_mut is stabilized
226/// <https://github.com/rust-lang/rust/issues/63569>
227///
228/// # Safety
229///
230/// It is up to the caller to guarantee that the `MaybeUninit<T>` elements
231/// really are in an initialized state.
232/// Calling this when the content is not yet fully initialized causes undefined behavior.
233///
234/// See [`assume_init_mut`] for more details and examples.
235///
236/// [`assume_init_mut`]: MaybeUninit::assume_init_mut
237pub const unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
238	std::mem::transmute(slice)
239}