【DP难题】力扣1617-统计子树中城市之间最大距离

  |  

摘要: 力扣 1617,比较难的树形 DP

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


今天我们来看一个比较难的图论题,力扣 1617,也就是第 210 周赛 D 题。抽象之后是求树的直径的问题,树的直径问题在文章 【树形DP】树的直径 中研究过,本题在树的直径的基础上增加了一些复杂性,可以理解为要求所有子树的直径,直接把以下三个基础问题的解法串到一起即可解决:

  • 枚举子集
  • 判断连通性
  • 树形 DP 求树的直径

本题也可以直接用树形 DP 解决,但是状态的定义和状态转移过程都很复杂。此外本题还有组合数学的解法,当然要解决组合计数的难点:如何不重不漏地把所有可能得方案枚举到。


$1 题目

给你 n 个城市,编号为从 1 到 n 。同时给你一个大小为 n-1 的数组 edges ,其中 edges[i] = [ui, vi] 表示城市 ui 和 vi 之间有一条双向边。题目保证任意城市之间只有唯一的一条路径。换句话说,所有城市形成了一棵 树 。

一棵 子树 是城市的一个子集,且子集中任意城市之间可以通过子集中的其他城市和边到达。两个子树被认为不一样的条件是至少有一个城市在其中一棵子树中存在,但在另一棵子树中不存在。

对于 d 从 1 到 n-1 ,请你找到城市间 最大距离 恰好为 d 的所有子树数目。

请你返回一个大小为 n-1 的数组,其中第 d 个元素(下标从 1 开始)是城市间 最大距离 恰好等于 d 的子树数目。

请注意,两个城市间距离定义为它们之间需要经过的边的数目。

提示:

1
2
3
4
5
2 <= n <= 15
edges.length == n-1
edges[i].length == 2
1 <= ui, vi <= n
题目保证 (ui, vi) 所表示的边互不相同。

示例 1:
输入:n = 4, edges = [[1,2],[2,3],[2,4]]
输出:[3,4,0]
解释:
子树 {1,2}, {2,3} 和 {2,4} 最大距离都是 1 。
子树 {1,2,3}, {1,2,4}, {2,3,4} 和 {1,2,3,4} 最大距离都为 2 。
不存在城市间最大距离为 3 的子树。

示例 2:
输入:n = 2, edges = [[1,2]]
输出:[1]

示例 3:
输入:n = 3, edges = [[1,2],[2,3]]
输出:[2,1]

$2 题解

算法1: 枚举子集+判断连通性+求树的直径

枚举所有可能的子树,具体地就是先枚举顶点集的所有子集,然后判断点集是否连通,如果连通的话则该子集构成一棵子树,在此基础上我们在求该子树的直径。

代码 (C++)

$O(2^{n}n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class TreeDiameter {
public:
int solve(const vector<vector<int>>& g, int s) {
int ans = 0;
dfs(s, -1, g, ans);
return ans;
}

private:
int dfs(int x, int fa, const vector<vector<int> >& g, int& ans)
{
int max1 = 0, max2 = 0;
for(int v: g[x])
{
if(v != fa)
{
int t = dfs(v, x, g, ans) + 1;
if(max1 < t)
{
max2 = max1;
max1 = t;
}
else if(max2 < t)
max2 = t;
}
}
ans = max(ans, (max1 + max2));
return max1;
}
};

class Solution {
public:
vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) {
vector<vector<int>> chosen_edge;
TreeDiameter diameter_solver;
vector<int> result(n - 1);
for(int state = 1; state < (1 << n); ++state)
{
chosen_edge.clear();
int check_state = 0;
for(const vector<int> &e: edges)
{
if(((state >> (e[0] - 1)) & 1) && ((state >> (e[1] - 1)) & 1))
{
check_state |= (1 << (e[0] - 1));
check_state |= (1 << (e[1] - 1));
chosen_edge.push_back(e);
}
}
// state 标记的点都有边选上了
if(check_state != state)
continue;
vector<vector<int> > g(n + 1);
for(const auto &e: chosen_edge)
{
g[e[0]].push_back(e[1]);
g[e[1]].push_back(e[0]);
}
int s = chosen_edge[0][0];
if(!connect(g, state, s))
continue;
int d = diameter_solver.solve(g, s);
++result[d - 1];
}
return result;
}

private:
bool connect(const vector<vector<int>>& g, int state, int s)
{
dfs(g, s, -1, state);
return state == 0;
}

void dfs(const vector<vector<int>>& g, int u, int prev, int& state)
{
state &= ~(1 << (u - 1));
for(int v: g[u])
if(v != prev)
dfs(g, v, u, state);
}
};

