Skip to main content

binius_hash/
binary_merkle_tree.rs

1// Copyright 2024-2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{fmt::Debug, mem::MaybeUninit};
5
6use binius_field::Field;
7use binius_utils::{
8	checked_arithmetics::log2_strict_usize,
9	mem::slice_assume_init_mut,
10	rand::par_rand,
11	rayon::{prelude::*, slice::ParallelSlice},
12};
13use digest::{Digest, FixedOutputReset, Output, block_api::BlockSizeUser};
14use rand::{CryptoRng, Rng, rngs::StdRng};
15
16use super::{
17	compress::PseudoCompressionFunction, parallel_compression::ParallelPseudoCompression,
18	parallel_digest::ParallelDigest,
19};
20
21/// A bundle of hash and compression types used to build and verify a binary Merkle tree.
22///
23/// Most callers want to vary the underlying hash family (SHA-256, etc.) as a single unit
24/// rather than independently picking a leaf hash, a compression function, and their parallel
25/// counterparts. `HashSuite` bundles the four related types so that user-facing prover and
26/// verifier APIs can take a single `H: HashSuite` parameter instead of two or three loose hash
27/// trait parameters.
28pub trait HashSuite {
29	/// Sequential hash used to compute leaf digests during verification.
30	type LeafHash: Digest + BlockSizeUser + FixedOutputReset + Send;
31	/// Sequential 2-to-1 compression used to fold inner Merkle nodes during verification.
32	type Compression: PseudoCompressionFunction<Output<Self::LeafHash>, 2> + Default;
33	/// Parallel counterpart of [`Self::LeafHash`] used during proving.
34	type ParLeafHash: ParallelDigest<Digest = Self::LeafHash> + Default;
35	/// Parallel counterpart of [`Self::Compression`] used during proving.
36	type ParCompression: ParallelPseudoCompression<Output<Self::LeafHash>, 2, Compression = Self::Compression>
37		+ Default;
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum Error {
42	#[error("Index exceeds Merkle tree base size: {max}")]
43	IndexOutOfRange { max: usize },
44	#[error("values length must be a multiple of the batch size")]
45	IncorrectBatchSize,
46	#[error("The argument length must be a power of two.")]
47	PowerOfTwoLengthRequired,
48	#[error("The layer does not exist in the Merkle tree")]
49	IncorrectLayerDepth,
50}
51
52/// A binary Merkle tree that commits batches of vectors.
53///
54/// The vector entries at each index in a batch are hashed together into leaf digests. Then a
55/// Merkle tree is constructed over the leaf digests. The implementation requires that the vector
56/// lengths are all equal to each other and a power of two.
57#[derive(Debug, Clone)]
58pub struct BinaryMerkleTree<D, F> {
59	/// Base-2 logarithm of the number of leaves
60	pub log_len: usize,
61	/// The inner nodes, arranged as a flattened array of layers with the root at the end
62	pub inner_nodes: Vec<D>,
63	/// Salt values for each leaf (if using hiding commitments)
64	pub salts: Vec<F>,
65}
66
67pub fn build<F, H, R>(
68	elements: &[F],
69	batch_size: usize,
70	salt_len: usize,
71	rng: R,
72) -> Result<BinaryMerkleTree<Output<H::LeafHash>, F>, Error>
73where
74	F: Field,
75	H: HashSuite,
76	R: Rng + CryptoRng,
77{
78	if !elements.len().is_multiple_of(batch_size) {
79		return Err(Error::IncorrectBatchSize);
80	}
81
82	let len = elements.len() / batch_size;
83
84	if !len.is_power_of_two() {
85		return Err(Error::PowerOfTwoLengthRequired);
86	}
87
88	build_from_iterator::<_, H, _, _>(
89		elements
90			.par_chunks(batch_size)
91			.map(|chunk| chunk.iter().copied()),
92		batch_size,
93		salt_len,
94		rng,
95	)
96}
97
98pub fn build_from_iterator<F, H, R, ParIter>(
99	iterated_chunks: ParIter,
100	n_items_per_input: usize,
101	salt_len: usize,
102	mut rng: R,
103) -> Result<BinaryMerkleTree<Output<H::LeafHash>, F>, Error>
104where
105	F: Field,
106	H: HashSuite,
107	R: Rng + CryptoRng,
108	ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F, IntoIter: Send>>,
109{
110	let log_len = log2_strict_usize(iterated_chunks.len()); // precondition
111
112	// Generate salts if needed
113	let salts =
114		par_rand::<StdRng, _, _>(salt_len << log_len, &mut rng, F::random).collect::<Vec<_>>();
115
116	let total_length = (1 << (log_len + 1)) - 1;
117	let mut inner_nodes = Vec::with_capacity(total_length);
118	hash_leaves::<F, H, _>(
119		iterated_chunks,
120		n_items_per_input,
121		&mut inner_nodes.spare_capacity_mut()[..(1 << log_len)],
122		&salts,
123	);
124
125	let (prev_layer, mut remaining) = inner_nodes.spare_capacity_mut().split_at_mut(1 << log_len);
126
127	let mut prev_layer = unsafe {
128		// SAFETY: prev-layer was initialized by hash_leaves
129		slice_assume_init_mut(prev_layer)
130	};
131	let parallel_compression = H::ParCompression::default();
132	for i in 1..(log_len + 1) {
133		let (next_layer, next_remaining) = remaining.split_at_mut(1 << (log_len - i));
134		remaining = next_remaining;
135
136		parallel_compression.parallel_compress(prev_layer, next_layer);
137
138		prev_layer = unsafe {
139			// SAFETY: next_layer was just initialized by compress_layer
140			slice_assume_init_mut(next_layer)
141		};
142	}
143
144	unsafe {
145		// SAFETY: inner_nodes should be entirely initialized by now
146		// Note that we don't incrementally update inner_nodes.len() since
147		// that doesn't play well with using split_at_mut on spare capacity.
148		inner_nodes.set_len(total_length);
149	}
150	Ok(BinaryMerkleTree {
151		log_len,
152		inner_nodes,
153		salts,
154	})
155}
156
157impl<D: Clone, F> BinaryMerkleTree<D, F> {
158	pub fn root(&self) -> D {
159		self.inner_nodes
160			.last()
161			.expect("MerkleTree inner nodes can't be empty")
162			.clone()
163	}
164
165	/// Returns the salt values associated with a specific leaf index in the Merkle tree.
166	///
167	/// # Arguments
168	/// * `index` - The index of the leaf. Must be less than 2^log_len (the total number of leaves).
169	pub fn get_salt(&self, index: usize) -> &[F] {
170		assert!(index < (1 << self.log_len));
171		let salt_len = self.salts.len() >> self.log_len;
172		&self.salts[index * salt_len..(index + 1) * salt_len]
173	}
174
175	pub fn layer(&self, layer_depth: usize) -> Result<&[D], Error> {
176		if layer_depth > self.log_len {
177			return Err(Error::IncorrectLayerDepth);
178		}
179		let range_start = self.inner_nodes.len() + 1 - (1 << (layer_depth + 1));
180
181		Ok(&self.inner_nodes[range_start..range_start + (1 << layer_depth)])
182	}
183
184	/// Get a Merkle branch for the given index
185	///
186	/// Throws if the index is out of range
187	pub fn branch(&self, index: usize, layer_depth: usize) -> Result<Vec<D>, Error> {
188		if index >= 1 << self.log_len || layer_depth > self.log_len {
189			return Err(Error::IndexOutOfRange {
190				max: (1 << self.log_len) - 1,
191			});
192		}
193
194		let branch = (0..self.log_len - layer_depth)
195			.map(|j| {
196				let node_index = (((1 << j) - 1) << (self.log_len + 1 - j)) | (index >> j) ^ 1;
197				self.inner_nodes[node_index].clone()
198			})
199			.collect();
200
201		Ok(branch)
202	}
203}
204
205/// Hashes the elements in chunks of a vector into digests.
206///
207/// Given a vector of elements and an output buffer of N hash digests, this splits the elements
208/// into N equal-sized chunks and hashes each chunks into the corresponding output digest.
209///
210/// Each leaf is built from exactly `n_items_per_input` data elements (plus the per-leaf salt, when
211/// salts are present), so the leaf byte length is constant. This is passed to
212/// [`ParallelDigest::digest_with_const_len`] so the hasher can specialize for short leaves.
213///
214/// # Preconditions
215/// - Each iterator in `iterated_chunks` yields exactly `n_items_per_input` elements.
216#[tracing::instrument("hash_leaves", skip_all, level = "debug")]
217fn hash_leaves<F, H, ParIter>(
218	iterated_chunks: ParIter,
219	n_items_per_input: usize,
220	digests: &mut [MaybeUninit<Output<H::LeafHash>>],
221	salts: &[F],
222) where
223	F: Field,
224	H: HashSuite,
225	ParIter: IndexedParallelIterator<Item: IntoIterator<Item = F, IntoIter: Send>>,
226{
227	if salts.is_empty() {
228		// Need special-case handling when salts is empty, otherwise salt_len is 0 and par_chunks
229		// cannot handle chunk size of 0.
230		let hasher = H::ParLeafHash::default();
231		hasher.digest_with_const_len(n_items_per_input, iterated_chunks, digests);
232	} else {
233		assert!(salts.len().is_multiple_of(digests.len()));
234
235		let salt_len = salts.len() / digests.len();
236
237		// Create an iterator that chains each chunk with its salt
238		let salted_iter = iterated_chunks
239			.zip(salts.par_chunks(salt_len))
240			.map(|(chunk, salt)| chunk.into_iter().chain(salt.iter().copied()));
241
242		// Each salted leaf yields the data elements followed by the salt elements.
243		let hasher = H::ParLeafHash::default();
244		hasher.digest_with_const_len(n_items_per_input + salt_len, salted_iter, digests);
245	}
246}