【树形DP】树的直径

  |  

摘要: 力扣 1245, 1522,树的直径,最经典的树形 DP 题目

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


各位好,我们继续研究力扣秋季赛的题目。在十一之前参加了力扣秋季赛个人赛,整体情况可以参考这篇文章 2022力扣秋季赛个人赛战报。 团队赛由于安排在了十一假期之间,所以没有参加。等个人赛的题目更新完之后考虑自己做一下看看,不过按照往年经验,除了第一题是中等之外,剩下的 5 道题应该都是 hard,只能有时间慢慢研究了。

个人赛的第 4 题是一道树形 DP 的题,比赛拿到题的时候,实际上直接就看出来了是树形 DP 的问题,只是想不出状态定义和状态转移具体是怎么样的。事后看了题解区,感觉这应该是我见过的状态最复杂的动态规划了。

本文我们先看一个树形 DP 的经典题:树的直径,体会一下树形 DP 的基本思想,本题也是一个一题多解的题,除了树形 DP 之外还有贪心的解法。


$1 题目

给你这棵「无向树」,请你测算并返回它的「直径」:这棵树上最长简单路径的 边数。

我们用一个由所有「边」组成的数组 edges 来表示一棵无向树,其中 edges[i] = [u, v] 表示节点 u 和 v 之间的双向边。

树上的节点都已经用 {0, 1, …, edges.length} 中的数做了标记,每个节点上的标记都是独一无二的。

提示:

1
2
3
4
0 <= edges.length < 10^4
edges[i][0] != edges[i][1]
0 <= edges[i][j] <= edges.length
edges 会形成一棵无向树

示例 1:
输入:edges = [[0,1],[0,2]]
输出:2
解释:
这棵树上最长的路径是 1 - 0 - 2,边数为 2。
示例 2:
输入:edges = [[0,1],[1,2],[2,3],[1,4],[4,5]]
输出:4
解释:
这棵树上最长的路径是 3 - 2 - 1 - 4 - 5,边数为 4。

$2 题解

算法1: 树形DP

在树形结构上求解问题,如果在某棵以 u 为根的树上的答案为 f(u),v 是 u 的子节点,以 v 为根的树,也就是 u 的子树上的答案为 f(v),如果 f(v) 构成 f(u) 的重复子问题的话,就可以用树形 DP 思路考虑了。

树形 DP 的状态转移方向就是从子节点到父节点,过程就是假设当前在节点 u,首先拿到各个以 u 的子节点 v 为根的子树的答案,然后进行整合,形成当前树的答案,然后继续向上传。

对于本题来说,树中的每个节点 u,我们都求一个经过 u 的最长链的长度,当树中的所有节点遍历完就可以得到整个树的最长链的长度了。

假设在遍历过程中,当前节点为 u,经过当前节点 u 的最长链可以被 u 分为两部分。

考虑这两部分与以 u 为根的子树的关系,有两种情况:

(1) 一部分为以 u 为根的子树从 u 走到子树的叶子的最长链,另一部分不在以 u 为根的子树上;
(2) 一部分为以 u 为根的子树从 u 走到子树的叶子的最长链,另一部分为以 u 为根的子树从 u 走到子树的叶子的次长链;

下面我们写出动态规划的状态定义和转移方程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
状态定义:
dp[u][0] := 以 u 为根且包含 u 的最长链长度
dp[u][1] := 以 u 为根且包含 u 的次长链长度

答案:
用 dp[u][0] + dp[u][1] 更新 ans

初始化:
dp[u][0] = 0 u 为叶子节点
dp[u][1] = 0 u 为叶子节点

状态转移:
dp[u][0] = max(dp[v][0]) + 1 v 为 u 的子节点
dp[u][1] = second_max(dp[v][0]) + 1 v 为 u 的子节点

代码 (C++)

注意到在节点 u 时,状态转移过程只需要 dp[v][0],而不需要 dp[v][1],因此 dfs 仅返回 dp[v][0] 即可。

代码中 max1 表示 dp[u][0], max2 表示 dp[u][1]。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
public:
int treeDiameter(vector<vector<int>>& edges) {
int n = edges.size();
vector<vector<int> > g(n + 1);
for(const auto &e: edges)
{
g[e[0]].push_back(e[1]);
g[e[1]].push_back(e[0]);
}
int ans = 0;
dfs(0, -1, g, ans);
return ans;
}