算法2: 树形DP

如下图,u 是 v 的父节点,图中的两个三角形分别代表两个子树集合:

其中:

  • 根为 u 的三角形代表 u 为根节点,深度为 ju,直径为 ku 的子树集合,集合中的子树个数就是方案数,记为 dp[u][ju][ku]
  • 根为 v 的三角形代表 v 为根节点,深度为 jv,直径为 kv 的子树集合,集合中的子树个数就是方案数,记为 dp[v][jv][kv]

如果我们知道以 u 为根节点,深度为 ju,直径为 ku 的方案数(dp[u][ju][ku])、以及以 v 为根节点,深度为 jv,直径为 kv 的方案数(dp[v][jv][kv]),则合并地考虑图中两个三角形的并集代表的子树子树集合。该集合中的子树,其深度为 max(ju, jv + 1);其直径为 max(ku, kv, ju + jv + 1),这两个结论可以结合前面的图来理解。又乘法原理,有:

1
dp[u][max(ju, jv + 1)][max(ku, kv, ju + jv + 1)] = dp[u][ju][ku] * dp[v][jv][kv]

于是我们可以写出树形 DP 算法,如下:

1
2
3
4
5
6
7
8
9
10
11
12
状态定义:
dp[u][j][k] := 以 u 为根,深度为 j,直径为 k 时的方案数

答案:
result[d - 1] = sum(dp[u][j][d])

初始化:
dp[u][0][0] = 1 u 为所有节点

状态转移:
u 为 v 的父节点
dp[u][max(ju, jv + 1)][max(kv, ku, ju + jv + 1)] = dp[v][jv][kv] * dp[u][ju][ku]

用 dp[v] 对 dp[u] 更新时需要用到 dp[u] 本身,此时 dp[u] 中不含来自 v 的贡献,计算过程中也要保持 dp[u] 中不含来自 v 的贡献。也就是在 (u, v) 返回结果后用 dp[v][...][...] 更新 dp[u][...][...] 时,不能直接在 dp[u][jf][kf] 上更新,因为计算过程中 dp[u][jf][kf] 中不能含有 v 对其的贡献。

需要将 dp[u][jf][kf] 中来自 (u, v) 的贡献,即对 v 的求解返回后用 dp[v][j][k]dp[u][jf][kf] 更新的那部分贡献缓存起来 tmp[jf][kf],最后一起加到 dp[u][jf][kf] 上。

代码 (C++)

$O(N^{4})$,证明比较复杂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Solution {
public:
vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) {
dp = vector<vector<vector<int>>>(n, vector<vector<int>>(n, vector<int>(n, 0)));
vector<vector<int>> g(n); // 建立有根树
for(const vector<int>& e: edges)
{
g[e[0] - 1].push_back(e[1] - 1);
g[e[1] - 1].push_back(e[0] - 1);
}
result.assign(n - 1, 0);
dfs(0, -1, g);
return result;
}

private:
vector<vector<vector<int>>> dp;
vector<int> result;
vector<vector<int>> tmp;

void dfs(int u, int fa, const vector<vector<int>>& g)
{
int n = g.size();
dp[u][0][0] = 1;
for(int v: g[u])
if(v != fa)
{
dfs(v, u, g);
tmp.assign(n, vector<int>(n, 0));
for(int jf = 0; jf < n; ++jf)
for(int kf = jf; kf < n; ++kf)
{
for(int j = 0; j + kf + 1 < n; ++j)
for(int k = j; k < n; ++k)
{
int deep = max(jf, j + 1);
int d = max(max(k, kf), jf + j + 1);
tmp[deep][d] += dp[v][j][k] * dp[u][jf][kf];
}
}
// 将 (u, v) 对 dp[u][jf][kf] 的贡献一起加过来
for(int deep = 1; deep < n; ++deep)
for(int d = deep; d < n; ++d)
{
dp[u][deep][d] += tmp[deep][d];
result[d - 1] += tmp[deep][d];
}
}
}
};

算法3: 枚举端点对 + 组合数学

