用下标索引堆优化邻接表的 Prim 算法

  |  

摘要: 下标索引堆的应用。

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们以 1135. 最低成本联通所有城市 来看一下如何用下标索引堆对邻接表的 Prim 算法进行优化。邻接表 + 二叉堆的 Prim 算法时间复杂度 $O(E\log E)$,改成下标索引堆后,时间复杂度可以变为 $O(E\log V)$。

这主要是因为节点 $v$ 的优先级变好准备压入时,$v$ 可能已经在堆中了,只是优先级(以$v$为终点的边的长度)更差。而标准二叉堆是不支持直接改变指定节点的优先级或直接删除指定节点的。所以用了延迟删除的策略,这样堆中的元素个数最坏就可能是 $O(E)$。在 Dijkstra 算法中也有这个问题,在文章 迪杰斯特拉算法(Dijkstra) 中我们对这个问题做过更细的阐述。

改变堆中既有元素的优先级是很麻烦的事情,为了能够在 $O(\log V)$ 时间复杂度完成对堆中指定节点的删除或修改,在以下文章中,我们做过一些尝试:

本文我们看一个具体的例子:应用下标索引堆的方式,对 Prim 算法进行优化。

题目

想象一下你是个城市基建规划者,地图上有 n 座城市,它们按以 1 到 n 的次序编号。

给你整数 n 和一个数组 conections,其中 connections[i] = [xi, yi, costi] 表示将城市 xi 和城市 yi 连接所要的costi(连接是双向的)。

返回连接所有城市的最低成本,每对城市之间至少有一条路径。如果无法连接所有 n 个城市,返回 -1

该 最小成本 应该是所用全部连接成本的总和。

提示:

1
2
3
4
5
6
1 <= n <= 1e4
1 <= connections.length <= 1e4
connections[i].length == 3
1 <= xi, yi <= n
xi != yi
0 <= costi <= 1e5

示例 1:
输入:n = 3, conections = [[1,2,5],[1,3,6],[2,3,1]]
输出:6
解释:选出任意 2 条边都可以连接所有城市,我们从中选取成本最小的 2 条。

示例 2:
输入:n = 4, conections = [[1,2,3],[3,4,4]]
输出:-1
解释:即使连通所有的边,也无法连接所有城市。

算法:邻接表 Prim

标准的最小生成树算法,图是邻接表形式,时间复杂度 $O(E\log E)$,参考: 最小生成树。算法如下:

1
2
3
4
5
6
将 0 加入集合 T
while 尚有点没有进入集合 T
从剩下的点中,选出最短的边 `(u, v, w)`
其中 u 属于 T,v 不属于 T,即 uv 是横切边
cost += w
v 加入 T

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
public:
int minimumCost(int N, vector<vector<int>>& connections) {
vector<vector<vector<int> > > g(N + 1);
for(const vector<int> &connection: connections)
{
int x = connection[0], y = connection[1], w = connection[2];
g[x].push_back(vector<int>({w, y}));
g[y].push_back(vector<int>({w, x}));
}

priority_queue<vector<int> > pq;
pq.push({0, 1});
vector<int> visited(N + 1, false);
visited[0] = true;
int cost = 0;
while(!pq.empty())
{
vector<int> cur = pq.top();
pq.pop();
if(visited[cur[1]])
continue;
visited[cur[1]] = true;
cost += -cur[0];
for(vector<int> son: g[cur[1]])
{
if(visited[son[1]])
continue;
pq.push({-son[0], son[1]});
}
}
for(bool v: visited)
if(!v)
return -1;
return cost;
}
};

$2 下标索引堆优化 Prim

1
2
3
4
5
6
将 0 加入集合 T
while 尚有点没有进入集合 T
从剩下的点中,选出最短的边 `(u, v, w)`
其中 u 属于 T,v 不属于 T,即 uv 是横切边
cost += w
v 加入 T

