下标索引堆优化Prim

  |  

$0 1135. 最低成本联通所有城市

$1 邻接表 Prim

1
2
3
4
5
6
将 0 加入集合 T
while 尚有点没有进入集合 T
从剩下的点中,选出最短的边 `(u, v, w)`
其中 u 属于 T,v 不属于 T,即 uv 是横切边
cost += w
v 加入 T
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
public:
int minimumCost(int N, vector<vector<int>>& connections) {
vector<vector<vector<int> > > g(N + 1);
for(const vector<int> &connection: connections)
{
int x = connection[0], y = connection[1], w = connection[2];
g[x].push_back(vector<int>({w, y}));
g[y].push_back(vector<int>({w, x}));
}

priority_queue<vector<int> > pq;
pq.push({0, 1});
vector<int> visited(N + 1, false);
visited[0] = true;
int cost = 0;
while(!pq.empty())
{
vector<int> cur = pq.top();
pq.pop();
if(visited[cur[1]])
continue;
visited[cur[1]] = true;
cost += -cur[0];
for(vector<int> son: g[cur[1]])
{
if(visited[son[1]])
continue;
pq.push({-son[0], son[1]});
}
}
for(bool v: visited)
if(!v)
return -1;
return cost;
}
};

$2 下标索引堆优化 Prim

1
2
3
4
5
6
将 0 加入集合 T
while 尚有点没有进入集合 T
从剩下的点中,选出最短的边 `(u, v, w)`
其中 u 属于 T,v 不属于 T,即 uv 是横切边
cost += w
v 加入 T

点 u 是否在 T 中,依然用 visited 维护。查询最短的边对应的 v, w 之前是用堆维护的,限制用下标索引堆来代替。其中 v 相当于 idx,w 相当于 data[idx]。

(1) 对下标索引堆模板的修改

下标索引堆使用模板,其中根据需求增加以下几个接口

1
2
3
top_idx() := 当前最小值所在的下标
get(idx) := 查询下标处 data 的值
check(idx) := 查询下标处在堆(indexes) 的位置,-1 表示未压过堆,-2 表示之前压过现在已经弹出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int IndexMinHeap::top_idx()
{
if(empty())
return -1;
return indexes[1] - 1;
}

int IndexMinHeap::check(int idx)
{
++idx;
return mapping[idx];
}

int IndexMinHeap::get(int idx)
{
++idx;
return data[idx];
}

数据量 N 提前可以知道,因此构造时候可以之间把所需空间开出,但堆的 size 仍为 0。修改构造函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
IndexMinHeap(int n=0, int origin_val=INT_MAX/2)
{
_size = 0;
data.assign(n + 1, -1);
indexes.assign(n + 1, -1);
mapping.assign(n + 1, -1);
for(int i = 1; i <= n; ++i)
{
data[i] = origin_val;
indexes[i] = -1;
mapping[i] = -1;
}
}

除以上修改,下标索引堆其余部分复制模板

(2) 代码:下标索引堆优化邻接表 Prim

将二叉堆改为下标索引堆后,堆中最多的边数从 $O(E)$ 变为 $O(V)$,整体时间复杂度变为 $O(E\log V)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution {
public:
int minimumCost(int N, vector<vector<int>>& connections) {
vector<vector<vector<int> > > g(N + 1);
for(const vector<int> &connection: connections)
{
int x = connection[0], y = connection[1], w = connection[2];
g[x].push_back(vector<int>({w, y}));
g[y].push_back(vector<int>({w, x}));
}

IndexMinHeap heap(N + 1);
heap.push(1, 0);
vector<bool> visited(N + 1, false);
visited[0] = true;
int cost = 0;
while(!heap.empty())
{
int cur_idx = heap.top_idx();
int cur_cost = heap.top();
heap.pop();
if(visited[cur_idx])
continue;
visited[cur_idx] = true;
cost += cur_cost;
for(vector<int> son: g[cur_idx])
{
if(visited[son[1]])
continue;
if(heap.check(son[1]) < 0)
heap.push(son[1], son[0]);
else if(heap.get(son[1]) > son[0])
heap.change(son[1], son[0]);
}
}
for(bool v: visited)
if(!v)
return -1;
return cost;
}
};

