binius_m3/gadgets/merkle_tree/
trace.rs

1// Copyright 2025 Irreducible Inc.
2
3/// High-level model for binary Merkle trees using the Grøstl-256 output transformation as a
4/// 2-to-1 compression function.
5use std::{
6	collections::{HashMap, HashSet},
7	hash::Hash,
8};
9
10use binius_field::AESTowerField8b;
11use binius_hash::groestl::{GroestlShortImpl, GroestlShortInternal};
12
13use crate::{builder::B8, emulate::Channel};
14
15/// Signature of the Nodes channel: (Root ID, Data, Depth, Index)
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
17pub struct NodeFlushToken {
18	pub root_id: u8,
19	pub data: [u8; 32],
20	pub depth: usize,
21	pub index: usize,
22}
23
24/// Signature of the Roots channel: (Root ID, Root digest)
25#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
26pub struct RootFlushToken {
27	pub root_id: u8,
28	pub data: [u8; 32],
29}
30
31/// A type alias for the Merkle path, which is a vector of tuples containing the root ID, index,
32/// leaf, and the siblings on the path to the root from the leaf.
33
34#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
35pub struct MerklePath {
36	pub root_id: u8,
37	pub index: usize,
38	pub leaf: [u8; 32],
39	pub nodes: Vec<[u8; 32]>,
40}
41
42/// A struct whose fields contain the channels involved in the trace to verify merkle paths for
43/// a binary merkle tree
44#[allow(clippy::tabs_in_doc_comments)]
45pub struct MerkleTreeChannels {
46	/// This channel gets flushed with tokens during "intermediate" steps of the verification
47	/// where the tokens are the values of the parent digest of the claimed siblings along with
48	/// associated position information such as the root it is associated to, the values of
49	/// the child digests, the depth and the index.
50	nodes: Channel<NodeFlushToken>,
51
52	/// This channel contains flushes that validate that the "final" digest obtained in a
53	/// merkle path is matches that of one of the claimed roots, pushed as boundary values.
54	roots: Channel<RootFlushToken>,
55}
56
57impl Default for MerkleTreeChannels {
58	fn default() -> Self {
59		Self::new()
60	}
61}
62
63impl MerkleTreeChannels {
64	pub fn new() -> Self {
65		Self {
66			nodes: Channel::default(),
67			roots: Channel::default(),
68		}
69	}
70}
71/// A table representing a step in verifying a merkle path for inclusion.
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
74pub struct MerklePathEvent {
75	pub root_id: u8,
76	pub left: [u8; 32],
77	pub right: [u8; 32],
78	pub parent: [u8; 32],
79	pub parent_depth: usize,
80	pub parent_index: usize,
81	pub flush_left: bool,
82	pub flush_right: bool,
83}
84
85/// A table representing the final step of comparing the claimed root.
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
88pub struct MerkleRootEvent {
89	pub root_id: u8,
90	pub digest: [u8; 32],
91}
92
93impl MerkleRootEvent {
94	pub fn new(root_id: u8, digest: [u8; 32]) -> Self {
95		Self { root_id, digest }
96	}
97}
98
99/// Uses the Groestl256 compression function to compress two 32-byte inputs into a single
100/// 32-byte. It is assumed the bytes are in the Fan-Paar basis and so are transformed to the
101/// AESTowerField8b basis before the Groestl compression function is applied to agree with the
102/// Groestl gadget.
103fn compress(left: &[u8], right: &[u8], output: &mut [u8]) {
104	let mut state_bytes = [0u8; 64];
105	let (half0, half1) = state_bytes.split_at_mut(32);
106	half0.copy_from_slice(left);
107	half1.copy_from_slice(right);
108	state_bytes = state_bytes.map(|b| AESTowerField8b::from(B8::new(b)).val());
109	let input = GroestlShortImpl::state_from_bytes(&state_bytes);
110	let mut state = input;
111	GroestlShortImpl::p_perm(&mut state);
112	GroestlShortImpl::xor_state(&mut state, &input);
113	state_bytes = GroestlShortImpl::state_to_bytes(&state);
114	state_bytes = state_bytes.map(|b| B8::from(AESTowerField8b::new(b)).val());
115	output.copy_from_slice(&state_bytes[32..]);
116}
117
118/// Merkle tree implementation for the model, assumes the leaf layer consists of 32 byte blobs.
119/// The tree is built in a flattened manner, where the leaves are at the beginning of the vector
120/// and layers are placed adjacent to each other.
121pub struct MerkleTree {
122	depth: usize,
123	nodes: Vec<[u8; 32]>,
124	root: [u8; 32],
125}
126
127impl MerkleTree {
128	/// Constructs a Merkle tree from the given leaf nodes that uses the Groestl output
129	/// transformation (Groestl-P permutation + XOR) as a digest compression function.
130	pub fn new(leaves: &[[u8; 32]]) -> Self {
131		assert!(leaves.len().is_power_of_two(), "Length of leaves needs to be a power of 2.");
132		let depth = leaves.len().ilog2() as usize;
133		let mut nodes = vec![[0u8; 32]; 2 * leaves.len() - 1];
134
135		// Fill the leaves in the flattened tree.
136		nodes[0..leaves.len()].copy_from_slice(leaves);
137
138		// Marks the beginning of the layer in the flattened tree.
139		let mut current_depth_marker = 0;
140		let mut parent_depth_marker = 0;
141		// Build the tree from the leaves up to the root.
142		for i in (0..depth).rev() {
143			let level_size = 1 << (i + 1);
144			let next_level_size = 1 << i;
145			parent_depth_marker += level_size;
146
147			let (current_layer, parent_layer) = nodes
148				[current_depth_marker..parent_depth_marker + next_level_size]
149				.split_at_mut(level_size);
150
151			for j in 0..next_level_size {
152				let left = &current_layer[2 * j];
153				let right = &current_layer[2 * j + 1];
154				compress(left, right, &mut parent_layer[j])
155			}
156			// Move the marker to the next level.
157			current_depth_marker = parent_depth_marker;
158		}
159		// The root of the tree is the last node in the flattened tree.
160		let root = *nodes.last().expect("Merkle tree should not be empty");
161		Self { depth, nodes, root }
162	}
163
164	/// Returns a merkle path for the given index.
165	pub fn merkle_path(&self, index: usize) -> Vec<[u8; 32]> {
166		assert!(index < 1 << self.depth, "Index out of range.");
167		(0..self.depth)
168			.map(|j| {
169				let node_index = (((1 << j) - 1) << (self.depth + 1 - j)) | (index >> j) ^ 1;
170				self.nodes[node_index]
171			})
172			.collect()
173	}
174
175	/// Verifies a merkle path for inclusion in the tree.
176	pub fn verify_path(path: &[[u8; 32]], root: [u8; 32], leaf: [u8; 32], index: usize) {
177		assert!(index < 1 << path.len(), "Index out of range.");
178		let mut current_hash = leaf;
179		let mut next_hash = [0u8; 32];
180		for (i, node) in path.iter().enumerate() {
181			if (index >> i) & 1 == 0 {
182				compress(&current_hash, node, &mut next_hash);
183			} else {
184				compress(node, &current_hash, &mut next_hash);
185			}
186			current_hash = next_hash;
187		}
188		assert_eq!(current_hash, root);
189	}
190
191	// Returns the root of the merkle tree.
192	pub fn root(&self) -> [u8; 32] {
193		self.root
194	}
195}
196
197impl MerklePathEvent {
198	/// Method to fire the event, pushing the parent digest if the parent flag is set and
199	/// pulling the left or right child depending on the flush flags.
200	pub fn fire(&self, node_channel: &mut Channel<NodeFlushToken>) {
201		// Push the parent digest to the nodes channel and optionally pull the left or right
202		// child depending on the flush flags.
203		node_channel.push(NodeFlushToken {
204			root_id: self.root_id,
205			data: self.parent,
206			depth: self.parent_depth,
207			index: self.parent_index,
208		});
209
210		if self.flush_left {
211			node_channel.pull(NodeFlushToken {
212				root_id: self.root_id,
213				data: self.left,
214				depth: self.parent_depth + 1,
215				index: 2 * self.parent_index,
216			});
217		}
218		if self.flush_right {
219			node_channel.pull(NodeFlushToken {
220				root_id: self.root_id,
221				data: self.right,
222				depth: self.parent_depth + 1,
223				index: 2 * self.parent_index + 1,
224			});
225		}
226	}
227}
228
229impl MerkleRootEvent {
230	pub fn fire(
231		&self,
232		node_channel: &mut Channel<NodeFlushToken>,
233		root_channel: &mut Channel<RootFlushToken>,
234	) {
235		// Pull the root node value presumed to have been pushed to the nodes channel from the
236		// merkle path table.
237		node_channel.pull(NodeFlushToken {
238			root_id: self.root_id,
239			data: self.digest,
240			depth: 0,
241			index: 0,
242		});
243		// Pull the root node from the roots channel, presumed to have been pushed as a boundary
244		// value.
245		root_channel.pull(RootFlushToken {
246			root_id: self.root_id,
247			data: self.digest,
248		});
249	}
250}
251
252/// Struct representing the boundary values of merkle tree inclusion proof statement.
253#[derive(Debug, Clone, PartialEq, Eq)]
254pub struct MerkleBoundaries {
255	pub leaf: HashSet<NodeFlushToken>,
256	pub root: HashSet<RootFlushToken>,
257}
258
259impl Default for MerkleBoundaries {
260	fn default() -> Self {
261		Self::new()
262	}
263}
264impl MerkleBoundaries {
265	pub fn new() -> Self {
266		Self {
267			leaf: HashSet::new(),
268			root: HashSet::new(),
269		}
270	}
271
272	pub fn insert(&mut self, leaf: NodeFlushToken, root: RootFlushToken) {
273		self.leaf.insert(leaf);
274		self.root.insert(root);
275	}
276}
277
278/// Struct representing the trace of the merkle tree inclusion proof statement.
279pub struct MerkleTreeTrace {
280	pub boundaries: MerkleBoundaries,
281	pub nodes: Vec<MerklePathEvent>,
282	pub root: HashSet<MerkleRootEvent>,
283}
284impl MerkleTreeTrace {
285	/// Method to generate the trace given the witness values. The function assumes that the
286	/// root_id is the index of the root in the roots vector and that the paths and leaves are
287	/// passed in with their assigned root_id.
288	pub fn generate(roots: Vec<[u8; 32]>, paths: &[MerklePath]) -> Self {
289		let mut path_nodes = Vec::new();
290		let mut root_nodes = HashSet::new();
291		let mut boundaries = MerkleBoundaries::new();
292		// Number of times each root is referenced in the paths. Since internal nodes have been
293		// deduped, these need to be pushed into the nodes channel as many times as they are
294		// referenced in the paths.
295
296		// Tracks the filled nodes in the tree
297		let mut filled_nodes = HashMap::new();
298
299		for path in paths.iter() {
300			let MerklePath {
301				root_id,
302				index,
303				leaf,
304				nodes,
305			} = path;
306
307			let mut current_child = *leaf;
308
309			// Populate the boundary values for the statement.
310			boundaries.insert(
311				NodeFlushToken {
312					root_id: *root_id,
313					data: current_child,
314					depth: nodes.len(),
315					index: *index,
316				},
317				RootFlushToken {
318					root_id: *root_id,
319					data: roots[*root_id as usize],
320				},
321			);
322
323			// Populate the root table's events.
324			root_nodes.insert(MerkleRootEvent::new(*root_id, roots[*root_id as usize]));
325
326			let mut parent_node = [0u8; 32];
327			for (i, &node) in nodes.iter().enumerate() {
328				match filled_nodes.get_mut(&(*root_id, index >> (i + 1), nodes.len() - i - 1)) {
329					Some((_, _, parent, flush_left, flush_right)) => {
330						if (index >> i) & 1 == 0 {
331							parent_node = *parent;
332							*flush_left = true;
333						} else {
334							parent_node = *parent;
335							*flush_right = true;
336						}
337					}
338					None => {
339						if (index >> i) & 1 == 0 {
340							compress(&current_child, &node, &mut parent_node);
341							filled_nodes.insert(
342								(*root_id, index >> (i + 1), nodes.len() - i - 1),
343								(current_child, node, parent_node, true, false),
344							);
345						} else {
346							compress(&node, &current_child, &mut parent_node);
347							filled_nodes.insert(
348								(*root_id, index >> (i + 1), nodes.len() - i - 1),
349								(node, current_child, parent_node, false, true),
350							);
351						}
352					}
353				}
354				current_child = parent_node
355			}
356		}
357
358		path_nodes.extend(filled_nodes.iter().map(|(key, value)| {
359			let &(root_id, parent_index, parent_depth) = key;
360			let &(left, right, parent, flush_left, flush_right) = value;
361			MerklePathEvent {
362				root_id,
363				left,
364				right,
365				parent,
366				parent_depth,
367				parent_index,
368				flush_left,
369				flush_right,
370			}
371		}));
372
373		Self {
374			boundaries,
375			nodes: path_nodes,
376			root: root_nodes,
377		}
378	}
379
380	pub fn validate(&self) {
381		let mut channels = MerkleTreeChannels::new();
382
383		// Push the boundary values to the nodes and roots channels.
384		for leaf in &self.boundaries.leaf {
385			channels.nodes.push(*leaf);
386		}
387		for root in &self.boundaries.root {
388			channels.roots.push(*root);
389		}
390
391		// Push the roots to the roots channel.
392		for root in &self.root {
393			root.fire(&mut channels.nodes, &mut channels.roots);
394		}
395
396		// Push the nodes to the nodes channel.
397		for node in &self.nodes {
398			node.fire(&mut channels.nodes);
399		}
400
401		// Assert that the nodes and roots channels are balanced.
402		channels.nodes.assert_balanced();
403		channels.roots.assert_balanced();
404	}
405}
406
407#[cfg(test)]
408mod tests {
409	use rand::{Rng, SeedableRng, rngs::StdRng};
410
411	use super::*;
412	// Tests for the merkle tree implementation.
413	#[test]
414	fn test_merkle_tree() {
415		let leaves = vec![
416			[0u8; 32], [1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32], [5u8; 32], [6u8; 32], [7u8; 32],
417		];
418		let tree = MerkleTree::new(&leaves);
419		let path = tree.merkle_path(0);
420		let root = tree.root();
421		let leaf = leaves[0];
422		MerkleTree::verify_path(&path, root, leaf, 0);
423
424		assert_eq!(tree.depth, 3);
425	}
426
427	// Tests for the Merkle tree trace generation
428	#[test]
429	fn test_high_level_model_inclusion() {
430		let mut rng = StdRng::from_seed([0; 32]);
431		let path_index = rng.random_range(0..1 << 10);
432		let leaves = (0..1 << 10)
433			.map(|_| rng.random::<[u8; 32]>())
434			.collect::<Vec<_>>();
435
436		let tree = MerkleTree::new(&leaves);
437		let root = tree.root();
438		let path = tree.merkle_path(path_index);
439		let path_root_id = 0;
440		let merkle_tree_trace = MerkleTreeTrace::generate(
441			vec![root],
442			&[MerklePath {
443				root_id: path_root_id,
444				index: path_index,
445				leaf: leaves[path_index],
446				nodes: path,
447			}],
448		);
449		merkle_tree_trace.validate();
450	}
451
452	#[test]
453	fn test_high_level_model_inclusion_multiple_paths() {
454		let mut rng = StdRng::from_seed([0; 32]);
455
456		let leaves = (0..1 << 4)
457			.map(|_| rng.random::<[u8; 32]>())
458			.collect::<Vec<_>>();
459
460		let tree = MerkleTree::new(&leaves);
461		let root = tree.root();
462		let paths = (0..2)
463			.map(|_| {
464				let path_index = rng.random_range(0..1 << 4);
465				MerklePath {
466					root_id: 0u8,
467					index: path_index,
468					leaf: leaves[path_index],
469					nodes: tree.merkle_path(path_index),
470				}
471			})
472			.collect::<Vec<_>>();
473		let merkle_tree_trace = MerkleTreeTrace::generate(vec![root], &paths);
474		merkle_tree_trace.validate();
475	}
476
477	#[test]
478	fn test_high_level_model_inclusion_multiple_roots() {
479		let mut rng = StdRng::from_seed([0; 32]);
480		let path_index = rng.random_range(0..1 << 10);
481		let leaves = (0..3)
482			.map(|_| {
483				(0..1 << 10)
484					.map(|_| rng.random::<[u8; 32]>())
485					.collect::<Vec<_>>()
486			})
487			.collect::<Vec<_>>();
488
489		let trees = (0..3)
490			.map(|i| MerkleTree::new(&leaves[i]))
491			.collect::<Vec<_>>();
492		let roots = (0..3).map(|i| trees[i].root()).collect::<Vec<_>>();
493		let paths = trees
494			.iter()
495			.enumerate()
496			.map(|(i, tree)| MerklePath {
497				root_id: i as u8,
498				index: path_index,
499				leaf: leaves[i][path_index],
500				nodes: tree.merkle_path(path_index),
501			})
502			.collect::<Vec<_>>();
503
504		let merkle_tree_trace = MerkleTreeTrace::generate(roots, &paths);
505		merkle_tree_trace.validate();
506	}
507}