private:
int dfs(int u, int fa, const vector<vector<int> >& g, int& ans)
{
int max1 = 0, max2 = 0;
for(int v: g[u])
{
if(v != fa)
{
int t = dfs(v, u, g, ans) + 1;
if(max1 < t)
{
max2 = max1;
max1 = t;
}
else if(max2 < t)
max2 = t;
}
}
ans = max(ans, (max1 + max2));
return max1;
}
};

代码 (Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
def treeDiameter(self, edges: List[List[int]]) -> int:
n = len(edges)
print(n)
self.g = [[] for i in range(n + 1)]
for e in edges:
self.g[e[0]].append(e[1])
self.g[e[1]].append(e[0])
self.ans = 0
self.dfs(0, -1)
return self.ans

def dfs(self, u, fa):
max1 = max2 = 0
for v in self.g[u]:
if v != fa:
t = self.dfs(v, u) + 1
if max1 < t:
max2 = max1
max1 = t
elif max2 < t:
max2 = t
self.ans = max(self.ans, (max1 + max2))
return max1

代码(Java)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution {
private int ans;

public int treeDiameter(int[][] edges) {
int n = edges.length;
ArrayList<Integer>[] g = new ArrayList[n + 1];
for(int i = 0; i <= n; i++) {
g[i] = new ArrayList<>();
}
for(int[] edge: edges) {
g[edge[0]].add(edge[1]);
g[edge[1]].add(edge[0]);
}
ans = 0;
dfs(0, -1, g);
return ans;
}
private int dfs(int u, int fa, ArrayList<Integer>[] g) {
int max1 = 0;
int max2 = 0;
for(int v: g[u]) {
if(v != fa) {
int t = dfs(v, u, g) + 1;
if(max1 < t) {
max2 = max1;
max1 = t;
} else if (max2 < t){
max2 = t;
}
}
}
ans = Math.max(ans, (max1 + max2));
return max1;
}
}

算法2: 贪心 (两次搜索)

贪心也是本题的一种做法,这里简要看一下,算法如下。

以任意一个点 u 为起点,BFS 或 DFS 找到距离 u 最远的点 x,然后再以 x 为起点,找到距离 x 最远的点 y。xy 就是树的直径。

代码(java)

这里用 dfs 实现,两次 dfs 可以用一个函数,每次返回最远的点以及最远的点 x 对应的距离 max_len。第一次调用返回的最远的点 x 作为第二次调用的起点,第二次调用返回的最远的距离 max_len 为结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Solution {
private int x; // 最远的点 x
private int max_len; // 最长距离

public int treeDiameter(int[][] edges) {
int n = edges.length;
ArrayList<Integer>[] g = new ArrayList[n + 1];
for(int i = 0; i <= n; i++) {
g[i] = new ArrayList<>();
}
for(int[] edge: edges) {
g[edge[0]].add(edge[1]);
g[edge[1]].add(edge[0]);
}
max_len = 0;
x = 0;
dfs(0, -1, 0, g);
dfs(x, -1, 0, g);
return max_len;
}
private void dfs(int u, int fa, int d, ArrayList<Integer>[] g) {
if(d > max_len) {
max_len = d;
x = u;
}
for(int v: g[u]) {
if(v != fa) {
dfs(v, u, d + 1, g);
}
}
}
}

一个类似的题目

给定一棵 N 叉树的根节点 root ,计算这棵树的直径长度。

N 叉树的直径指的是树中任意两个节点间路径中 最长 路径的长度。这条路径可能经过根节点,也可能不经过根节点。

(N 叉树的输入序列以层序遍历的形式给出,每组子节点用 null 分隔)

提示:

1
2
N 叉树的深度小于或等于 1000 。
节点的总个数在 [0, 10^4] 间。

算法: 树形DP

上面的是以邻接表的形式给出的无根树上求直径。如果是以指针形式给出的有根树上求树的直径,算法还是树形 DP,只是在实现上以后序遍历的方式来完成状态的计算和转移。

贪心也是可以的,但是需要先遍历一遍建立邻接表,然后在完成贪心的过程,稍微麻烦一步。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Solution {
public:
int diameter(Node* root) {
int ans = 0;
_postOrder(root, ans);
return ans;
}

private:
int _postOrder(Node *node, int& ans)
{
int h1 = -1, max2 = -1;
for(Node *child: node -> children)
{
int h = _postOrder(child, ans);
if(max1 == -1 || h > max1)
{
max2 = max1;
max1 = h;
}
else if(max2 == -1 || h > max2)
{
max2 = h;
}
}
if(max1 == -1)
return 0;
else if(max2 == -1)
{
ans = max(ans, max1 + 1);
return max1 + 1;
}
else
{
ans = max(ans, max1 + max2 + 2);
return max1 + 1;
}
}
};

Share