binius_m3/
emulate.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{collections::BTreeMap, fmt::Debug};
4
5/// A channel used to validate a high-level M3 trace.
6#[derive(Debug)]
7pub struct Channel<T> {
8	net_multiplicities: BTreeMap<T, isize>,
9}
10
11impl<T> Default for Channel<T> {
12	fn default() -> Self {
13		Self {
14			net_multiplicities: BTreeMap::default(),
15		}
16	}
17}
18
19impl<T: Eq + PartialEq + Ord + PartialOrd> Channel<T> {
20	pub fn push(&mut self, val: T) {
21		match self.net_multiplicities.get_mut(&val) {
22			Some(multiplicity) => {
23				*multiplicity += 1;
24
25				// Remove the key if the multiplicity is zero, to improve Debug behavior.
26				if *multiplicity == 0 {
27					self.net_multiplicities.remove(&val);
28				}
29			}
30			None => {
31				let _ = self.net_multiplicities.insert(val, 1);
32			}
33		}
34	}
35
36	pub fn pull(&mut self, val: T) {
37		match self.net_multiplicities.get_mut(&val) {
38			Some(multiplicity) => {
39				*multiplicity -= 1;
40
41				// Remove the key if the multiplicity is zero, to improve Debug behavior.
42				if *multiplicity == 0 {
43					self.net_multiplicities.remove(&val);
44				}
45			}
46			None => {
47				let _ = self.net_multiplicities.insert(val, -1);
48			}
49		}
50	}
51
52	pub fn is_balanced(&self) -> bool {
53		self.net_multiplicities.is_empty()
54	}
55}
56
57impl<T: Debug + Ord + PartialOrd> Channel<T> {
58	#[track_caller]
59	pub fn assert_balanced(&self) {
60		if !self.is_balanced() {
61			let (push, pull) = self
62				.net_multiplicities
63				.iter()
64				.partition::<Vec<_>, _>(|(_, multiplicity)| multiplicity.is_positive());
65
66			let mut output = String::new();
67			output.push_str("Channel is not balanced: \n");
68			if !push.is_empty() {
69				output.push_str("  Unbalanced pushes:\n");
70				for (v, balance) in push {
71					output.push_str(&format!("    {balance}: {v:?}\n"));
72				}
73			}
74			if !pull.is_empty() {
75				output.push_str("  Unbalanced pulls:\n");
76				for (v, balance) in pull {
77					output.push_str(&format!("    {}: {v:?}\n", balance.abs()));
78				}
79			}
80
81			panic!("{}", output);
82		}
83	}
84}