binius_utils/
graph.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// Copyright 2024-2025 Irreducible Inc.

use std::cmp::Ordering;

/// Finds connected components using a Kruskal-like approach.
/// Each input slice of usizes represents a set of nodes that form a complete subgraph
/// (i.e., all of them are connected).
///
/// Returns a vector where each index corresponds to a graph node index and the value is the
/// identifier of the connected component (the minimal node in that component).
///
/// ```
/// use binius_utils::graph::connected_components;
/// assert_eq!(connected_components(&[]), vec![]);
/// assert_eq!(connected_components(&[&[0], &[1]]), vec![0, 1]);
/// assert_eq!(connected_components(&[&[0, 1], &[1, 2], &[2, 3], &[4]]), vec![0, 0, 0, 0, 4]);
/// assert_eq!(
///     connected_components(&[&[0, 1, 2], &[5, 6, 7, 8], &[9], &[2, 3, 9]]),
///     vec![0, 0, 0, 0, 4, 5, 5, 5, 5, 0]
/// );
/// ```
///
pub fn connected_components(data: &[&[usize]]) -> Vec<usize> {
	if data.is_empty() {
		return vec![];
	}

	// Determine the maximum node ID
	let max_id = *data.iter().flat_map(|ids| ids.iter()).max().unwrap_or(&0);

	let n = max_id + 1;
	let mut uf = UnionFind::new(n);

	// Convert input sets into edges. For each chunk, we connect all nodes together.
	// To avoid adding a large number of redundant edges (fully connecting each subset),
	// we can simply connect each node in the subset to the minimum node in that subset.
	// This still ensures they all become part of one connected component.
	for ids in data {
		if ids.len() > 1 {
			let &base = ids.iter().min().unwrap();
			for &node in ids.iter() {
				if node != base {
					uf.union(base, node);
				}
			}
		}
	}

	// After union-find is complete, each node can be mapped to the minimum element
	// of its component.
	let mut result = vec![0; n];
	for (i, res) in result.iter_mut().enumerate() {
		let root = uf.find(i);
		// Use the stored minimum element for the root
		*res = uf.min_element[root];
	}

	result
}

struct UnionFind {
	parent: Vec<usize>,
	rank: Vec<usize>,
	min_element: Vec<usize>,
}

impl UnionFind {
	fn new(n: usize) -> Self {
		UnionFind {
			parent: (0..n).collect(),
			rank: vec![0; n],
			min_element: (0..n).collect(),
		}
	}

	fn find(&mut self, x: usize) -> usize {
		if self.parent[x] != x {
			self.parent[x] = self.find(self.parent[x]);
		}
		self.parent[x]
	}

	fn union(&mut self, x: usize, y: usize) {
		let rx = self.find(x);
		let ry = self.find(y);

		if rx != ry {
			// Union by rank, but also maintain the minimal element in the representative
			let min_element = self.min_element[rx].min(self.min_element[ry]);
			match self.rank[rx].cmp(&self.rank[ry]) {
				Ordering::Less => {
					self.parent[rx] = ry;
					self.min_element[ry] = min_element;
				}
				Ordering::Greater => {
					self.parent[ry] = rx;
					self.min_element[rx] = min_element;
				}
				Ordering::Equal => {
					self.parent[ry] = rx;
					self.min_element[rx] = min_element;
					self.rank[rx] += 1;
				}
			}
		}
	}
}