点 u 是否在 T 中,依然用 visited 维护。查询最短的边对应的 v, w 之前是用堆维护的,现在用下标索引堆来代替。其中 v 相当于 idx,w 相当于 data[idx]。

下标索引堆的原理,代码模板参考文章:下标索引堆

(1) 对下标索引堆模板的修改

下标索引堆使用模板,其中根据需求增加以下几个接口:

1
2
3
top_idx() := 当前最小值所在的下标
get(idx) := 查询下标处 data 的值
check(idx) := 查询下标处在堆(indexes) 的位置,-1 表示未压过堆,-2 表示之前压过现在已经弹出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int IndexMinHeap::top_idx()
{
if(empty())
return -1;
return indexes[1] - 1;
}

int IndexMinHeap::check(int idx)
{
++idx;
return mapping[idx];
}

int IndexMinHeap::get(int idx)
{
++idx;
return data[idx];
}

数据量 N 提前可以知道,因此构造时候可以之间把所需空间开出,但堆的 size 仍为 0。修改构造函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
IndexMinHeap(int n=0, int origin_val=INT_MAX/2)
{
_size = 0;
data.assign(n + 1, -1);
indexes.assign(n + 1, -1);
mapping.assign(n + 1, -1);
for(int i = 1; i <= n; ++i)
{
data[i] = origin_val;
indexes[i] = -1;
mapping[i] = -1;
}
}

除以上修改,下标索引堆其余部分复制模板。

(2) 代码:下标索引堆优化邻接表 Prim

将二叉堆改为下标索引堆后,堆中最多的边数从 $O(E)$ 变为 $O(V)$,整体时间复杂度变为 $O(E\log V)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution {
public:
int minimumCost(int N, vector<vector<int>>& connections) {
vector<vector<vector<int> > > g(N + 1);
for(const vector<int> &connection: connections)
{
int x = connection[0], y = connection[1], w = connection[2];
g[x].push_back(vector<int>({w, y}));
g[y].push_back(vector<int>({w, x}));
}

IndexMinHeap heap(N + 1);
heap.push(1, 0);
vector<bool> visited(N + 1, false);
visited[0] = true;
int cost = 0;
while(!heap.empty())
{
int cur_idx = heap.top_idx();
int cur_cost = heap.top();
heap.pop();
if(visited[cur_idx])
continue;
visited[cur_idx] = true;
cost += cur_cost;
for(vector<int> son: g[cur_idx])
{
if(visited[son[1]])
continue;
if(heap.check(son[1]) < 0)
heap.push(son[1], son[0]);
else if(heap.get(son[1]) > son[0])
heap.change(son[1], son[0]);
}
}
for(bool v: visited)
if(!v)
return -1;
return cost;
}
};

