回溯法的搜索树规模的上界估计

  |  

摘要: 暴力算法怎样实现评估可不可行

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


此前我们系统了解过回溯法,并解决了很多问题,参考文章 回溯法的思想、设计与分析。回溯法其实就是一种暴力方法,因此在实际问题中使用回溯法一般要依赖剪枝。能不能直接用回溯法,就取决于对回溯法时间复杂度的分析。

要分析回溯法的时间复杂度,主要是分析状态空间树的形态及其规模。形态方面,在文章 回溯法三种常见的状态空间树:子集树、排列树、满m叉树 中我们了解到回溯法的状态空间树往往是排列树、子集树、满m叉树等。规模方面,一般要分析一下最坏情况下树会有多深,每个节点最多会有多少个子节点等等。

如果经过分析,发现状态空间树的规模可以推出一个上界,而基于这个上界的运行时间是满足要求的,就可以直接用回溯法了。本文我们看一个例子。

题目

给你一张 无向 图,图中有 n 个节点,节点编号从 0 到 n - 1 (都包括)。同时给你一个下标从 0 开始的整数数组 values ,其中 values[i] 是第 i 个节点的 价值 。同时给你一个下标从 0 开始的二维整数数组 edges ,其中 edges[j] = [uj, vj, timej] 表示节点 uj 和 vj 之间有一条需要 timej 秒才能通过的无向边。最后,给你一个整数 maxTime

合法路径 指的是图中任意一条从节点 0 开始,最终回到节点 0 ,且花费的总时间 不超过 maxTime 秒的一条路径。你可以访问一个节点任意次。一条合法路径的 价值 定义为路径中 不同节点 的价值 之和 (每个节点的价值 至多 算入价值总和中一次)。

请你返回一条合法路径的 最大 价值。

注意:每个节点 至多 有 四条 边与之相连。

提示:

1
2
3
4
5
6
7
8
9
10
n == values.length
1 <= n <= 1000
0 <= values[i] <= 1e8
0 <= edges.length <= 2000
edges[j].length == 3
0 <= uj < vj <= n - 1
10 <= timej, maxTime <= 100
[uj, vj] 所有节点对 互不相同 。
每个节点 至多有四条 边。
图可能不连通。

示例 1:

输入:values = [0,32,10,43], edges = [[0,1,10],[1,2,15],[0,3,10]], maxTime = 49
输出:75
解释:
一条可能的路径为:0 -> 1 -> 0 -> 3 -> 0 。总花费时间为 10 + 10 + 10 + 10 = 40 <= 49 。
访问过的节点为 0 ,1 和 3 ,最大路径价值为 0 + 32 + 43 = 75 。

示例 2:

输入:values = [5,10,15,20], edges = [[0,1,10],[1,2,10],[0,3,10]], maxTime = 30
输出:25
解释:
一条可能的路径为:0 -> 3 -> 0 。总花费时间为 10 + 10 = 20 <= 30 。
访问过的节点为 0 和 3 ,最大路径价值为 5 + 20 = 25 。

示例 3:

输入:values = [1,2,3,4], edges = [[0,1,10],[1,2,11],[2,3,12],[1,3,13]], maxTime = 50
输出:7
解释:
一条可能的路径为:0 -> 1 -> 3 -> 1 -> 0 。总花费时间为 10 + 13 + 13 + 10 = 46 <= 50 。
访问过的节点为 0 ,1 和 3 ,最大路径价值为 1 + 2 + 4 = 7 。

示例 4:

输入:values = [0,1,2], edges = [[1,2,10]], maxTime = 10
输出:0
解释:
唯一一条路径为 0 。总花费时间为 0 。
唯一访问过的节点为 0 ,最大路径价值为 0 。

题解

算法:回溯法

从 0 号顶点出发,由于顶点可以重复到达多次,因此每一步可以沿着任意一条边走,走到剩余时间耗尽时回溯,过程中每次经过起点即完成一条合法路径。

如果用回溯法的话,第一感觉是时间复杂度过高,不可行,因为图的顶点数是 $N \leq 1000$,很大。

但是提示中有两条最关键的数据限制:

