k短路问题:AStar启发式搜索

  |  

摘要: AStar 算法经典问题:k短路

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


给定一个有向图,求 S 到 T 的最短路径长度是图算法中的基本问题,算法很多,参考文章 带权图最短路径算法与实现

如果要求第 K 短的路,就不那么简单了,通过 AStar 启发式搜索是一种解决方法。本题可以作为 AStar 启发式搜索的经典例子。

K 短路问题

给定一张N个点(编号1,2…N),M条边的有向图,求从起点S到终点T的第K短路的长度,路径允许重复经过点或边。

注意: 每条最短路中至少要包含一条边。

1
2
3
4
5
6
7
8
9
数据格式:
第一行包含两个整数N和M。
接下来M行,每行包含三个整数A,B和L,表示点A与点B之间存在有向边,且边长为L。
最后一行包含三个整数S,T和K,分别表示起点S,终点T和第K短路。

1 <= S,T <= N <= 1000
0 <= M <= 1e5
1 <= K <= 1000
1 <= L <= 100

算法:Dijkstra

源点 S 到所有点的单源最短路可以求出来,$d[u]$ 表示 $u$ 到 $s$ 的最短路径。这是图上的基本算法,参考以下文章:

这里是没有负权的单源最短路问题,适用 Dijkstra 算法,考虑 Dijkstra 算法的过程,优先级队列中的每个元素记录节点编号 $u$ 以及到源点的距离 $d$,优先级队列中状态按照 $d$ 从小到大排序。

当 $u$ 第一次出队的时候,弹出的信息中记录的到源点 $s$ 的距离即为 $u$ 到 $s$ 的最短路径,记为 $d[u]$。

如果优先级队列用的是普通的二叉堆,则当 $u$ 第一次出队后,队列中可能还有若干个 $u$ 节点,只是其对应的到源点 $s$ 的距离大于等于 $d[u]$。正因为队列中节点 $u$ 可能会有重复压入的情况($s$ 到 $u$ 有多条路径),堆中元素的规模就是 $O(E)$($E$ 为边数),而不是 $O(V)$,于是 Dijkstra 算法时间复杂度为 $O(E\log E)$。

此前我们也通过把普通二叉堆改为索引堆,将堆中元素个数从 $O(E)$ 降为 $O(V)$,总时间复杂度变为 $O(E\log V)$,参考文章:

如果我们保留堆中相同顶点 $u$ 对应的每个元素,这样 $u$ 弹出的顺序就是 $s$ 到 $u$ 的路径长度从小到大的顺序。那么当 $u$ 第 $k$ 次弹出的时候,就对应 $u$ 到 $s$ 的第 $k$ 短路径。

因此我们的基本算法就是在 Dijkstra 的基础上,取 $T$ 第 $k$ 次出队时对应的路径长度即可。后面我们会看到,直接跑这个算法会超时,在此基础上需要增加一些 AStar 启发式搜索。我们先写基本的 Dijkstra 算法,然后再增加启发式搜索 AStar 的策略。

状态设计

当节点 T 第 K 次出队的时候,得到 $S \rightarrow T$ 的 K 短路。

1
2
3
4
5
6
7
8
9
10
11
12
13
struct State
{
int u, d;
State(int u, int d):u(u),d(d){}
};

struct Cmp
{
bool operator()(const State& s1, const State& s2) const
{
return s1.d > s2.d;
}
};

一些细节

  • $S \rightarrow T$ 如果不可达,则直接不求解 k 短路就可以返回 -1 了。
  • 优先级队列迭代过程中,需要能感知到队列中尚存在可以到达 T 的点。如果队列仍在迭代但是已经没有可以到达 T 的点了,则 $S \rightarrow T$ 的下一条路不会来了,返回 -1。
  • 当扩展到的节点 y 已经被取出 K 次的时候,就没有必要再压进队了

以上三个问题可以一起处理:建图时候同时建立不带权的反图 rg,然后以 T 为源点,用哈希表记录所有 T 可以到达的点。

1
2
3
4
5
6
7
8
9
10
void check(const vector<vector<int>>& rg, const int T, unordered_set<int>& setting)
{
for(const int &son: rg[T])
{
if(setting.count(son))
continue;
setting.insert(son);
check(rg, son, setting);
}
}

对于第一个问题,如果 S 不在 setting 中,则 $S \rightarrow T$ 不可达:

