1pub mod trace;
17
18use std::{array, cell::RefMut, cmp::Reverse};
19
20use array_util::ArrayExt;
21use binius_core::{constraint_system::channel::ChannelId, oracle::ShiftVariant};
22use binius_field::{
23 BinaryField8b, Field, PackedBinaryField4x8b, PackedBinaryField8x8b, PackedExtension,
24 PackedField, PackedFieldIndexable, PackedSubfield,
25 linear_transformation::PackedTransformationFactory,
26 packed::{get_packed_slice, set_packed_slice},
27 underlier::WithUnderlier,
28};
29use itertools::Itertools;
30use trace::{MerklePathEvent, MerkleRootEvent, MerkleTreeTrace, NodeFlushToken, RootFlushToken};
31
32use crate::{
33 builder::{
34 B1, B8, B32, B64, B128, Boundary, Col, ConstraintSystem, FlushDirection, TableBuilder,
35 TableFiller, TableId, TableWitnessSegment, WitnessIndex, tally, upcast_col,
36 },
37 gadgets::{
38 hash::groestl::{Permutation, PermutationVariant},
39 indexed_lookup::incr::{Incr, IncrIndexedLookup, IncrLookup, merge_incr_vals},
40 },
41};
42pub struct MerkleTreeCS {
46 pub merkle_path_table_left: NodesTable,
49 pub merkle_path_table_right: NodesTable,
50 pub merkle_path_table_both: NodesTable,
51
52 pub root_table: RootTable,
54
55 pub incr_table: IncrLookup,
56 pub nodes_channel: ChannelId,
59
60 pub roots_channel: ChannelId,
64
65 pub lookup_channel: ChannelId,
68}
69
70impl MerkleTreeCS {
71 pub fn new(cs: &mut ConstraintSystem) -> Self {
72 let nodes_channel = cs.add_channel("merkle_tree_nodes");
73 let roots_channel = cs.add_channel("merkle_tree_roots");
74 let lookup_channel = cs.add_channel("incr_lookup");
75 let permutation_channel = cs.add_channel("permutation");
76
77 let merkle_path_table_left =
78 NodesTable::new(cs, MerklePathPullChild::Left, nodes_channel, lookup_channel);
79 let merkle_path_table_right =
80 NodesTable::new(cs, MerklePathPullChild::Right, nodes_channel, lookup_channel);
81 let merkle_path_table_both =
82 NodesTable::new(cs, MerklePathPullChild::Both, nodes_channel, lookup_channel);
83
84 let root_table = RootTable::new(cs, nodes_channel, roots_channel);
85
86 let mut table = cs.add_table("incr_lookup_table");
87 let incr_table = IncrLookup::new(&mut table, lookup_channel, permutation_channel, 20);
88 Self {
89 merkle_path_table_left,
90 merkle_path_table_right,
91 merkle_path_table_both,
92 root_table,
93 nodes_channel,
94 roots_channel,
95 lookup_channel,
96 incr_table,
97 }
98 }
99
100 pub fn fill_tables(
101 &self,
102 trace: &MerkleTreeTrace,
103 cs: &ConstraintSystem,
104 witness: &mut WitnessIndex,
105 ) -> anyhow::Result<()> {
106 let left_events = trace
108 .nodes
109 .iter()
110 .copied()
111 .filter(|event| event.flush_left && !event.flush_right)
112 .collect::<Vec<_>>();
113 let right_events = trace
114 .nodes
115 .iter()
116 .copied()
117 .filter(|event| !event.flush_left && event.flush_right)
118 .collect::<Vec<_>>();
119 let both_events = trace
120 .nodes
121 .iter()
122 .copied()
123 .filter(|event| event.flush_left && event.flush_right)
124 .collect::<Vec<_>>();
125
126 witness.fill_table_parallel(&self.merkle_path_table_left, &left_events)?;
128 witness.fill_table_parallel(&self.merkle_path_table_right, &right_events)?;
129 witness.fill_table_parallel(&self.merkle_path_table_both, &both_events)?;
130
131 witness.fill_table_parallel(
133 &self.root_table,
134 &trace.root.clone().into_iter().collect::<Vec<_>>(),
135 )?;
136
137 let lookup_counts = tally(cs, witness, &[], self.lookup_channel, &IncrIndexedLookup)?;
138
139 let sorted_counts = lookup_counts
141 .into_iter()
142 .enumerate()
143 .sorted_by_key(|(_, count)| Reverse(*count))
144 .collect::<Vec<_>>();
145 witness.fill_table_parallel(&self.incr_table, &sorted_counts)?;
146 witness.fill_constant_cols()?;
147 Ok(())
148 }
149
150 pub fn make_boundaries(&self, trace: &MerkleTreeTrace) -> Vec<Boundary<B128>> {
151 let mut boundaries = Vec::new();
152 for &NodeFlushToken {
154 root_id,
155 data,
156 depth,
157 index,
158 } in &trace.boundaries.leaf
159 {
160 let leaf_state = bytes_to_boundary(&data);
161 let values = vec![
162 B128::from(root_id as u128),
163 leaf_state[0],
164 leaf_state[1],
165 leaf_state[2],
166 leaf_state[3],
167 leaf_state[4],
168 leaf_state[5],
169 leaf_state[6],
170 leaf_state[7],
171 B128::from(depth as u128),
172 B128::from(index as u128),
173 ];
174 boundaries.push(Boundary {
175 values,
176 channel_id: self.nodes_channel,
177 direction: FlushDirection::Push,
178 multiplicity: 1,
179 });
180 }
181
182 for &RootFlushToken { root_id, data } in &trace.boundaries.root {
184 let state = bytes_to_boundary(&data);
185 let values = vec![
186 B128::new(root_id as u128),
187 state[0],
188 state[1],
189 state[2],
190 state[3],
191 state[4],
192 state[5],
193 state[6],
194 state[7],
195 ];
196 boundaries.push(Boundary {
197 values,
198 channel_id: self.roots_channel,
199 direction: FlushDirection::Push,
200 multiplicity: 1,
201 });
202 }
203 boundaries
204 }
205}
206
207pub enum MerklePathPullChild {
212 Left,
213 Right,
214 Both,
215}
216
217pub struct NodesTable {
219 id: TableId,
220 root_id: Col<B8>,
222 state_out_shifted: [Col<B8, 8>; 8],
225 permutation_output_columns: [Col<B8, 4>; 8],
228 left_columns: [Col<B8, 4>; 8],
232 right_columns: [Col<B8, 4>; 8],
233 parent_columns: [Col<B8, 4>; 8],
234 pub parent_depth: Col<B8>,
235 pub child_depth: Col<B8>,
236 parent_index: Col<B1, 32>,
237 left_index: Col<B1, 32>,
238 right_index_packed: Col<B32>,
239 permutation: Permutation,
242 increment: Incr,
245 pub _pull_child: MerklePathPullChild,
246}
247
248impl NodesTable {
249 pub fn new(
250 cs: &mut ConstraintSystem,
251 pull_child: MerklePathPullChild,
252 nodes_channel_id: ChannelId,
253 lookup_chan: ChannelId,
254 ) -> Self {
255 let mut table = cs.add_table(format!("merkle_tree_nodes_{}", {
256 match pull_child {
257 MerklePathPullChild::Left => "left",
258 MerklePathPullChild::Right => "right",
259 MerklePathPullChild::Both => "both",
260 }
261 }));
262
263 let id = table.id();
264 let root_id = table.add_committed("root_id");
266
267 let left_columns: [Col<BinaryField8b, 4>; 8] = table.add_committed_multiple("left_columns");
273 let right_columns: [Col<BinaryField8b, 4>; 8] =
274 table.add_committed_multiple("right_columns");
275 let state_in = table.add_committed_multiple("state_in");
276
277 let left_packed: [Col<B32>; 8] =
278 array::from_fn(|i| table.add_packed(format!("left_packed[{i}]"), left_columns[i]));
279 let right_packed: [Col<B32>; 8] =
280 array::from_fn(|i| table.add_packed(format!("right_packed[{i}]"), right_columns[i]));
281 let state_in_packed: [Col<B64>; 8] =
282 array::from_fn(|i| table.add_packed(format!("state_in_packed[{i}]"), state_in[i]));
283
284 for i in 0..8 {
285 table.assert_zero(
286 format!("state_in_assert[{i}]"),
287 state_in_packed[i]
288 - upcast_col(left_packed[i])
289 - upcast_col(right_packed[i]) * B64::from(1 << 32),
290 );
291 }
292 let permutation = Permutation::new(
293 &mut table.with_namespace("permutation"),
294 PermutationVariant::P,
295 state_in,
296 );
297
298 let state_out = permutation.state_out();
299
300 let state_out_shifted: [Col<BinaryField8b, 8>; 8] = array::from_fn(|i| {
301 table.add_shifted(
302 format!("state_in_shifted[{i}]"),
303 state_out[i],
304 3,
305 4,
306 ShiftVariant::LogicalRight,
307 )
308 });
309
310 let permutation_output_columns: [Col<BinaryField8b, 4>; 8] = array::from_fn(|i| {
311 table.add_selected_block(
312 format!("permutation_output_columns[{i}]"),
313 state_out_shifted[i],
314 4,
315 )
316 });
317
318 let parent_columns: [Col<B8, 4>; 8] = array::from_fn(|i| {
319 table.add_computed(
320 format!("parent_columns[{i}]"),
321 permutation_output_columns[i] + right_columns[i],
322 )
323 });
324
325 let parent_packed: [Col<B32>; 8] =
326 array::from_fn(|i| table.add_packed(format!("parent_packed[{i}]"), parent_columns[i]));
327
328 let parent_depth = table.add_committed("parent_depth");
329
330 let one = table.add_constant("one", [B1::ONE]);
331
332 let increment = Incr::new(&mut table, lookup_chan, parent_depth, one);
333 let child_depth = increment.output;
334
335 let parent_index: Col<B1, 32> = table.add_committed("parent_index");
336 let left_index: Col<B1, 32> =
337 table.add_shifted("left_index", parent_index, 5, 1, ShiftVariant::LogicalLeft);
338 let left_index_packed = table.add_packed("left_index_packed", left_index);
341 let right_index_packed =
342 table.add_computed("right_index_packed", left_index_packed + B32::ONE);
343 let parent_index_packed: Col<B32> = table.add_packed("parent_index_packed", parent_index);
344
345 let left_index_upcasted = upcast_col(left_index_packed);
346 let right_index_upcasted = upcast_col(right_index_packed);
347 let parent_index_upcasted = upcast_col(parent_index_packed);
348 let parent_depth_upcasted = upcast_col(parent_depth);
349 let child_depth_upcasted = upcast_col(child_depth);
350 let root_id_upcasted = upcast_col(root_id);
351
352 let mut nodes_channel = NodesChannel::new(&mut table, nodes_channel_id);
353
354 nodes_channel.push(
355 root_id_upcasted,
356 parent_packed,
357 parent_depth_upcasted,
358 parent_index_upcasted,
359 );
360
361 match pull_child {
362 MerklePathPullChild::Left => nodes_channel.pull(
363 root_id_upcasted,
364 left_packed,
365 child_depth_upcasted,
366 left_index_upcasted,
367 ),
368 MerklePathPullChild::Right => nodes_channel.pull(
369 root_id_upcasted,
370 right_packed,
371 child_depth_upcasted,
372 right_index_upcasted,
373 ),
374 MerklePathPullChild::Both => {
375 nodes_channel.pull(
376 root_id_upcasted,
377 left_packed,
378 child_depth_upcasted,
379 left_index_upcasted,
380 );
381 nodes_channel.pull(
382 root_id_upcasted,
383 right_packed,
384 child_depth_upcasted,
385 right_index_upcasted,
386 );
387 }
388 }
389 Self {
390 id,
391 root_id,
392 state_out_shifted,
393 permutation_output_columns,
394 left_columns,
395 right_columns,
396 parent_columns,
397 parent_depth,
398 child_depth,
399 parent_index,
400 left_index,
401 right_index_packed,
402 _pull_child: pull_child,
403 permutation,
404 increment,
405 }
406 }
407}
408
409pub struct NodesChannel<'a> {
410 table: &'a mut TableBuilder<'a>,
411 channel_id: ChannelId,
412}
413
414impl<'a> NodesChannel<'a> {
415 pub fn new(table: &'a mut TableBuilder<'a>, channel_id: ChannelId) -> Self {
416 Self { table, channel_id }
417 }
418
419 pub fn push(
420 &mut self,
421 root_id: Col<B32>,
422 digest: [Col<B32>; 8],
423 depth: Col<B32>,
424 index: Col<B32>,
425 ) {
426 self.table
427 .push(self.channel_id, to_node_flush(upcast_col(root_id), digest, depth, index));
428 }
429
430 pub fn pull(
431 &mut self,
432 root_id: Col<B32>,
433 digest: [Col<B32>; 8],
434 depth: Col<B32>,
435 index: Col<B32>,
436 ) {
437 self.table
438 .pull(self.channel_id, to_node_flush(upcast_col(root_id), digest, depth, index));
439 }
440}
441
442fn to_node_flush(
444 root_id: Col<B32>,
445 digest: [Col<B32>; 8],
446 depth: Col<B32>,
447 index: Col<B32>,
448) -> [Col<B32>; 11] {
449 [
450 root_id, digest[0], digest[1], digest[2], digest[3], digest[4], digest[5], digest[6],
451 digest[7], depth, index,
452 ]
453}
454
455fn to_root_flush(root_id: Col<B32>, digest: [Col<B32>; 8]) -> [Col<B32>; 9] {
456 [
457 root_id, digest[0], digest[1], digest[2], digest[3], digest[4], digest[5], digest[6],
458 digest[7],
459 ]
460}
461
462pub struct RootTable {
463 pub id: TableId,
464 pub root_id: Col<B8>,
465 pub digest: [Col<B32>; 8],
466}
467
468impl RootTable {
469 pub fn new(
470 cs: &mut ConstraintSystem,
471 nodes_channel_id: ChannelId,
472 roots_channel_id: ChannelId,
473 ) -> Self {
474 let mut table = cs.add_table("merkle_tree_roots");
475 let id = table.id();
476 let root_id = table.add_committed("root_id");
477 let digest = table.add_committed_multiple("digest");
478
479 let zero = table.add_constant("zero", [B32::ZERO]);
480 let root_id_upcasted = upcast_col(root_id);
481 table.pull(roots_channel_id, to_root_flush(root_id_upcasted, digest));
482 let mut nodes_channel = NodesChannel::new(&mut table, nodes_channel_id);
483 nodes_channel.pull(root_id_upcasted, digest, zero, zero);
484 Self {
485 id,
486 root_id,
487 digest,
488 }
489 }
490}
491
492impl<P> TableFiller<P> for NodesTable
493where
494 P: PackedFieldIndexable<Scalar = B128>
495 + PackedExtension<B1>
496 + PackedExtension<B8>
497 + PackedExtension<B32>
498 + PackedExtension<B64>,
499 PackedSubfield<P, B8>:
500 PackedTransformationFactory<PackedSubfield<P, B8>> + PackedFieldIndexable,
501{
502 type Event = MerklePathEvent;
503
504 fn id(&self) -> TableId {
505 self.id
506 }
507
508 fn fill(
509 &self,
510 rows: &[Self::Event],
511 witness: &mut TableWitnessSegment<P>,
512 ) -> anyhow::Result<()> {
513 let state_ins = rows
514 .iter()
515 .map(|MerklePathEvent { left, right, .. }| {
516 let mut digest = [B8::ZERO; 64];
517 let (left_bytes, right_bytes) =
518 (B8::from_underliers_arr_ref(left), B8::from_underliers_arr_ref(right));
519
520 digest[..32].copy_from_slice(left_bytes);
521 digest[32..].copy_from_slice(right_bytes);
522 digest
523 })
524 .collect::<Vec<_>>();
525 self.permutation.populate_state_in(witness, &state_ins)?;
526 self.permutation.populate(witness)?;
527
528 let mut witness_root_id: RefMut<'_, [u8]> = witness.get_mut_as(self.root_id)?;
529 let mut witness_parent_depth: RefMut<'_, [u8]> = witness.get_mut_as(self.parent_depth)?;
530 let mut witness_child_depth: RefMut<'_, [u8]> = witness.get_mut_as(self.child_depth)?;
531 let mut witness_parent_index: RefMut<'_, [u32]> = witness.get_mut_as(self.parent_index)?;
532 let mut witness_left_index: RefMut<'_, [u32]> = witness.get_mut_as(self.left_index)?;
534 let mut witness_right_index_packed: RefMut<'_, [u32]> =
535 witness.get_mut_as(self.right_index_packed)?;
536
537 let witness_state_out: [RefMut<'_, [PackedBinaryField8x8b]>; 8] = self
538 .permutation
539 .state_out()
540 .try_map_ext(|col| witness.get_mut_as(col))?;
541
542 let mut witness_state_out_shifted: [RefMut<'_, [PackedBinaryField8x8b]>; 8] = self
543 .state_out_shifted
544 .try_map_ext(|col| witness.get_mut_as(col))?;
545
546 let mut witness_permutation_output_columns: [RefMut<'_, [PackedBinaryField4x8b]>; 8] = self
547 .permutation_output_columns
548 .try_map_ext(|col| witness.get_mut_as(col))?;
549
550 let mut left_columns: [RefMut<'_, [PackedBinaryField4x8b]>; 8] = self
551 .left_columns
552 .try_map_ext(|col| witness.get_mut_as(col))?;
553 let mut right_columns: [RefMut<'_, [PackedBinaryField4x8b]>; 8] = self
554 .right_columns
555 .try_map_ext(|col| witness.get_mut_as(col))?;
556 let mut parent_columns: [RefMut<'_, [PackedBinaryField4x8b]>; 8] = self
557 .parent_columns
558 .try_map_ext(|col| witness.get_mut_as(col))?;
559
560 let mut increment_merged: RefMut<'_, [u32]> = witness.get_mut_as(self.increment.merged)?;
561
562 {
563 for (i, event) in rows.iter().enumerate() {
564 let &MerklePathEvent {
565 root_id,
566 parent_depth,
567 parent_index,
568 left,
569 right,
570 parent,
571 ..
572 } = event;
573
574 witness_root_id[i] = root_id;
575 witness_parent_depth[i] = parent_depth
576 .try_into()
577 .expect("Parent depth must fit in u8");
578 witness_parent_index[i] = parent_index as u32;
579 witness_child_depth[i] = witness_parent_depth[i] + 1;
580 witness_left_index[i] = 2 * parent_index as u32;
581 witness_right_index_packed[i] = 2 * parent_index as u32 + 1;
582
583 increment_merged[i] =
584 merge_incr_vals(witness_parent_depth[i], true, witness_child_depth[i], false);
585 let left_bytes: [BinaryField8b; 32] = B8::from_underliers_arr(left);
586 let right_bytes: [BinaryField8b; 32] = B8::from_underliers_arr(right);
587 let parent_bytes: [BinaryField8b; 32] = B8::from_underliers_arr(parent);
588
589 for jk in 0..32 {
590 let j = jk % 8;
592 let k = jk / 8;
594
595 set_packed_slice(&mut left_columns[j], i * 4 + k, left_bytes[8 * k + j]);
597 set_packed_slice(&mut right_columns[j], i * 4 + k, right_bytes[8 * k + j]);
598 set_packed_slice(&mut parent_columns[j], i * 4 + k, parent_bytes[8 * k + j]);
599
600 let permutation_output = get_packed_slice(&witness_state_out[j], i * 8 + 4 + k);
602 set_packed_slice(
603 &mut witness_state_out_shifted[j],
604 i * 8 + k,
605 permutation_output,
606 );
607 set_packed_slice(
608 &mut witness_permutation_output_columns[j],
609 i * 4 + k,
610 permutation_output,
611 );
612 }
613 }
614 }
615 Ok(())
616 }
617}
618
619impl<P> TableFiller<P> for RootTable
620where
621 P: PackedFieldIndexable<Scalar = B128>
622 + PackedExtension<B1>
623 + PackedExtension<B8>
624 + PackedExtension<B32>
625 + PackedExtension<B64>,
626 PackedSubfield<P, B8>: PackedFieldIndexable,
627{
628 type Event = MerkleRootEvent;
629
630 fn id(&self) -> TableId {
631 self.id
632 }
633
634 fn fill(
635 &self,
636 rows: &[Self::Event],
637 witness: &mut TableWitnessSegment<P>,
638 ) -> anyhow::Result<()> {
639 let mut witness_root_id = witness.get_mut_as(self.root_id)?;
640 let mut witness_root_digest: Vec<RefMut<'_, [PackedBinaryField4x8b]>> = (0..8)
641 .map(|i| witness.get_mut_as(self.digest[i]))
642 .collect::<Result<Vec<_>, _>>()?;
643
644 for (i, event) in rows.iter().enumerate() {
645 let &MerkleRootEvent { root_id, digest } = event;
646 witness_root_id[i] = root_id;
647 let digest_as_field = B8::from_underliers_arr(digest);
648 for (jk, &byte) in digest_as_field.iter().enumerate() {
649 let j = jk % 8;
651 let k = jk / 8;
653 set_packed_slice(&mut witness_root_digest[j], i * 4 + k, byte);
654 }
655 }
656 Ok(())
657 }
658}
659
660fn bytes_to_boundary(bytes: &[u8; 32]) -> [B128; 8] {
661 let mut cols = [PackedBinaryField4x8b::zero(); 8];
662 for ij in 0..32 {
663 let i = ij % 8;
665 let j = ij / 8;
667
668 set_packed_slice(&mut cols, 4 * i + j, B8::from(bytes[8 * j + i]));
670 }
671 cols.map(|col| B128::from(col.to_underlier() as u128))
672}
673
674#[cfg(test)]
675mod tests {
676 use binius_compute::cpu::alloc::CpuComputeAllocator;
677 use binius_field::{arch::OptimalUnderlier, as_packed_field::PackedType};
678 use rand::{Rng, SeedableRng, rngs::StdRng};
679 use trace::{MerklePath, MerkleTree};
680
681 use super::*;
682 use crate::builder::test_utils::validate_system_witness;
683 #[test]
684 fn test_nodes_table_constructor() {
685 let mut cs = ConstraintSystem::new();
686 let nodes_channel = cs.add_channel("nodes");
687 let lookup_channel = cs.add_channel("lookup");
688 let pull_child = MerklePathPullChild::Left;
689 let nodes_table = NodesTable::new(&mut cs, pull_child, nodes_channel, lookup_channel);
690 assert_eq!(nodes_table.left_columns.len(), 8);
691 assert_eq!(nodes_table.right_columns.len(), 8);
692 assert_eq!(nodes_table.parent_columns.len(), 8);
693 }
694 #[test]
695 fn test_root_table_constructor() {
696 let mut cs = ConstraintSystem::new();
697 let nodes_channel = cs.add_channel("nodes");
698 let roots_channel = cs.add_channel("roots");
699 let root_table = RootTable::new(&mut cs, nodes_channel, roots_channel);
700 assert_eq!(root_table.digest.len(), 8);
701 }
702 #[test]
703 fn test_node_table_filling() {
704 let mut cs = ConstraintSystem::new();
705 let nodes_channel = cs.add_channel("nodes");
706 let lookup_channel = cs.add_channel("lookup");
707 let pull_child = MerklePathPullChild::Left;
709 let nodes_table = NodesTable::new(&mut cs, pull_child, nodes_channel, lookup_channel);
710 let tree = MerkleTree::new(&[
711 [0u8; 32], [1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32], [5u8; 32], [6u8; 32], [7u8; 32],
712 ]);
713
714 let index = 0;
715 let path = tree.merkle_path(0);
716 let trace = MerkleTreeTrace::generate(
717 vec![tree.root()],
718 &[MerklePath {
719 root_id: 0,
720 index,
721 leaf: [0u8; 32],
722 nodes: path,
723 }],
724 );
725 let mut allocator = CpuComputeAllocator::new(1 << 12);
726 let allocator = allocator.into_bump_allocator();
727 let mut witness = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
728
729 witness
730 .fill_table_sequential(&nodes_table, &trace.nodes)
731 .unwrap();
732 }
733
734 #[test]
735 fn test_root_table_filling() {
736 let mut cs = ConstraintSystem::new();
737 let nodes_channel = cs.add_channel("nodes");
738 let roots_channel = cs.add_channel("roots");
739 let root_table = RootTable::new(&mut cs, nodes_channel, roots_channel);
740 let leaves = [
741 [0u8; 32], [1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32], [5u8; 32], [6u8; 32], [7u8; 32],
742 ];
743 let tree = MerkleTree::new(&leaves);
744 let path = tree.merkle_path(0);
745 let trace = MerkleTreeTrace::generate(
746 vec![tree.root()],
747 &[MerklePath {
748 root_id: 0,
749 index: 0,
750 leaf: leaves[0],
751 nodes: path,
752 }],
753 );
754 let mut allocator = CpuComputeAllocator::new(1 << 12);
755 let allocator = allocator.into_bump_allocator();
756 let mut witness = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
757
758 witness
759 .fill_table_sequential(&root_table, &trace.root.into_iter().collect::<Vec<_>>())
760 .unwrap();
761 }
762
763 #[test]
764 fn test_merkle_tree_cs_fill_tables() {
765 let mut cs = ConstraintSystem::new();
766 let merkle_tree_cs = MerkleTreeCS::new(&mut cs);
767
768 let tree = MerkleTree::new(&[
769 [0u8; 32], [1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32], [5u8; 32], [6u8; 32], [7u8; 32],
770 ]);
771 let index = 0;
772 let path = tree.merkle_path(index);
773
774 let trace = MerkleTreeTrace::generate(
775 vec![tree.root()],
776 &[MerklePath {
777 root_id: 0,
778 index,
779 leaf: [0u8; 32],
780 nodes: path,
781 }],
782 );
783
784 let mut allocator = CpuComputeAllocator::new(1 << 12);
785 let allocator = allocator.into_bump_allocator();
786 let mut witness = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
787
788 merkle_tree_cs
789 .fill_tables(&trace, &cs, &mut witness)
790 .unwrap();
791 }
792
793 #[test]
794 fn test_merkle_tree_cs_end_to_end() {
795 let mut cs = ConstraintSystem::new();
796 let merkle_tree_cs = MerkleTreeCS::new(&mut cs);
797
798 let mut rng = StdRng::seed_from_u64(0);
799 let index = rng.random_range(0..1 << 10);
801 let leaves = (0..3)
802 .map(|_| {
803 (0..1 << 10)
804 .map(|_| rng.random::<[u8; 32]>())
805 .collect::<Vec<_>>()
806 })
807 .collect::<Vec<_>>();
808
809 let trees = (0..3)
810 .map(|i| MerkleTree::new(&leaves[i]))
811 .collect::<Vec<_>>();
812 let roots = (0..3).map(|i| trees[i].root()).collect::<Vec<_>>();
813 let paths = trees
814 .iter()
815 .enumerate()
816 .map(|(i, tree)| MerklePath {
817 root_id: i as u8,
818 index,
819 leaf: leaves[i][index],
820 nodes: tree.merkle_path(index),
821 })
822 .collect::<Vec<_>>();
823
824 let trace = MerkleTreeTrace::generate(roots, &paths);
825
826 let mut allocator = CpuComputeAllocator::new(1 << 14);
828 let allocator = allocator.into_bump_allocator();
829 let mut witness = WitnessIndex::new(&cs, &allocator);
830
831 merkle_tree_cs
833 .fill_tables(&trace, &cs, &mut witness)
834 .unwrap();
835
836 let boundaries = merkle_tree_cs.make_boundaries(&trace);
838
839 validate_system_witness::<OptimalUnderlier>(&cs, witness, boundaries);
841 }
842}