在算法1中,我们考虑的是枚举所有的子树,求出该子树对应的直径 d,然后将 d 对应的数目加 1。

枚举所有的端点对,然后求出以该端点对作为直径的子树的数目 k,以及直径 d,然后将 d 对应的数目加 k。

下面的问题就是选中端点对 i, j 后,如何找到以 i, j 为直径的树。

如果 i, j 之间不连通,则 i, j 这个端点对可以直接过,如果连通,由于题目保证任意城市之间只有唯一的一条路径。换句话说,所有城市形成了一棵树。所以 i 和 j 之间的路径就是直径,直径为 d。

以 i, j 为直径的树,直径形成的路径上的点必选,在直径路线上的点已经选中后,问题变成了还能够额外选择哪些点,使得额外选择的点与直径连通也就是构成一棵新的树,且新的树的直径不大于 d。

定义 dist[i][j] 为 i 到 j 的距离,i, j 的路径上的点必选,如图中红色阴影的点。对于 i, j 路径以外的点 x

  • dist[i][x] > ddist[j][x] > d 时,x 不可选。如图中打叉的点。
  • dist[i][x] < ddist[j][x] < d 时,x 可以选,如图中红色实线框中的点。
  • dist[i][x] = ddist[j][x] = d 时,要看情况。因为可能有重复计数的情况。

枚举端点对 i, j 的过程如下:

1
2
for(int i = 0; i < n - 1; ++i)
for(int j = i + 1; j < n; ++j)

因此对于端点对 2,4,直径为 d,此时如果 dist[2][3] = d,那么包含 2, 3, 4 且直径为 d 的情况在端点对为 2, 3 时已经统计过了,因此这里 3 不选。

如果 dist[2][3] != ddist[4][3] = d,3 可以选,但注意后续对于端点对 3, 4 时,2 就不能加进去了。

总结一下,当出现 dist[i][x] = ddist[j][x] = d 时的选法如下:

dist[i][x] = dist[i][j] = d 的情况,如果 x < j,则 x 不选。
dist[j][x] = dist[i][j] = d 的情况,如果 x < i,则 x 不选。

所有可选点构成了若干棵树(i, j 路径上必选),一棵树,不妨将 i 作为该树的根。对于这棵树上的每个节点 x:

  • 如果选 x,则选法为 x 的若干子树的选法数的乘积。
  • 如果 x 不在 i, j 路径上,也就是 dist[i][x] + dist[j][x] > d,则还可以不选 x,将选法数加 1 即可。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class Solution {
public:
vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) {
vector<vector<int> > g(n + 1);
for(const auto &e: edges)
{
g[e[0]].push_back(e[1]);
g[e[1]].push_back(e[0]);
}
vector<vector<int>> dist(n + 1, vector<int>(n + 1, INT_MAX / 2));
for(int s = 1; s <= n; ++s)
bfs(g, s, dist);
vector<int> result(n - 1);
for(int i = 1; i < n; ++i)
for(int j = i + 1; j <= n; ++j)
{
int d = dist[i][j];
if(d == INT_MAX / 2)
continue;
result[d - 1] += dfs(g, dist, i, j, i, -1);
}
return result;
}

private:
struct State {
int v;
int d;
State(){}
State(int v, int d):v(v),d(d){}
};

vector<int> visited;
queue<State> q;

int dfs(const vector<vector<int>>& g, const vector<vector<int>>& dist, const int i, const int j, int x, int fa)
{
int d = dist[i][j];
int c = 1; // 选 x
for(int y: g[x])
{
// continue 的情况是 y 不可选的情况
if(y == fa)
continue;
if(dist[i][y] > d || dist[j][y] > d)
continue;
if(dist[i][y] == d && y < j)
continue;
if(dist[j][y] == d && y < i)
continue;
c *= dfs(g, dist, i, j, y, x);
}
if(dist[i][x] + dist[j][x] > d)
c++;
return c;
}

void bfs(const vector<vector<int>>& g, int s, vector<vector<int>>& dist)
{
int n = g.size() - 1;
visited.assign(n + 1, 0);
q.push(State(s, 0));
visited[s] = 1;
while(!q.empty())
{
State cur = q.front();
q.pop();
int u = cur.v;
int d = cur.d;
dist[s][u] = d;
for(int v: g[u])
{
if(visited[v] == 1)
continue;
visited[v] = 1;
q.push(State(v, d + 1));
}
}
}
};

Share