1
2
3
4
5
if(!setting.count(S))
{
cout << "-1" << endl;
return 0;
}

对于第二个问题和第三个问题,可以在节点压队之前同时处理:

  • 队列迭代时维护一个 cnts[i] 表示 i 被取出的次数。
  • 对于待入队的节点,只有属于可达 T 的节点集合,同时从队列中弹出的次数小于 K 的时候才压队。
1
2
3
4
5
for(const To &son: g[cur.u])
{
if(setting.count(son.v) && cnts[son.v] > 0)
pq.push(State(son.v, cur.d + son.w));
}
  • S = T 的情况,也是需要先出去再进来的过程的才能算是有路的(可以是自环)。
1
2
if(S == T)
++cnts[S];

代码 (C++)

注:超时。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <vector>
#include <iostream>
#include <fstream>
#include <queue>
#include <climits>
#include <unordered_set>

using namespace std;

struct State
{
int u, d;
State(int u, int d):u(u),d(d){}
};

struct Cmp
{
bool operator()(const State& s1, const State& s2) const
{
return s1.d > s2.d;
}
};

struct To
{
int v, w;
To(int v, int w):v(v),w(w){}
};

void check(const vector<vector<int>>& rg, const int T, unordered_set<int>& setting)
{
for(const int &son: rg[T])
{
if(setting.count(son))
continue;
setting.insert(son);
check(rg, son, setting);
}
}

int solve(const vector<vector<To>>& g, int S, int T, int K, const unordered_set<int>& setting)
{
priority_queue<State, vector<State>, Cmp> pq;
pq.push(State(S, 0));
int N = g.size() - 1;
vector<int> cnts(N + 1, K);
if(S == T)
++cnts[S];
while(!pq.empty())
{
State cur = pq.top();
pq.pop();
--cnts[cur.u];
if(cnts[T] == 0)
return cur.d;
for(const To &son: g[cur.u])
{
if(setting.count(son.v) && cnts[son.v] > 0)
pq.push(State(son.v, cur.d + son.w));
}
}
return -1;
}

int main()
{
int N, M;
cin >> N >> M;
vector<vector<To>> g(N + 1);
vector<vector<int>> rg(N + 1); // 用于求所有可以到达 T 的上游节点
for(int i = 1; i <= M; ++i)
{
int a, b, l;
cin >> a >> b >> l;
g[a].push_back(To(b, l));
rg[b].push_back(a);
}
int S, T, K;
cin >> S >> T >> K;

// 记录反图中 T 可以到达的所有节点集合 setting
unordered_set<int> setting;
check(rg, T, setting);
if(!setting.count(S))
{
cout << "-1" << endl;
return 0;
}
int ans = solve(g, S, T, K, setting);
if(ans == -1)
cout << "-1" << endl;
else
cout << ans << endl;
}

优化:Astar

关于 AStar 算法的思想与实现,可以参考文章 Astar 算法的原理:优先级队列BFS+估价函数

这里直接走流程:在优先级队列 BFS (Dijkstra) 基础上增加估价函数、并修改相应的状态设计。

估价函数

在反图 rg 上将终点 T 到各个点的最短距离预处理出来,记为 rd[u] 表示 u 到 T 的最短路径。

则估价函数可以定义如下:

1
2
3
4
int h(const int u)
{
return rd[u];
}

状态设计

增加 h 字段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
struct State
{
int u, d;
int h;
State(int u, int d, int h=0):u(u),d(d),h(h){}
};

struct Cmp
{
bool operator()(const State& s1, const State& s2) const
{
return s1.d + s1.h > s2.d + s2.h;
}
};

预处理 rd[u]

由于计算估价函数需要 rd[u] 信息,需要先预处理出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
vector<int> dijkstra(const vector<vector<To>>& rg, const int T)
{
priority_queue<State, vector<State>, Cmp> pq;
pq.push(State(T, 0));
int N = rg.size() - 1;
vector<int> d(N + 1, INT_MAX / 2);
while(!pq.empty())
{
State cur = pq.top();
pq.pop();
if(d[cur.u] < cur.d)
continue;
for(const To &son: rg[cur.u])
{
if(cur.d + son.w >= d[son.v])
continue;
d[son.v] = cur.d + son.w;
pq.push(State(son.v, cur.d + son.w));
}
}
return d;
}

