高维状态设计:凸连通块的状态表示

  |  

摘要: 凸连通块的状态表示

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们看一个高维线性动态规划的问题,主要难点在于对矩阵中的连通块的状态表示,其中阶段占了一个维度,附加信息占了 5 个维度。

276. I-区域

在 $N\times M$ 的矩阵中,每个格子有一个权值,要求寻找一个包含 $K$ 个格子的凸连通块(连通块中间没有空缺,并且轮廓是凸的),使这个连通块中的格子的权值和最大。

注意:凸连通块是指:连续的若干行,每行的左端点列号先递减、后递增,右端点列号先递增、后递减。

求出这个最大的权值和,并给出连通块的具体方案,输出任意一种方案即可。

1
2
3
4
5
6
7
8
9
10
11
12
输入格式
第一行包含三个整数 N,M 和 K。
接下来 N 行每行 M 个整数,表示 $N\times M$ 的矩阵上每个格子的权值(均不超过 1000)。

输出格式
第一行输出 Oil : X,其中 X 为最大权值和。

接下来 K 行每行两个整数 xi 和 yi,用来描述所有格子的具体位置,每个格子位于第 xi 行,第 yi 列。

数据范围
1<=N,M<=15,
0<=K<=N<=M

输入样例:
2 3 4
10 20 30
40 2 3
输出样例:
Oil : 100
1 1
1 2
1 3
2 1

高维线性 DP:对矩形连通块的状态设计

任意一个凸连通块,可以划分成连续的若干行,每行可以用左端点和右端点组成的区间 $l, r$ 表示。那么从上到下枚举各行,其左端点先递减,后递增,右端点先递增,后递减。

因此对于 $N \times M$ 矩阵,我可以一次考虑从矩阵的每一行中选择那些格子来构成凸连通块。过程中需要关注以下信息:

  • 已经处理完的行数
  • 已经选出的格子数
  • 当前行已选格子的左端位置:用于确定下一行左端点范围
  • 当前行已选格子的右端位置:用于确定下一行右端点范围
  • 当前行左侧轮廓是递增的还是递减的
  • 当前行右侧轮廓是递增的还是递减的

上述信息中,行数作为阶段。从当前行转移到下一行,这满足阶段线性增长的特点。其余信息作为附加信息,在状态转移时需要用到。综上,DP 状态表示如下。

状态表示

定义 $dp[i][j][l][r][x][y]$ 表示第 $0,\cdots,i$ 行选了 $j$ 个格子,其中第 $i$ 行选了 $[l, r]$ 范围内的格子(若不选则 l = r = -1),左边界单调性为 $x$,右边界单调性为 $y$(0 表示递增,1 表示递减)。这样的凸连通块可获得的最大权值和。

状态转移方程

当前行 $i$ 选择 $[l, r]$,权值和为 $\sum\limits_{k=l}\limits^{r}A[i][k]$,含 $r - l + 1$ 个格子。于是第 $0,1,\cdots,i-1$ 行选择的格子数就为 $j - (r - l + 1)$,因此状态转移方程大致如下:

其中 $mx = dp[i-1][j-(r-l+1)][p][q][x’][y’]$,于是状态转移方程的关键就在于后四维的信息 $p, q, x’, y’$ 是怎么转移的,一种特殊情况是 $j = r - l + 1$,此时前面的第 $0,1,\cdots,i-1$ 行选出的格子数为 0,$mx=0$。

$r - l + 1 < j$ 时的情况比较复杂,按照边界满足的单调性分别讨论,记 $s = \sum\limits_{k=l}\limits^{r}A[i][k]$:

(1) 左边单调递减,右边单调递增(两个边界都在扩张)

由于左端点要先递减后递增,所以第 $i$ 行单调性状态为 $(x, y) = (1, 0)$ 时,$i-1$ 行只能保持不变,即 $(x’, y’) = (1, 0)$。

第 $i - 1$ 行可以选择 $[l, r]$ 的某个子区间 $[p, q]$,这样如果左右端点单调性不变的话,就转移到了 $dp[i - 1][j - (r - l + 1)][p][q][1][0]$。

(2) 左边单调递增,右边单调递增(左边界收缩,右边界扩张)

第 $i$ 行为 $(x, y) = (0, 0)$ 时,前面的行 $x$ 可能变为 $1$,但 $y$ 不能变。因此 $(x’, y’) = (0, 0), (1, 0)$。于是:

(3) 左边单调递减,右边单调递减(左边界扩张,右边界收缩)

第 $i$ 行为 $(x, y) = (1, 1)$ 时,前面的行 $y$ 可能变为 $0$,但 $x$ 不能变。因此 $(x’, y’) = (1, 0), (1, 1)$。于是:

(4) 左边单调递增,右边单调递减(两个边界都在收缩)