第一个是 10 <= timej, maxTime <= 100,这样的话假设可用的总时间为 100 而每条边的长度均为 10,那么最多走 10 步就会把剩余时间耗尽。也就是说回溯法的状态空间树最多有 11 层。

第二个是每个顶点至多有四条边,也就是说回溯法的状态空间树中,每个节点至多有 4 个子节点。这样最坏情况下第一层 1 个节点,第二层 4 个节点,直至第 11 层 $4^{10}$ 个节点。

有了以上两条,回溯法的状态空间树中节点数目的上界是 $4^{0} + 4^{1} + 4^{2} + \cdots + 4^{10} = \frac{4}{3}(4^{10} - 1) = 1398100$,是 $1e6$ 级别,满足时间要求。

因此可以直接通过回溯法暴力搜索所有可能的路径,从编号为 0 的顶点出发,沿着所有可能的边,按 DFS 去走,当剩余时间耗尽时回溯。每次重新回到编号为 0 的顶点,都是一个合法路径,用此时的路径价值更新答案即可。

过程中我们要记录历史路径中经过的顶点的总价值,用变量 path_val,由于每个顶点只能计算一次价值,因此还需要记录顶点的访问次数,用哈希映射 mapping 即可。此外还需要一个变量 t 记录剩余的时间。

代码 (Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution:
def maximalPathQuality(self, values: List[int], edges: List[List[int]], maxTime: int) -> int:
n = len(values)
g = [[] for _ in range(n)]
for u, v, w in edges:
g[u].append((v, w))
g[v].append((u, w))

def dfs(u: int, t: int, path_val: int) -> None:
nonlocal ans

if mapping[u] == 0:
path_val += values[u]
mapping[u] += 1

if u == 0:
ans = max(ans, path_val)

for (v, w) in g[u]:
if w > t:
continue
dfs(v, t - w, path_val)

mapping[u] -= 1
if mapping[u] == 0:
path_val -= values[u]

mapping = [0 for _ in range(n)]
ans = 0
path_val = 0

dfs(0, maxTime, path_val)

return ans

剪枝:使用最短路径信息

在前面的回溯法中,只要剩余时间没有耗尽,并且下一条边的时间不大于剩余时间,那么就沿着下一条边继续走。

但是由于合法路径必须以编号为 0 的点为终点,因此在回溯法的过程中,在当前这一步在节点 $u$,如果剩余时间 $t$ 已经无法支付从 $u$ 到 $0$ 的最短路径,那么实际上可以直接回溯了,即使下一条边 $uv$ 的时间比剩余时间 $t$ 还小。因为继续走下去不可能走到合法路径了。

综上,我们可以用最短路径的信息进行剪枝。首先用 Dijkstra 算法与处理处每个顶点 $u$ 到 $0$ 的最短路径 d[u],然后在 DFS 过程中,每一步都先做一步判断,看剩余时间是否满足 t >= d[u],若不满足则直接回溯。

代码 (Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Solution:
def maximalPathQuality(self, values: List[int], edges: List[List[int]], maxTime: int) -> int:
INF = int(1e9)
n = len(values)
g = [[] for _ in range(n)]
for u, v, w in edges:
g[u].append((v, w))
g[v].append((u, w))

def dijkstra(s: int) -> List[int]:
d = [INF] * n
d[s] = 0
heap_data = []
heapq.heappush(heap_data, (0, s))

while heap_data:
min_d, u = heapq.heappop(heap_data)
if d[u] < min_d:
continue
for v, w in g[u]:
if d[v] <= d[u] + w:
continue
d[v] = d[u] + w
heapq.heappush(heap_data, (d[v], v))

return d

def dfs(u: int, t: int, path_val: int) -> None:
if t < d[u]:
return

nonlocal ans

if mapping[u] == 0:
path_val += values[u]
mapping[u] += 1

if u == 0:
ans = max(ans, path_val)

for (v, w) in g[u]:
if w > t:
continue
dfs(v, t - w, path_val)

mapping[u] -= 1
if mapping[u] == 0:
path_val -= values[u]

mapping = [0 for _ in range(n)]
ans = 0
path_val = 0
d = dijkstra(0)
dfs(0, maxTime, path_val)

return ans

Share