(3) 完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class IndexMinHeap
{
public:
IndexMinHeap(int n=0, int origin_val=INT_MAX/2)
{
_size = 0;
data.assign(n + 1, -1);
indexes.assign(n + 1, -1);
mapping.assign(n + 1, -1);
for(int i = 1; i <= n; ++i)
{
data[i] = origin_val;
indexes[i] = -1;
mapping[i] = -1;
}
}

void build(const vector<int>& nums)
{
data.clear();
indexes.clear();
mapping.clear();
int n = nums.size();
data.assign(n + 1, -1);
indexes.assign(n + 1, -1);
mapping.assign(n + 1, -1);
for(int i = 1; i <= n; ++i)
{
data[i] = nums[i - 1];
indexes[i] = i;
mapping[i] = i;
}
_size = n;
for(int i = n; i >= 1; --i)
push_down(i);
}

int top()
{
if(empty())
return -1;
return data[indexes[1]];
}

int pop()
{
if(empty())
return -1;
int ans = data[indexes[1]];
_remove(1);
return ans;
}

void push(int idx, const int key)
{
++idx;
// 覆盖 data[idx] 的数据, idx 同时也是 key 的索引
// mapping[idx] := key 在堆中的逻辑位置,即 indexes 中的位置
if(idx >= (int)data.size())
dilatation();
++_size;
data[idx] = key;
mapping[idx] = _size;
indexes[_size] = idx;
push_up(_size);
}

void remove(int idx)
{
++idx;
if(mapping[idx] < 0)
return;
int i = mapping[idx];
_remove(i);
}

void change(int idx, const int new_key)
{
++idx;
if(mapping[idx] < 0)
return;
if(data[idx] == new_key)
return;
data[idx] = new_key;
int i = mapping[idx];
push_up(i);
push_down(i);
}

int size()
{
return _size;
}

bool empty()
{
return _size == 0;
}

private:
vector<int> data; // keys
vector<int> indexes;
int _size;
vector<int> mapping; // id -> idx, -1: 未插入过,-2 已经被删

void dilatation()
{
int new_capacity = (int)data.size() * 2 + 1;
vector<int> tmp_data(new_capacity, -1);
vector<int> tmp_indexes(new_capacity, -1);
vector<int> tmp_mapping(new_capacity, -1);
for(int i = 0; i < (int)data.size(); ++i)
{
tmp_data[i] = data[i];
tmp_indexes[i] = indexes[i];
tmp_mapping[i] = mapping[i];
}
data.swap(tmp_data);
indexes.swap(tmp_indexes);
mapping.swap(tmp_mapping);
}

void _remove(int i)
{
if(i > _size)
return;
if(i == _size)
{
mapping[indexes[_size]] = -2;
--_size;
return;
}
int idx_i = indexes[i];
int idx_j = indexes[_size];
swap(mapping[idx_i], mapping[idx_j]);
indexes[i] = indexes[_size--];
mapping[idx_i] = -2;
push_up(i);
push_down(i);
}

void push_up(int i)
{
if(i > _size) return;
while(i / 2 > 0 && data[indexes[i / 2]] > data[indexes[i]])
{
swap(mapping[indexes[i]], mapping[indexes[i / 2]]);
swap(indexes[i], indexes[i / 2]);
i /= 2;
}
}

void push_down(int i)
{
int ori = i, left = i * 2, right = i * 2 + 1;
if(left <= _size && data[indexes[left]] < data[indexes[ori]])
ori = left;
if(right <= _size && data[indexes[right]] < data[indexes[ori]])
ori = right;
if(ori != i)
{
swap(mapping[indexes[i]], mapping[indexes[ori]]);
swap(indexes[i], indexes[ori]);
push_down(ori);
}
}

public:
int top_idx()
{
if(empty())
return -1;
return indexes[1] - 1;
}

int check(int idx)
{
++idx;
return mapping[idx];
}

int get(int idx)
{
++idx;
return data[idx];
}
};

class Solution {
public:
int minimumCost(int N, vector<vector<int>>& connections) {
vector<vector<vector<int> > > g(N + 1);
for(const vector<int> &connection: connections)
{
int x = connection[0], y = connection[1], w = connection[2];
g[x].push_back(vector<int>({w, y}));
g[y].push_back(vector<int>({w, x}));
}

IndexMinHeap heap(N + 1);
heap.push(1, 0);
vector<bool> visited(N + 1, false);
visited[0] = true;
int cost = 0;
while(!heap.empty())
{
int cur_idx = heap.top_idx();
int cur_cost = heap.top();
heap.pop();
if(visited[cur_idx])
continue;
visited[cur_idx] = true;
cost += cur_cost;
for(vector<int> son: g[cur_idx])
{
if(visited[son[1]])
continue;
if(heap.check(son[1]) < 0)
heap.push(son[1], son[0]);
else if(heap.get(son[1]) > son[0])
heap.change(son[1], son[0]);
}
}
for(bool v: visited)
if(!v)
return -1;
return cost;
}
};

Share