理解高维状态空间线性DP:已处理部分在状态空间中的轮廓

  |  

摘要: 状态计算的方向

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们通过照相排列问题 (k 个串上的线性 DP),来理解一个动态规划中的一个抽象的概念:已处理部分在高维状态空间中的轮廓。可以大致做以下理解:

  • 一个轮廓代表了一个子问题;轮廓内代表已经解决的更小的子问题;轮廓外是尚未解决的更大的子问题。
  • 轮廓上的状态转移,相当于轮廓沿着高维状态空间的某一维度扩张,进入下一阶段。

因此在分析高维状态空间的问题的时候,可以从已处理部分的轮廓的角度来发现最优子结构和重复子问题。

照相排列问题

271. 杨老师的照相排列

有 N 个学生合影,站成左端对齐的 k 排,每排分别有 N1,N2,…,Nk 个人。 (N1 >= N2 >= … >= Nk)

第 1 排站在最后边,第 k 排站在最前边。

学生的身高互不相同,把他们从高到底依次标记为 1,2,…,N。

在合影时要求每一排从左到右身高递减,每一列从后到前身高也递减。

问一共有多少种安排合影位置的方案?

下面的一排三角矩阵给出了当 N=6,k=3,N1=3,N2=2,N3=1 时的全部 16 种合影方案。注意身高最高的是 1,最低的是 6。

1
2
3
123 123 124 124 125 125 126 126 134 134 135 135 136 136 145 146
45 46 35 36 34 36 34 35 25 26 24 26 24 25 26 25
6 5 6 5 6 4 5 4 6 5 6 4 5 4 3 3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
输入格式
输入包含多组测试数据。

每组数据两行,第一行包含一个整数 k 表示总排数。

第二行包含 k 个整数,表示从后向前每排的具体人数。

当输入 k=0 的数据时,表示输入终止,且该数据无需处理。

输出格式
每组测试数据输出一个答案,表示不同安排的数量。
每个答案占一行。

数据范围
1<=k<=5,
学生总人数不超过 30 人。

输入样例:
1
30
5
1 1 1 1 1
3
3 2 1
4
5 3 3 1
5
6 5 4 3 2
2
15 15
0
输出样例:
1
1
16
4158
141892608
9694845

算法:动态规划

每行和每列的身高都是单调的。随意可以从高到低依次考虑标记为 1, 2, …, N (从高到低考虑)的学生站的位置。

这样在任意时刻,已经安排好位置的学生在每一行占据的一定是从左开始连续若干个位置。

用一个 k 元组 $(a_{1}, a_{2}, \cdots, a_{k})$ 表示每一行已经安排的学生人数,即可描绘出“已经处理的部分”的轮廓。

当安排一名新学生时,考虑所有满足以下条件的行号 $i$:

  • $a_{i} < N_{i}$
  • $i = 1$ 或 $a_{i-1} > a_{i}$

新学生只要安排在满足以上条件的行,每列的单调性也能满足。因此我们不用关心已经安排好的 $\sum\limits_{i=1}\limits^{k}a_{i}$ 名学生的具体方案。$(a_{1}, a_{2}, \cdots, a_{k})$ 描绘的轮廓内的方案数构成一个子问题。因此可以把 $(a_{1}, a_{2}, \cdots, a_{k})$ 作为阶段,当安排一名新学生时,$a_{1}, \cdots, a_{k}$ 之一会增加 1,从而转移到后续阶段

综上,设计动态规划算法如下:

  • 状态定义:$dp(\vec{a})$,其中$\vec{a}=(a_{1}, \cdots, a_{k})$,表示考虑前 $\sum\limits_{i=1}\limits^{k}a_{i}$ 名学生,将其按照 $\vec{a}$ 表示的各行人数的排列方法数。
  • 边界:$dp(\vec{0})$
  • 目标:$dp(\vec{N})$,其中 $\vec{N}=(N_{1}, \cdots, N_{k})$
  • 状态转移方程(如何通过已有状态计算当前状态):

代码 (Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import functools

def main():
@functools.lru_cache()
def solve(a):
if list(a) == [0] * len(a):
return 1
ans = 0
for i in range(k):
if a[i] > 0 and a[i] <= N[i] and (i == k - 1 or a[i] > a[i + 1]):
nxt_a = list(a)
nxt_a[i] -= 1
ans += solve(tuple(nxt_a))
return ans

while True:
k_str = input()
if k_str == "0":
break
k = int(k_str)
N_str = input()
N_str = N_str.split()
N = [int(s) for s in N_str]

ans = solve(tuple(N))
print(ans)


if __name__ == "__main__":
main()

Share