1use std::{
6 collections::{HashMap, HashSet},
7 hash::Hash,
8};
9
10use binius_field::AESTowerField8b;
11use binius_hash::groestl::{GroestlShortImpl, GroestlShortInternal};
12
13use crate::{builder::B8, emulate::Channel};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
17pub struct NodeFlushToken {
18 pub root_id: u8,
19 pub data: [u8; 32],
20 pub depth: usize,
21 pub index: usize,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
26pub struct RootFlushToken {
27 pub root_id: u8,
28 pub data: [u8; 32],
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
35pub struct MerklePath {
36 pub root_id: u8,
37 pub index: usize,
38 pub leaf: [u8; 32],
39 pub nodes: Vec<[u8; 32]>,
40}
41
42#[allow(clippy::tabs_in_doc_comments)]
45pub struct MerkleTreeChannels {
46 nodes: Channel<NodeFlushToken>,
51
52 roots: Channel<RootFlushToken>,
55}
56
57impl Default for MerkleTreeChannels {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl MerkleTreeChannels {
64 pub fn new() -> Self {
65 Self {
66 nodes: Channel::default(),
67 roots: Channel::default(),
68 }
69 }
70}
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
74pub struct MerklePathEvent {
75 pub root_id: u8,
76 pub left: [u8; 32],
77 pub right: [u8; 32],
78 pub parent: [u8; 32],
79 pub parent_depth: usize,
80 pub parent_index: usize,
81 pub flush_left: bool,
82 pub flush_right: bool,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
88pub struct MerkleRootEvent {
89 pub root_id: u8,
90 pub digest: [u8; 32],
91}
92
93impl MerkleRootEvent {
94 pub fn new(root_id: u8, digest: [u8; 32]) -> Self {
95 Self { root_id, digest }
96 }
97}
98
99fn compress(left: &[u8], right: &[u8], output: &mut [u8]) {
104 let mut state_bytes = [0u8; 64];
105 let (half0, half1) = state_bytes.split_at_mut(32);
106 half0.copy_from_slice(left);
107 half1.copy_from_slice(right);
108 state_bytes = state_bytes.map(|b| AESTowerField8b::from(B8::new(b)).val());
109 let input = GroestlShortImpl::state_from_bytes(&state_bytes);
110 let mut state = input;
111 GroestlShortImpl::p_perm(&mut state);
112 GroestlShortImpl::xor_state(&mut state, &input);
113 state_bytes = GroestlShortImpl::state_to_bytes(&state);
114 state_bytes = state_bytes.map(|b| B8::from(AESTowerField8b::new(b)).val());
115 output.copy_from_slice(&state_bytes[32..]);
116}
117
118pub struct MerkleTree {
122 depth: usize,
123 nodes: Vec<[u8; 32]>,
124 root: [u8; 32],
125}
126
127impl MerkleTree {
128 pub fn new(leaves: &[[u8; 32]]) -> Self {
131 assert!(leaves.len().is_power_of_two(), "Length of leaves needs to be a power of 2.");
132 let depth = leaves.len().ilog2() as usize;
133 let mut nodes = vec![[0u8; 32]; 2 * leaves.len() - 1];
134
135 nodes[0..leaves.len()].copy_from_slice(leaves);
137
138 let mut current_depth_marker = 0;
140 let mut parent_depth_marker = 0;
141 for i in (0..depth).rev() {
143 let level_size = 1 << (i + 1);
144 let next_level_size = 1 << i;
145 parent_depth_marker += level_size;
146
147 let (current_layer, parent_layer) = nodes
148 [current_depth_marker..parent_depth_marker + next_level_size]
149 .split_at_mut(level_size);
150
151 for j in 0..next_level_size {
152 let left = ¤t_layer[2 * j];
153 let right = ¤t_layer[2 * j + 1];
154 compress(left, right, &mut parent_layer[j])
155 }
156 current_depth_marker = parent_depth_marker;
158 }
159 let root = *nodes.last().expect("Merkle tree should not be empty");
161 Self { depth, nodes, root }
162 }
163
164 pub fn merkle_path(&self, index: usize) -> Vec<[u8; 32]> {
166 assert!(index < 1 << self.depth, "Index out of range.");
167 (0..self.depth)
168 .map(|j| {
169 let node_index = (((1 << j) - 1) << (self.depth + 1 - j)) | (index >> j) ^ 1;
170 self.nodes[node_index]
171 })
172 .collect()
173 }
174
175 pub fn verify_path(path: &[[u8; 32]], root: [u8; 32], leaf: [u8; 32], index: usize) {
177 assert!(index < 1 << path.len(), "Index out of range.");
178 let mut current_hash = leaf;
179 let mut next_hash = [0u8; 32];
180 for (i, node) in path.iter().enumerate() {
181 if (index >> i) & 1 == 0 {
182 compress(¤t_hash, node, &mut next_hash);
183 } else {
184 compress(node, ¤t_hash, &mut next_hash);
185 }
186 current_hash = next_hash;
187 }
188 assert_eq!(current_hash, root);
189 }
190
191 pub fn root(&self) -> [u8; 32] {
193 self.root
194 }
195}
196
197impl MerklePathEvent {
198 pub fn fire(&self, node_channel: &mut Channel<NodeFlushToken>) {
201 node_channel.push(NodeFlushToken {
204 root_id: self.root_id,
205 data: self.parent,
206 depth: self.parent_depth,
207 index: self.parent_index,
208 });
209
210 if self.flush_left {
211 node_channel.pull(NodeFlushToken {
212 root_id: self.root_id,
213 data: self.left,
214 depth: self.parent_depth + 1,
215 index: 2 * self.parent_index,
216 });
217 }
218 if self.flush_right {
219 node_channel.pull(NodeFlushToken {
220 root_id: self.root_id,
221 data: self.right,
222 depth: self.parent_depth + 1,
223 index: 2 * self.parent_index + 1,
224 });
225 }
226 }
227}
228
229impl MerkleRootEvent {
230 pub fn fire(
231 &self,
232 node_channel: &mut Channel<NodeFlushToken>,
233 root_channel: &mut Channel<RootFlushToken>,
234 ) {
235 node_channel.pull(NodeFlushToken {
238 root_id: self.root_id,
239 data: self.digest,
240 depth: 0,
241 index: 0,
242 });
243 root_channel.pull(RootFlushToken {
246 root_id: self.root_id,
247 data: self.digest,
248 });
249 }
250}
251
252#[derive(Debug, Clone, PartialEq, Eq)]
254pub struct MerkleBoundaries {
255 pub leaf: HashSet<NodeFlushToken>,
256 pub root: HashSet<RootFlushToken>,
257}
258
259impl Default for MerkleBoundaries {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264impl MerkleBoundaries {
265 pub fn new() -> Self {
266 Self {
267 leaf: HashSet::new(),
268 root: HashSet::new(),
269 }
270 }
271
272 pub fn insert(&mut self, leaf: NodeFlushToken, root: RootFlushToken) {
273 self.leaf.insert(leaf);
274 self.root.insert(root);
275 }
276}
277
278pub struct MerkleTreeTrace {
280 pub boundaries: MerkleBoundaries,
281 pub nodes: Vec<MerklePathEvent>,
282 pub root: HashSet<MerkleRootEvent>,
283}
284impl MerkleTreeTrace {
285 pub fn generate(roots: Vec<[u8; 32]>, paths: &[MerklePath]) -> Self {
289 let mut path_nodes = Vec::new();
290 let mut root_nodes = HashSet::new();
291 let mut boundaries = MerkleBoundaries::new();
292 let mut filled_nodes = HashMap::new();
298
299 for path in paths.iter() {
300 let MerklePath {
301 root_id,
302 index,
303 leaf,
304 nodes,
305 } = path;
306
307 let mut current_child = *leaf;
308
309 boundaries.insert(
311 NodeFlushToken {
312 root_id: *root_id,
313 data: current_child,
314 depth: nodes.len(),
315 index: *index,
316 },
317 RootFlushToken {
318 root_id: *root_id,
319 data: roots[*root_id as usize],
320 },
321 );
322
323 root_nodes.insert(MerkleRootEvent::new(*root_id, roots[*root_id as usize]));
325
326 let mut parent_node = [0u8; 32];
327 for (i, &node) in nodes.iter().enumerate() {
328 match filled_nodes.get_mut(&(*root_id, index >> (i + 1), nodes.len() - i - 1)) {
329 Some((_, _, parent, flush_left, flush_right)) => {
330 if (index >> i) & 1 == 0 {
331 parent_node = *parent;
332 *flush_left = true;
333 } else {
334 parent_node = *parent;
335 *flush_right = true;
336 }
337 }
338 None => {
339 if (index >> i) & 1 == 0 {
340 compress(¤t_child, &node, &mut parent_node);
341 filled_nodes.insert(
342 (*root_id, index >> (i + 1), nodes.len() - i - 1),
343 (current_child, node, parent_node, true, false),
344 );
345 } else {
346 compress(&node, ¤t_child, &mut parent_node);
347 filled_nodes.insert(
348 (*root_id, index >> (i + 1), nodes.len() - i - 1),
349 (node, current_child, parent_node, false, true),
350 );
351 }
352 }
353 }
354 current_child = parent_node
355 }
356 }
357
358 path_nodes.extend(filled_nodes.iter().map(|(key, value)| {
359 let &(root_id, parent_index, parent_depth) = key;
360 let &(left, right, parent, flush_left, flush_right) = value;
361 MerklePathEvent {
362 root_id,
363 left,
364 right,
365 parent,
366 parent_depth,
367 parent_index,
368 flush_left,
369 flush_right,
370 }
371 }));
372
373 Self {
374 boundaries,
375 nodes: path_nodes,
376 root: root_nodes,
377 }
378 }
379
380 pub fn validate(&self) {
381 let mut channels = MerkleTreeChannels::new();
382
383 for leaf in &self.boundaries.leaf {
385 channels.nodes.push(*leaf);
386 }
387 for root in &self.boundaries.root {
388 channels.roots.push(*root);
389 }
390
391 for root in &self.root {
393 root.fire(&mut channels.nodes, &mut channels.roots);
394 }
395
396 for node in &self.nodes {
398 node.fire(&mut channels.nodes);
399 }
400
401 channels.nodes.assert_balanced();
403 channels.roots.assert_balanced();
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use rand::{Rng, SeedableRng, rngs::StdRng};
410
411 use super::*;
412 #[test]
414 fn test_merkle_tree() {
415 let leaves = vec![
416 [0u8; 32], [1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32], [5u8; 32], [6u8; 32], [7u8; 32],
417 ];
418 let tree = MerkleTree::new(&leaves);
419 let path = tree.merkle_path(0);
420 let root = tree.root();
421 let leaf = leaves[0];
422 MerkleTree::verify_path(&path, root, leaf, 0);
423
424 assert_eq!(tree.depth, 3);
425 }
426
427 #[test]
429 fn test_high_level_model_inclusion() {
430 let mut rng = StdRng::from_seed([0; 32]);
431 let path_index = rng.random_range(0..1 << 10);
432 let leaves = (0..1 << 10)
433 .map(|_| rng.random::<[u8; 32]>())
434 .collect::<Vec<_>>();
435
436 let tree = MerkleTree::new(&leaves);
437 let root = tree.root();
438 let path = tree.merkle_path(path_index);
439 let path_root_id = 0;
440 let merkle_tree_trace = MerkleTreeTrace::generate(
441 vec![root],
442 &[MerklePath {
443 root_id: path_root_id,
444 index: path_index,
445 leaf: leaves[path_index],
446 nodes: path,
447 }],
448 );
449 merkle_tree_trace.validate();
450 }
451
452 #[test]
453 fn test_high_level_model_inclusion_multiple_paths() {
454 let mut rng = StdRng::from_seed([0; 32]);
455
456 let leaves = (0..1 << 4)
457 .map(|_| rng.random::<[u8; 32]>())
458 .collect::<Vec<_>>();
459
460 let tree = MerkleTree::new(&leaves);
461 let root = tree.root();
462 let paths = (0..2)
463 .map(|_| {
464 let path_index = rng.random_range(0..1 << 4);
465 MerklePath {
466 root_id: 0u8,
467 index: path_index,
468 leaf: leaves[path_index],
469 nodes: tree.merkle_path(path_index),
470 }
471 })
472 .collect::<Vec<_>>();
473 let merkle_tree_trace = MerkleTreeTrace::generate(vec![root], &paths);
474 merkle_tree_trace.validate();
475 }
476
477 #[test]
478 fn test_high_level_model_inclusion_multiple_roots() {
479 let mut rng = StdRng::from_seed([0; 32]);
480 let path_index = rng.random_range(0..1 << 10);
481 let leaves = (0..3)
482 .map(|_| {
483 (0..1 << 10)
484 .map(|_| rng.random::<[u8; 32]>())
485 .collect::<Vec<_>>()
486 })
487 .collect::<Vec<_>>();
488
489 let trees = (0..3)
490 .map(|i| MerkleTree::new(&leaves[i]))
491 .collect::<Vec<_>>();
492 let roots = (0..3).map(|i| trees[i].root()).collect::<Vec<_>>();
493 let paths = trees
494 .iter()
495 .enumerate()
496 .map(|(i, tree)| MerklePath {
497 root_id: i as u8,
498 index: path_index,
499 leaf: leaves[i][path_index],
500 nodes: tree.merkle_path(path_index),
501 })
502 .collect::<Vec<_>>();
503
504 let merkle_tree_trace = MerkleTreeTrace::generate(roots, &paths);
505 merkle_tree_trace.validate();
506 }
507}