本题顶点是 1000 级别,边是 100000 级别,因此预处理 rd[u] 应该用数组实现 dijkstra

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
vector<int> dijkstra_array(const vector<vector<To>>& rg, const int T)
{
int N = rg.size() - 1;
vector<int> d(N + 1, INT_MAX / 2);
for(const To &son: rg[T])
d[son.v] = son.w;
vector<bool> got(N + 1, false);
for(int cnt = 1; cnt <= N - 1; ++cnt)
{
int minx = INT_MAX / 2;
int u = -1;
for(int i = 1; i <= N; ++i)
{
if(!got[i] && d[i] < minx)
{
minx = d[i];
u = i;
}
}
if(u == -1)
return d;
got[u] = true;
for(const To &son: rg[u])
{
d[son.v] = min(d[son.v], d[u] + son.w);
}
}
return d;
}

可达 T 的节点集合

预处理出 rd[u] 之后,所有 rd[u] != INT_MAX / 2 的都是,判断节点是否属于可达 T 的集合时,直接用预处理出的 rd 就可以了。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <vector>
#include <iostream>
#include <fstream>
#include <queue>
#include <climits>
#include <unordered_set>

using namespace std;

struct State
{
int u, d;
int h;
State(int u, int d, int h=0):u(u),d(d),h(h){}
};

struct Cmp
{
bool operator()(const State& s1, const State& s2) const
{
return s1.d + s1.h > s2.d + s2.h;
}
};

struct To
{
int v, w;
To(int v, int w):v(v),w(w){}
};

vector<int> dijkstra(const vector<vector<To>>& rg, const int T)
{
priority_queue<State, vector<State>, Cmp> pq;
pq.push(State(T, 0));
int N = rg.size() - 1;
vector<int> d(N + 1, INT_MAX / 2);
while(!pq.empty())
{
State cur = pq.top();
pq.pop();
if(d[cur.u] < cur.d)
continue;
for(const To &son: rg[cur.u])
{
if(cur.d + son.w >= d[son.v])
continue;
d[son.v] = cur.d + son.w;
pq.push(State(son.v, cur.d + son.w));
}
}
return d;
}

vector<int> dijkstra_array(const vector<vector<To>>& rg, const int T)
{
int N = rg.size() - 1;
vector<int> d(N + 1, INT_MAX / 2);
for(const To &son: rg[T])
d[son.v] = son.w;
vector<bool> got(N + 1, false);
for(int cnt = 1; cnt <= N - 1; ++cnt)
{
int minx = INT_MAX / 2;
int u = -1;
for(int i = 1; i <= N; ++i)
{
if(!got[i] && d[i] < minx)
{
minx = d[i];
u = i;
}
}
if(u == -1)
return d;
got[u] = true;
for(const To &son: rg[u])
{
d[son.v] = min(d[son.v], d[u] + son.w);
}
}
return d;
}

int h(const int s, const vector<int>& rd)
{
return rd[s];
}

int solve(const vector<vector<To>>& g, int S, int T, int K, const vector<int>& rd)
{
priority_queue<State, vector<State>, Cmp> pq;
pq.push(State(S, 0, h(S, rd)));
int N = g.size() - 1;
vector<int> cnts(N + 1, K);
if(S == T)
++cnts[T];
while(!pq.empty())
{
State cur = pq.top();
pq.pop();
--cnts[cur.u];
if(cnts[T] == 0)
return cur.d;
for(const To &son: g[cur.u])
{
if(rd[son.v] < INT_MAX / 2 && cnts[son.v] > 0)
pq.push(State(son.v, cur.d + son.w, h(son.v, rd)));
}
}
return -1;
}

int main()
{
int N, M;
cin >> N >> M;
vector<vector<To>> g(N + 1);
vector<vector<To>> rg(N + 1); // 用于求所有可以到达 T 的上游节点
for(int i = 1; i <= M; ++i)
{
int a, b, l;
cin >> a >> b >> l;
g[a].push_back(To(b, l));
rg[b].push_back(To(a, l));
}
int S, T, K;
cin >> S >> T >> K;

vector<int> rd = dijkstra_array(rg, T);
// vector<int> rd = dijkstra(rg, T);
if(rd[T] == INT_MAX / 2)
{
cout << "-1" << endl;
return 0;
}
int ans = solve(g, S, T, K, rd);
if(ans == -1)
cout << "-1" << endl;
else
cout << ans << endl;
}

Share