拓扑排序的方案数

  |  

摘要: 一个综合算法问题:树形DP + BST 建树 + 预处理组合数

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们通过力扣的一个题目,来看一下拓扑排序方案数问题怎么解决。这是第 204 场周赛 D 题,综合性很强,涉及到数组形式 BST 的存储和插入;树形DP;预处理组合数等算法。

$1 题目

给你一个数组 nums 表示 1 到 n 的一个排列。我们按照元素在 nums 中的顺序依次插入一个初始为空的二叉搜索树(BST)。请你统计将 nums 重新排序后,统计满足如下条件的方案数:重排后得到的二叉搜索树与 nums 原本数字顺序得到的二叉搜索树相同。

比方说,给你 nums = [2,1,3],我们得到一棵 2 为根,1 为左孩子,3 为右孩子的树。数组 [2,3,1] 也能得到相同的 BST,但 [3,2,1] 会得到一棵不同的 BST 。

请你返回重排 nums 后,与原数组 nums 得到相同二叉搜索树的方案数。

由于答案可能会很大,请将结果对 10^9 + 7 取余数。

提示:

1
2
3
1 <= nums.length <= 1000
1 <= nums[i] <= nums.length
nums 中所有数 互不相同 。

示例 1:
输入:nums = [2,1,3]
输出:1
解释:我们将 nums 重排, [2,3,1] 能得到相同的 BST 。没有其他得到相同 BST 的方案了。

示例 2:
输入:nums = [3,4,5,1,2]
输出:5
解释:下面 5 个数组会得到相同的 BST:
[3,1,2,4,5]
[3,1,4,2,5]
[3,1,4,5,2]
[3,4,1,2,5]
[3,4,1,5,2]

示例 3:
输入:nums = [1,2,3]
输出:0
解释:没有别的排列顺序能得到相同的 BST 。

$2 题解

算法: 树形 DP + BST 建树 + 预处理组合数

第1部分:数组形式 BST 的存储和插入

由给定的 nums 可以建一棵树。因为树形DP方程要用到节点所代表的子树的大小,因此给节点增加 size 字段表示该信息

1
2
3
4
5
6
7
struct Node
{
int id;
int left, right;
int size;
Node(int id, int left = 0, int right = 0, int size = 1):id(id),left(left),right(right),size(size){}
};

树以数组的形式存储:vector<Node> tree,当某个节点 node, tree[node].lefttree[node].right 均为 0 时,为叶子节点。

1
2
3
4
bool is_leaf(const vector<Node>& tree, int node)
{
return tree[node].left == 0 && tree[node].right == 0;
}

树的插入就是递归实现的 BST 插入过程,只是这里是数组形式。需要注意边界的处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
void insert(vector<Node>& tree, int root, int node)
{
if(node == root)
return;
++tree[root].size;
if(node < root)
{
if(tree[root].left == 0)
{
tree[root].left = node;
return;
}
insert(tree, tree[root].left, node);
}
if(node > root)
{
if(tree[root].right == 0)
{
tree[root].right = node;
return;
}
insert(tree, tree[root].right, node);
}
}

第2部分: 树形 DP

状态定义

1
2
size[node]:= 以 node 为根的子树的节点数目
dp[node] := 以 node 为根的子树,有多少重排的方案数(包括原排列本身)

状态转移

当 node 为叶子节点时候,dp[node] 当然为 1。

当 node 不为叶子节点的时候,即 node 代表的子树至少有两个节点。可以将 node 代表的子树分为三部分(左右子树最多有一个为空):

  1. 根: node
  2. 左子树: node.left 代表的子树
  3. 右子树: node.right 代表的子树

node 子树在数组上有 size[node] 个位置,第一个位置为 node 固定了,剩下的 size[node] - 1 个位置需要选出 size[node.left] 个位置放置比根小的节点。选中位置之后,左右子树内部的排列个数形成了子问题,即 dp[node.left]dp[node.right]

1
2
dp[node] = 1   node 为叶子节点
= C(size[node] - 1, size[node.left]) * dp[node.left] * dp[node.right] node 为非叶子节点。

第3部分:预处理组合数

在状态转移方程中有一步 C(size[node] - 1, size[node.left]),因此可以将 C(n, i) 0 < i < n 预处理出来然后直接查询。

