LeetCode 684. Redundant Connection (Union-Find)

Difficulty: Medium

I used 3 approaches for this problem, using DFS, union-find with sets, and union-find with an array.

Depth First Search

The intuition behind my DFS approach is to first create an adjacency list to traverse the graph. Then for each edge in reverse order (since we want to find the extra edge with the greatest index), for edge = [a, b], we run DFS on every node connected to a expect b, and if we are able to find a path back to b, that means that edge is the extra edge. 

This approach has a time complexity of O(n^2) and space complexity of O(n). But, this is in the worst case, since most of the time we will find the edge before traversing through all the edges. Keep in mind that the number of nodes = number of edges = n, which did confuse me, and it is why we can just use n in our complexities.

I am aware of the DFS solution with one-pass, with O(n) time-complexity, where the main idea is to find the cycle, then all the edges in the cycle, then from all the edges, we find the edge that is in edges and has the largest index, but I choose not to implement it.

class Solution:
    def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
        adjList = {node:[] for node in range(1, len(edges)+1)}
        for n1, n2 in edges:
            adjList[n1].append(n2)
            adjList[n2].append(n1)
            
        n2Check = [False, None]
        seen = set()
        
        def dfs(node):
            if node == n2Check[1]:
                n2Check[0] = True
            if node not in seen:
                seen.add(node)
                for n in adjList[node]:
                    dfs(n)
                seen.remove(node)
        
        for i in range(len(edges)-1, -1, -1):
            edge = edges[i]
            n1, n2 = edge
            n2Check[1] = n2
            seen.add(n1)
            for node in adjList[n1]:
                if node != n2:
                    dfs(node)
                    if n2Check[0]: return edge
            seen.remove(n1)

Union-find with sets 

After understanding the intuition behind union-find, I tried to implement union-sort myself, and naturally, I used sets. Here, the idea is to have a hash map, and each key is a node and the value is a set of nodes that are connected together. (The key node is also in the set). Initially, each node starts as a set of only the node.

We iterate through each edge, and for each node, we check which set it belongs to, if they belong to the same set, that edge is the extra and so we return it. Since there is only one extra edge we can be sure that this edge is of the greatest index. If they belong to different sets we remove one of the sets and merge it with the other.

It would cost approximately O(n^2) time complexity since we iterate through every edge and (merge the sets + go through every set in the hashMap).

class Solution:
    def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
        unionSets = {}
        for edge in edges:
            node1, node2 = edge
            node1Set = self.getSet(node1, unionSets)
            node2Set = self.getSet(node2, unionSets)
            
            if node1Set == node2Set:
                return edge
            unionSets[node1Set] = unionSets[node1Set] | unionSets.pop(node2Set)
            
    def getSet(self, node, unionSets):
        for k in unionSets:
            if node in unionSets[k]:
                return k
        unionSets[node] = {node}
        return node

Union-Find with Array with Ranking and Path Compression

This is the most optimal approach since it makes sure we only use the information we need and not whole sets. The intuition behind this approach is that we only care about the root of each node while performing find. I also use the same array to store the rank and the parent, rank being negative numbers and the parent being positive numbers. 

We first have to initialize an array of size n (Length of edges), with the default value -1. This represents that each node has a weight of 1 and is also the root node. For each edge, we check if each node has the same root, by recursing to its root using the find function. If it’s not we perform union, by first checking which root has the larger rank (lower negative number). Then the root with the larger rank’s rank is increased by the rank of the other root, and the root with the smaller rank now points to the root with the larger rank.

Without path compression, to find the root of the node we might have to recurse through a long path. However, by performing path compression, which will point every node directly to its root, during find. This will result on find having a constant time complexity most of the time (amortized). So the overall time complexity would be O(n), which is just iterating through every edge. The space complexity would be O(n) for the array. Of course, our find with path compression does use a recursive stack space, but it’s very minimal and definitely less than n.

class Solution:
    def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
        parent = [-1] * len(edges)
        
        def find(node):
            if parent[node] < 0:
                return node
            parent[node] = find(parent[node])
            return parent[node]
        
        def union(n1, n2):
            n1, n2 = n1-1, n2-1
            r1, r2 = find(n1), find(n2)
            if r1 == r2:
                return False
            
            if parent[r1] <= parent[r2]:
                parent[r1] += parent[r2]
                parent[r2] = r1
            else:
                parent[r2] += parent[r1]
                parent[r1] = r2
            return True
        
        for n1, n2 in edges:
            if not union(n1, n2):
                return [n1, n2]