带权并查集:需要多个权值以及权值为复杂结构的情况

  |  

摘要: 并查集中的集合权值很复杂的情况

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


在文章 并查集 中我们学习了并查集的原理和代码模板。我们知道并查集的结构是一个森林,其中的每一棵树表示一个集合或者一个连通分量,树根元素为该集合的代表元,树中的各个节点分别代表各个元素。

有时我们需要在合并的过程中维护连通分量的某种属性,比如元素个数、最值等等,称为集合级的权。维护这种集合级的权,只需要在基础的并查集上增加很少代码即可实现,在文章 含集合级信息的并查集 中我们讨论了带集合级信息的并查集的原理与代码模板,并给出了丰富的例题。

有时需要维护的信息很复杂,一个代表元下仅仅维护元素个数还不够,需要再加上其它额外信息,甚至这种额外信息是另一种数据结构。本文我们就来解决一个这样的问题。

题目

给定一个由 n 个节点组成的网络,用 n x n 个邻接矩阵 graph 表示。在节点网络中,只有当 graph[i][j] = 1 时,节点 i 能够直接连接到另一个节点 j。

一些节点 initial 最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。

假设 M(initial) 是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。

我们可以从 initial 中删除一个节点,并完全移除该节点以及从该节点到任何其他节点的任何连接。

请返回移除后能够使 M(initial) 最小化的节点。如果有多个节点满足条件,返回索引 最小的节点 。

提示:

1
2
3
4
5
6
7
8
9
n == graph.length
n == graph[i].length
2 <= n <= 300
graph[i][j] 是 0 或 1.
graph[i][j] == graph[j][i]
graph[i][i] == 1
1 <= initial.length < n
0 <= initial[i] <= n - 1
initial 中每个整数都不同

示例 1:
输入:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
输出:0

示例 2:
输入:graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1]
输出:1

示例 3:
输入:graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1]
输出:1

题解

算法:带权并查集

用并查集维护这 N 个节点,其中包含两种集合级的信息,一个是连通分量元素个数,另一个是连通分量中的病毒个数。

枚举所有的节点对 $(i, j)$,如果是图中的边,且 $i, j$ 均不是初始时的污染节点,则在并查集中 merge(i, j)

然后枚举每个初始时的污染节点 $i$,然后枚举图中 $i$ 的每条边 $(i, j)$,$j$ 所在连通分量都会被病毒 $i$ 影响,我们希望记录每个连通分量受到多少个病毒节点影响。这需要一个额外的权值来记录。

由于一个代表元下有多个节点,病毒节点可能会与其中的不止一个点连边,这样在枚举完病毒节点 $i$ 的每条边后,有的连通分量就被病毒节点 $i$ 影响了多次。下图是一个例子,[6, 7] 是一个连通分量,而病毒 4 通过边 [4, 7] 和 [4, 6] 影响了该连通分量两次:

因此代表连通分量中的病毒个数的那个权值,就不能用一个简单的数字,而需要用一个哈希表来记录与连通分量相连的具体病毒节点,这样就可以处理重复的问题了。

前面的处理完成后,枚举每个初始时的病毒节点 $i$,该节点可能与并查集中的多个集合相连。考察与 $i$ 相连的每个集合,我们只需要代表病毒个数的权值为 1 的连通分量,原因在于只能删除一个节点,病毒个数若大于 1,则删除 $i$ 也不能让该集合变为无病毒影响。这些病毒个数为 1 的连通分量中的元素是将 $i$ 删除后可以减少感染的节点。

但注意,这里同一个病毒还是有可能会连接同一个连通分量两次,这样在计数的时候该连通分量就会重复记,这里还是需要一个哈希表维护与当前病毒相连的代表元,解决去重问题。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <vector>
#include <unordered_set>

using namespace std;

class UnionFindSet
{
public:
UnionFindSet(int n)
{
_father.assign(n, -1);
for(int i = 0; i < n; ++i)
_father[i] = i;
_rank.assign(n, 0);
_weight.assign(n, 1);
_label = vector<unordered_set<int>>(n);
}

int get_weight(int x)
{
return _weight[_find(x)];
}

void set_label(int x, int y)
{
_label[_find(x)].insert(y);
}

int get_label(int x)
{
return _label[_find(x)].size();
}

int get_root(int x)
{
return _find(x);
}

void merge(int x, int y)
{
x = _find(x);
y = _find(y);
if(x == y)
return;

if(_rank[x] < _rank[y])
{
_father[x] = y;
_weight[y] += _weight[x];
}
else
{
_father[y] = x;
_weight[x] += _weight[y];
if(_rank[x] == _rank[y])
++_rank[x];
}
}

private:
vector<int> _father;
vector<int> _rank;
vector<int> _weight;
vector<unordered_set<int>> _label;

int _find(int x)
{
if(x == _father[x])
return x;
return _father[x] = _find(_father[x]);
}
};

class Solution {
public:
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
int n = graph.size();
unordered_set<int> setting(initial.begin(), initial.end());
UnionFindSet unionfindset(n);
for(int i = 0; i < n - 1; ++i)
{
for(int j = i + 1; j < n; ++j)
{
if(graph[i][j] == 1 && setting.count(i) == 0 && setting.count(j) == 0)
{
unionfindset.merge(i, j);
}
}
}
for(int i: initial)
{
for(int j = 0; j < n; ++j)
{
if(i == j) continue;
if(setting.count(j) > 0)
continue;
if(graph[i][j] == 0)
continue;
unionfindset.set_label(j, i);
}
}

int ans = -1;
int max_cand = -1;
for(int i: initial)
{
int cand = 0;
unordered_set<int> roots;
for(int j = 0; j < n; ++j)
{
if(i == j) continue;
if(setting.count(j) > 0) continue;
if(graph[i][j] == 0)
continue;
if(unionfindset.get_label(j) == 1)
{
roots.insert(unionfindset.get_root(j));
}

}
for(int r: roots)
{
int cc = unionfindset.get_weight(r);
cand += cc;
}
if(cand > max_cand)
{
max_cand = cand;
ans = i;
}
else if(cand == max_cand && i < ans)
ans = i;
}
return ans;
}
};

Share