实现如下。使用的是预处理阶乘和逆元的方式。理论基础参考 组合数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class C
{
public:
C(){}
C(ll n, ll p)
{
preprocess(n, p);
}

ll operator()(ll n, ll m, ll p) const
{
// 先求 $n! \mod p$
ll ans1 = jc[n];
// 再求 $m!(n-m)! \mod p$ 的逆元
ll ans2 = (jc_inv[m] * jc_inv[n - m]) % p;
// 再乘起来
return (ans1 * ans2) % p;
}

private:
vector<ll> jc, jc_inv;

void preprocess(ll n, ll p)
{
// jc[i] = i! % p
jc = vector<ll>(n + 1, 1);
for(ll i = 1; i <= n; ++i)
jc[i] = (jc[i - 1] * i) % p;
jc_inv = vector<ll>(n + 1, 1);
// 0 不存在逆元
for(ll i = 1; i <= n; ++i)
jc_inv[i] = inv(jc[i], p);
}

// 扩展欧几里得求 ax + by = gcd(a, b)
ll exgcd(ll a, ll b, ll& x, ll& y)
{
// 求出 ax + by = gcd(a, b) 的一组特接并返回 a,b 的最大公约数 d。
if(b == 0)
{
x = 1;
y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y);
int z = x;
x = y;
y = z - y * (a / b);
return d;
}

// 求 b 模 p 的乘法逆元 (b 与 p 互质)
ll inv(ll b, ll p)
{
// 解方程 bx 与 1 模 p 同余
// 扩展欧几里得求 bx0 + py0 = 1
ll x0 = 0, y0 = 0;
ll d = exgcd(b, p, x0, y0);
if(d != 1) // b 和 p 不互质的情况不不存在逆元
return -1;
return (x0 % p + p) % p;
}
};

代码(C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
using ll = long long;

class C
{
public:
C(){}
C(ll n, ll p)
{
preprocess(n, p);
}

ll operator()(ll n, ll m, ll p) const
{
// 先求 $n! \mod p$
ll ans1 = jc[n];
// 再求 $m!(n-m)! \mod p$ 的逆元
ll ans2 = (jc_inv[m] * jc_inv[n - m]) % p;
// 再乘起来
return (ans1 * ans2) % p;
}

private:
vector<ll> jc, jc_inv;

void preprocess(ll n, ll p)
{
// jc[i] = i! % p
jc = vector<ll>(n + 1, 1);
for(ll i = 1; i <= n; ++i)
jc[i] = (jc[i - 1] * i) % p;
jc_inv = vector<ll>(n + 1, 1);
// 0 不存在逆元
for(ll i = 1; i <= n; ++i)
jc_inv[i] = inv(jc[i], p);
}

// 扩展欧几里得求 ax + by = gcd(a, b)
ll exgcd(ll a, ll b, ll& x, ll& y)
{
// 求出 ax + by = gcd(a, b) 的一组特接并返回 a,b 的最大公约数 d。
if(b == 0)
{
x = 1;
y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y);
int z = x;
x = y;
y = z - y * (a / b);
return d;
}

// 求 b 模 p 的乘法逆元 (b 与 p 互质)
ll inv(ll b, ll p)
{
// 解方程 bx 与 1 模 p 同余
// 扩展欧几里得求 bx0 + py0 = 1
ll x0 = 0, y0 = 0;
ll d = exgcd(b, p, x0, y0);
if(d != 1) // b 和 p 不互质的情况不不存在逆元
return -1;
return (x0 % p + p) % p;
}
};

struct Node
{
int id;
int left, right;
int size;
Node(int id, int left = 0, int right = 0, int size = 1):id(id),left(left),right(right),size(size){}
};

class Solution {
public:
int numOfWays(vector<int>& nums) {
int n = nums.size();
vector<Node> tree(n + 1, Node(0));
for(int i = 1; i <= n; ++i)
tree[i].id = i;
comb = C(n, MOD);
int root = nums[0];
for(int i = 0; i < n; ++i)
{
int node = nums[i];
insert(tree, root, node);
}
int ans = dfs(tree, root);
ans = (ans - 1 + MOD) % MOD;
return ans;
}

private:
const int MOD = 1e9 + 7;
C comb;

int dfs(const vector<Node>& tree, int node)
{
if(is_leaf(tree, node))
return 1;
int left_size = tree[node].left == 0 ? 0 : tree[tree[node].left].size;
int c = comb(tree[node].size - 1, left_size, MOD); // C(n, m) 返回 C(n, m)%p
int ans = c;
if(tree[node].left != 0)
{
int left_cnt = dfs(tree, tree[node].left);
ans = ((ll)ans * left_cnt) % MOD;
}
if(tree[node].right != 0)
{
int right_cnt = dfs(tree, tree[node].right);
ans = ((ll)ans * right_cnt) % MOD;
}
return ans;
}

bool is_leaf(const vector<Node>& tree, int node)
{
return tree[node].left == 0 && tree[node].right == 0;
}

void insert(vector<Node>& tree, int root, int node)
{
if(node == root)
return;
++tree[root].size;
if(node < root)
{
if(tree[root].left == 0)
{
tree[root].left = node;
return;
}
insert(tree, tree[root].left, node);
}
if(node > root)
{
if(tree[root].right == 0)
{
tree[root].right = node;
return;
}
insert(tree, tree[root].right, node);
}
}
};

Share