(3) 完整代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class IndexMinHeap
{
public:
IndexMinHeap(int n=0, int origin_val=INT_MAX/2)
{
_size = 0;
data.assign(n + 1, -1);
indexes.assign(n + 1, -1);
mapping.assign(n + 1, -1);
for(int i = 1; i <= n; ++i)
{
data[i] = origin_val;
indexes[i] = -1;
mapping[i] = -1;
}
}

void build(const vector<int>& nums)
{
data.clear();
indexes.clear();
mapping.clear();
int n = nums.size();
data.assign(n + 1, -1);
indexes.assign(n + 1, -1);
mapping.assign(n + 1, -1);
for(int i = 1; i <= n; ++i)
{
data[i] = nums[i - 1];
indexes[i] = i;
mapping[i] = i;
}
_size = n;
for(int i = n; i >= 1; --i)
push_down(i);
}

int top()
{
if(empty())
return -1;
return data[indexes[1]];
}

int pop()
{
if(empty())
return -1;
int ans = data[indexes[1]];
_remove(1);
return ans;
}

void push(int idx, const int key)
{
++idx;
// 覆盖 data[idx] 的数据, idx 同时也是 key 的索引
// mapping[idx] := key 在堆中的逻辑位置,即 indexes 中的位置
if(idx >= (int)data.size())
dilatation();
++_size;
data[idx] = key;
mapping[idx] = _size;
indexes[_size] = idx;
push_up(_size);
}

void remove(int idx)
{
++idx;
if(mapping[idx] < 0)
return;
int i = mapping[idx];
_remove(i);
}

void change(int idx, const int new_key)
{
++idx;
if(mapping[idx] < 0)
return;
if(data[idx] == new_key)
return;
data[idx] = new_key;
int i = mapping[idx];
push_up(i);
push_down(i);
}

int size()
{
return _size;
}

bool empty()
{
return _size == 0;
}

private:
vector<int> data; // keys
vector<int> indexes;
int _size;
vector<int> mapping; // id -> idx, -1: 未插入过,-2 已经被删

void dilatation()
{
int new_capacity = (int)data.size() * 2 + 1;
vector<int> tmp_data(new_capacity, -1);
vector<int> tmp_indexes(new_capacity, -1);
vector<int> tmp_mapping(new_capacity, -1);
for(int i = 0; i < (int)data.size(); ++i)
{
tmp_data[i] = data[i];
tmp_indexes[i] = indexes[i];
tmp_mapping[i] = mapping[i];
}
data.swap(tmp_data);
indexes.swap(tmp_indexes);
mapping.swap(tmp_mapping);
}

void _remove(int i)
{
if(i > _size)
return;
if(i == _size)
{
mapping[indexes[_size]] = -2;
--_size;
return;
}
int idx_i = indexes[i];
int idx_j = indexes[_size];
swap(mapping[idx_i], mapping[idx_j]);
indexes[i] = indexes[_size--];
mapping[idx_i] = -2;
push_up(i);
push_down(i);
}

void push_up(int i)
{
if(i > _size) return;
while(i / 2 > 0 && data[indexes[i / 2]] > data[indexes[i]])
{
swap(mapping[indexes[i]], mapping[indexes[i / 2]]);
swap(indexes[i], indexes[i / 2]);
i /= 2;
}
}

void push_down(int i)
{
int ori = i, left = i * 2, right = i * 2 + 1;
if(left <= _size && data[indexes[left]] < data[indexes[ori]])
ori = left;
if(right <= _size && data[indexes[right]] < data[indexes[ori]])
ori = right;
if(ori != i)
{
swap(mapping[indexes[i]], mapping[indexes[ori]]);
swap(indexes[i], indexes[ori]);
push_down(ori);
}
}

public:
int top_idx()
{
if(empty())
return -1;
return indexes[1] - 1;
}

int check(int idx)
{
++idx;
return mapping[idx];
}

int get(int idx)
{
++idx;
return data[idx];
}
};

class Solution {
public:
int minimumCost(int N, vector<vector<int>>& connections) {
vector<vector<vector<int> > > g(N + 1);
for(const vector<int> &connection: connections)
{
int x = connection[0], y = connection[1], w = connection[2];
g[x].push_back(vector<int>({w, y}));
g[y].push_back(vector<int>({w, x}));
}

IndexMinHeap heap(N + 1);
heap.push(1, 0);
vector<bool> visited(N + 1, false);
visited[0] = true;
int cost = 0;
while(!heap.empty())
{
int cur_idx = heap.top_idx();
int cur_cost = heap.top();
heap.pop();
if(visited[cur_idx])
continue;
visited[cur_idx] = true;
cost += cur_cost;
for(vector<int> son: g[cur_idx])
{
if(visited[son[1]])
continue;
if(heap.check(son[1]) < 0)
heap.push(son[1], son[0]);
else if(heap.get(son[1]) > son[0])
heap.change(son[1], son[0]);
}
}
for(bool v: visited)
if(!v)
return -1;
return cost;
}
};

Share