第 $i$ 行为 $(x, y) = (0, 1)$ 时,前面的行 $x$ 可能变为 1,$y$ 也不能变为 0。因此 $(x’, y’) = (0, 0), (0, 1), (1, 0), (1, 1)$。于是:

初始值和答案

对于 $dp[i][j][l][r][x][y]$:

边界情况是当 $r - l + 1 = j$ 时,$dp[i][j][l][r][x][y] = s = \sum\limits_{k=l}\limits^{r}A[i][k]$。

此外还有不可能的情况,也就是 $j < r - l + 1$ 时,返回 $dp[i][j][l][r][x][y] = -1$,表示这种情况无法达到。

还有一种特殊情况是在 $i=0$ 时,若 $j > r - l + 1$,则也是无法达到的 $dp[i][j][l][r][x][y] = -1$。

答案为 $max(dp[i][K][l][r][x][y])$。

时间复杂度为 $O(NM^{4}K)$。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#include <iostream>
#include <cstring>
#include <vector>
#include <algorithm>

using namespace std;

const int INF = 0x3f3f3f3f;
const int MAXN = 15;
const int MAXM = 15;
const int MAXK = MAXN * MAXM;
int dp[MAXN][MAXK][MAXM][MAXM][2][2];
int A[MAXN][MAXM];
int sums[MAXN][MAXM + 1];
int N, M, K;

int solve(int i, int j, int l, int r, int x, int y)
{
if(dp[i][j][l][r][x][y] != INF)
return dp[i][j][l][r][x][y];

int w = r - l + 1;
int s = sums[i][r + 1] - sums[i][l];

if(j < w)
return -1;
if(j == w)
return s;
if(i == 0 && j > w)
return -1;

// i > 0, j > w
int ans = -1;
if(x == 1 && y == 0)
{
int x_ = 1;
int y_ = 0;
for(int p = l; p <= r; ++p)
for(int q = p; q <= r; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res != -1 && s + res > ans)
ans = s + res;
}
}
if(x == 0 && y == 0)
{
int y_ = 0;
for(int x_ = 0; x_ <= 1; ++x_)
for(int p = 0; p <= l; ++p)
for(int q = l; q <= r; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res != -1 && s + res > ans)
ans = s + res;
}
}
if(x == 1 && y == 1)
{
int x_ = 1;
for(int y_ = 0; y_ <= 1; ++y_)
for(int p = l; p <= r; ++p)
for(int q = r; q < M; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res != -1 && s + res > ans)
ans = s + res;
}
}
if(x == 0 && y == 1)
{
for(int x_ = 0; x_ <= 1; ++x_)
for(int y_ = 0; y_ <= 1; ++y_)
for(int p = 0; p <= l; ++p)
for(int q = r; q < M; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res != -1 && s + res > ans)
ans = s + res;
}
}
return dp[i][j][l][r][x][y] = ans;
}

void get_best_decisions(int i, int j, int l, int r, int x, int y, vector<vector<int>>& ans_best_decisions, int ans)
{
ans_best_decisions.push_back({i, j, l, r, x, y});

int w = r - l + 1;
int s = sums[i][r + 1] - sums[i][l];

if(j == w)
{
// 已做完最后一个决策
return;
}

// i > 0, j > w
if(x == 1 && y == 0)
{
int x_ = 1;
int y_ = 0;
for(int p = l; p <= r; ++p)
for(int q = p; q <= r; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res == ans - s)
{
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans_best_decisions, ans - s);
return;
}
}
}
if(x == 0 && y == 0)
{
int y_ = 0;
for(int x_ = 0; x_ <= 1; ++x_)
for(int p = 0; p <= l; ++p)
for(int q = l; q <= r; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res == ans - s)
{
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans_best_decisions, ans - s);
return;
}
}
}
if(x == 1 && y == 1)
{
int x_ = 1;
for(int y_ = 0; y_ <= 1; ++y_)
for(int p = l; p <= r; ++p)
for(int q = r; q < M; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res == ans - s)
{
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans_best_decisions, ans - s);
return;
}
}
}
if(x == 0 && y == 1)
{
for(int x_ = 0; x_ <= 1; ++x_)
for(int y_ = 0; y_ <= 1; ++y_)
for(int p = 0; p <= l; ++p)
for(int q = r; q < M; ++q)
{
int res = solve(i - 1, j - w, p, q, x_, y_);
if(res == ans - s)
{
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans_best_decisions, ans - s);
return;
}
}
}
}

