【状态压缩DP】N次操作后的最大分数和

  |  

摘要: 一道状态压缩DP的简单题

【对数据分析、人工智能、金融科技、风控服务感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:潮汐朝夕
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


在文章 状态压缩DP 中我们学习了状态压缩DP的知识,本文我们来看一个相关的经典问题。

题目

1799. N 次操作后的最大分数和

给你 nums ,它是一个大小为 2 * n 的正整数数组。你必须对这个数组执行 n 次操作。

在第 i 次操作时(操作编号从 1 开始),你需要:

选择两个元素 x 和 y 。
获得分数 i * gcd(x, y) 。
将 x 和 y 从 nums 中删除。
请你返回 n 次操作后你能获得的分数和最大为多少。

函数 gcd(x, y) 是 x 和 y 的最大公约数。

提示:

1
2
3
1 <= n <= 7
nums.length == 2 * n
1 <= nums[i] <= 1e6

示例 1:
输入:nums = [1,2]
输出:1
解释:最优操作是:
(1 * gcd(1, 2)) = 1

示例 2:
输入:nums = [3,4,6,8]
输出:11
解释:最优操作是:
(1 gcd(3, 6)) + (2 gcd(4, 8)) = 3 + 8 = 11

示例 3:
输入:nums = [1,2,3,4,5,6]
输出:14
解释:最优操作是:
(1 gcd(1, 5)) + (2 gcd(2, 4)) + (3 * gcd(3, 6)) = 1 + 4 + 9 = 14

题解

算法:状态压缩DP

一共有 2n 个数,每次从中取出 2 个数,共取 n 次。从动态规划的角度看的话,阶段是比较容易确定的,就是每次取数就是一个阶段,共 n 个阶段。

每次取数后,nums 中剩余的数构成了下一个阶段的子问题。这个子问题的答案受到 nums 中剩余的数的情况的影响,因此状态应该 nums 中的数字是否被取出。

按照前面的分析,可以定义 dp[i][s] 为第 i 次取数时,nums 中剩余的数的状态为 s 的情况下的最终的最大分数和。

但注意,由于每一阶段剩余数字个数减 2,而 s 中隐含了数字的个数,因此 i 可以省略掉。

1
2
3
4
5
6
7
8
9
10
11
12
13
状态定义:
dp[s] := nums 中各个数字是否被取出的情况为 s 时的最大分数和
其中 (s >> i) 为 1 表示 nums[i] 已经被取出;为 0 表示尚未被取出

答案:
dp[0]

初始化:
dp[(1 << n) - 1] = 0

状态转移:
dp[s] = max(r * gcd(nums[i], nums[j]) + dp[s | (1 << i) | (1 << j)])
其中 r 表示当前的步数

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
public:
int maxScore(vector<int>& nums) {
int n = nums.size();
dp = vector<int>(1 << n, -1);
return solve(0, nums, 1);
}

private:
vector<int> dp;
int solve(int state, const vector<int>& nums, int r)
{
if(dp[state] != -1)
return dp[state];
int n = nums.size();
int ans = 0;
for(int i = 0; i < n; ++i)
{
if(state >> i & 1)
continue;
for(int j = i + 1; j < n; ++j)
{
if(state >> j & 1)
continue;
ans = max(ans, r * (gcd<int>(nums[i], nums[j])) + solve(state | (1 << i) | (1 << j), nums, r + 1));
}
}
dp[state] = ans;
return ans;
}
};

代码 (Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution:
def maxScore(self, nums: List[int]) -> int:
self.nums = nums
self.n = len(nums)
return self.solve(0, 1)

@lru_cache(int(1e7))
def solve(self, state: int, r: int) -> int:
ans = 0
for i in range(self.n):
if state >> i & 1:
continue
for j in range(i + 1, self.n):
if state >> j & 1:
continue
ans = max(ans, r * gcd(self.nums[i], self.nums[j]) + self.solve(state | (1 << i) | (1 << j), r + 1))
return ans

Share