binius_utils/graph.rs
1// Copyright 2024-2025 Irreducible Inc.
2
3use std::cmp::Ordering;
4
5/// Finds connected components using a Kruskal-like approach.
6/// Each input slice of usizes represents a set of nodes that form a complete subgraph
7/// (i.e., all of them are connected).
8///
9/// Returns a vector where each index corresponds to a graph node index and the value is the
10/// identifier of the connected component (the minimal node in that component).
11///
12/// ```
13/// use binius_utils::graph::connected_components;
14/// assert_eq!(connected_components(&[]), vec![]);
15/// assert_eq!(connected_components(&[&[0], &[1]]), vec![0, 1]);
16/// assert_eq!(connected_components(&[&[0, 1], &[1, 2], &[2, 3], &[4]]), vec![0, 0, 0, 0, 4]);
17/// assert_eq!(
18/// connected_components(&[&[0, 1, 2], &[5, 6, 7, 8], &[9], &[2, 3, 9]]),
19/// vec![0, 0, 0, 0, 4, 5, 5, 5, 5, 0]
20/// );
21/// ```
22///
23pub fn connected_components(data: &[&[usize]]) -> Vec<usize> {
24 if data.is_empty() {
25 return vec![];
26 }
27
28 // Determine the maximum node ID
29 let max_id = *data.iter().flat_map(|ids| ids.iter()).max().unwrap_or(&0);
30
31 let n = max_id + 1;
32 let mut uf = UnionFind::new(n);
33
34 // Convert input sets into edges. For each chunk, we connect all nodes together.
35 // To avoid adding a large number of redundant edges (fully connecting each subset),
36 // we can simply connect each node in the subset to the minimum node in that subset.
37 // This still ensures they all become part of one connected component.
38 for ids in data {
39 if ids.len() > 1 {
40 let &base = ids.iter().min().unwrap();
41 for &node in *ids {
42 if node != base {
43 uf.union(base, node);
44 }
45 }
46 }
47 }
48
49 // After union-find is complete, each node can be mapped to the minimum element
50 // of its component.
51 let mut result = vec![0; n];
52 for (i, res) in result.iter_mut().enumerate() {
53 let root = uf.find(i);
54 // Use the stored minimum element for the root
55 *res = uf.min_element[root];
56 }
57
58 result
59}
60
61struct UnionFind {
62 parent: Vec<usize>,
63 rank: Vec<usize>,
64 min_element: Vec<usize>,
65}
66
67impl UnionFind {
68 fn new(n: usize) -> Self {
69 Self {
70 parent: (0..n).collect(),
71 rank: vec![0; n],
72 min_element: (0..n).collect(),
73 }
74 }
75
76 fn find(&mut self, x: usize) -> usize {
77 if self.parent[x] != x {
78 self.parent[x] = self.find(self.parent[x]);
79 }
80 self.parent[x]
81 }
82
83 fn union(&mut self, x: usize, y: usize) {
84 let rx = self.find(x);
85 let ry = self.find(y);
86
87 if rx != ry {
88 // Union by rank, but also maintain the minimal element in the representative
89 let min_element = self.min_element[rx].min(self.min_element[ry]);
90 match self.rank[rx].cmp(&self.rank[ry]) {
91 Ordering::Less => {
92 self.parent[rx] = ry;
93 self.min_element[ry] = min_element;
94 }
95 Ordering::Greater => {
96 self.parent[ry] = rx;
97 self.min_element[rx] = min_element;
98 }
99 Ordering::Equal => {
100 self.parent[ry] = rx;
101 self.min_element[rx] = min_element;
102 self.rank[rx] += 1;
103 }
104 }
105 }
106 }
107}