int main()
{
memset(dp, INF, sizeof(dp));
memset(A, -1, sizeof(A));
memset(sums, -1, sizeof(sums));

cin >> N >> M >> K;
for(int i = 0; i < N; ++i)
for(int j = 0; j < M; ++j)
cin >> A[i][j];

for(int i = 0; i < N; ++i)
for(int j = 1; j <= M; ++j)
sums[i][j] = sums[i][j - 1] + A[i][j - 1];

int ans = -1;
vector<int> start;
for(int i = 0; i < N; ++i)
for(int l = 0; l < M; ++l)
for(int r = l; r < M; ++r)
for(int x = 0; x <= 1; ++x)
for(int y = 0; y <= 1; ++y)
{
int res = solve(i, K, l, r, x, y);
if(res != -1 && res > ans)
{
ans = res;
start = {i, K, l, r, x, y};
}
}

if(ans == -1)
{
// 没有可行决策
cout << "Oil : " << 0 << endl;
}
else
{
cout << "Oil : " << ans << endl;

vector<vector<int>> ans_best_decisions;
get_best_decisions(start[0], start[1], start[2], start[3], start[4], start[5], ans_best_decisions, ans);
reverse(ans_best_decisions.begin(), ans_best_decisions.end());
for(vector<int>& decison: ans_best_decisions)
{
int i = decison[0];
int l = decison[2];
int r = decison[3];
for(int j = l; j <= r; j++)
cout << i + 1 << " " << j + 1 << endl;
}
}
}

代码 (Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import functools

def main():
@functools.lru_cache(int(1e8))
def solve(i: int, j: int, l: int, r: int, x: int, y: int) -> int:
# N * K * M * M * 2 * 2
w = r - l + 1
s = sums[i][r + 1] - sums[i][l]

if j < w:
return -1
if j == w:
return s
if i == 0 and j > w:
return -1

# i > 0, j > w
ans = -1

if (x, y) == (1, 0):
x_, y_ = 1, 0
for p in range(l, r + 1):
for q in range(p, r + 1):
# M * M
res = solve(i - 1, j - w, p, q, x_, y_)
if res != -1 and s + res > ans:
ans = s + res

if (x, y) == (0, 0):
y_ = 0
for x_ in range(2):
for p in range(0, l + 1):
for q in range(l, r + 1):
res = solve(i - 1, j - w, p, q, x_, y_)
if res != -1 and s + res > ans:
ans = s + res

if (x, y) == (1, 1):
x_ = 1
for y_ in range(2):
for p in range(l, r + 1):
for q in range(r, M):
res = solve(i - 1, j - w, p, q, x_, y_)
if res != -1 and s + res > ans:
ans = s + res

if (x, y) == (0, 1):
for x_ in range(2):
for y_ in range(2):
for p in range(0, l + 1):
for q in range(r, M):
res = solve(i - 1, j - w, p, q, x_, y_)
if res != -1 and s + res > ans:
ans = s + res

return ans

def get_best_decisions(i: int, j: int, l: int, r: int, x: int, y: int, ans: int) -> None:
ans_best_decisions.append((i, j, l, r, x, y))

w = r - l + 1
s = sums[i][r + 1] - sums[i][l]

if j == w:
return

# i > 0, j > w
if (x, y) == (1, 0):
x_, y_ = 1, 0
for p in range(l, r + 1):
for q in range(p, r + 1):
# M * M
res = solve(i - 1, j - w, p, q, x_, y_)
if res == ans - s:
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans - s)
return

if (x, y) == (0, 0):
y_ = 0
for x_ in range(2):
for p in range(0, l + 1):
for q in range(l, r + 1):
res = solve(i - 1, j - w, p, q, x_, y_)
if res == ans - s:
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans - s)
return

if (x, y) == (1, 1):
x_ = 1
for y_ in range(2):
for p in range(l, r + 1):
for q in range(r, M):
res = solve(i - 1, j - w, p, q, x_, y_)
if res == ans - s:
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans - s)
return

if (x, y) == (0, 1):
for x_ in range(2):
for y_ in range(2):
for p in range(0, l + 1):
for q in range(r, M):
res = solve(i - 1, j - w, p, q, x_, y_)
if res == ans - s:
get_best_decisions(i - 1, j - w, p, q, x_, y_, ans - s)
return

N, M, K = list(map(int, input().split()))
A = [[] for _ in range(N)]
for i in range(N):
A[i] = list(map(int, input().split()))


sums = [[0 for _ in range(M + 1)] for _ in range(N)]
for i in range(N):
for j in range(1, M + 1):
sums[i][j] = sums[i][j - 1] + A[i][j - 1]

ans = -1
start = []
for i in range(N):
for l in range(M):
for r in range(l, M):
for (x, y) in [(0, 0), (0, 1), (1, 0), (1, 1)]:
res = solve(i, K, l, r, x, y)
if res != -1 and res > ans:
ans = res
start = [i, K, l, r, x, y]

if ans == -1:
print("Oil : {}".format(0))
return

print("Oil : {}".format(ans))

ans_best_decisions = []
get_best_decisions(*start, ans)
ans_best_decisions.reverse()
for decision in ans_best_decisions:
i = decision[0]
l = decision[2]
r = decision[3]
for j in range(l, r + 1):
print("{} {}".format(i + 1, j + 1))


if __name__ == "__main__":
main()

Share