树状数组优化DP

  |  

摘要: 权值树状数组优化DP

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


LIS

最长上升子序列LIS 我们通过动态规划解决了最长上升子序列问题,其状态转移方程如下:

在这个状态转移方程中,我们需要查询 $dp$ 数组在 $0 \leq j < i$ 上同时满足 $a[j] < a[j]$ 的最小值,在推导的过程中 $dp$ 数组会发生更新。这是一个待修改的区间最值查询问题。

同时这里的区间并不是简单的 $j \in [0, i)$,而是还有一个 $a[j] < a[i]$,所以我们维护的应该是 $a$ 数组的值域的某个取值范围上,$dp$ 的最值。这可以通过权值树状数组或权值线段树解决。

类似的状态转移方程经过分析发现是带修改的区间和或区间最值查询时,可以考虑用树状数组或线段树优化,具体用哪种结构,需要结合具体的问题分析。

本文我们以这个问题为例,看一下权值树状数组优化 DP 是怎么做的。


树状数组优化 DP

LIS 的转移方程如下:

朴素的做法,max 这一步要用一个循环来求。

但将这个过程抽象一下,相当于询问:在 [0..i-1] 上,值比 s[i] 小的这些位置,dp 的最大值是多少。

这里的区间包含的相当于是 [0..i-1] 上所有比 s[i] 小的值,数据就是对应的 dp 值,一个 s[j] 值对应一个 dp 值,如果有两个 j1, j2,它们的值 s[j1]s[j2] 相等,则 dp 值只记最大的。

在数据和区间的含义按照以上描述来理解时,问区间的最大值是多少。

维护这些 s 值及其对应的 dp 值的数组称为权值数组,权值数组的下标是 s 的各个值,权值数组的值就是 s 的值对应的 dp 值。

当用树状数组维护权值数组的最大值时,称为时,称为权值树状数组权值树状数组。参考文章 权值线段树、权值树状数组:元素排名区间的权值(个数)和

以上描述的原始数组 <-> 权值数组 <-> 权值树状数组的对应关系如图:

代码 (C++)

下面的代码中 BIT_RMQ 部分完全照搬模板,参考文章 树状数组维护区间最值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class BIT_RMQ;

class Solution {
public:
// 树状数组优化
int lengthOfLIS(vector<int>& nums) {
int n = nums.size();
for(int i : nums)
x.push_back(i);
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end()); // 把实际值离散化
m = x.size();
BIT_RMQ rmq(m);
int ans = 0;
for(int i = 0; i < n; ++i)
{
if(find(nums[i]) == 1)
{
ans = max(ans, 1);
rmq.update(1, 1);
continue;
}
int dp = rmq.query(1, find(nums[i]) - 1) + 1;
ans = max(ans, dp);
rmq.update(find(nums[i]), dp);
}
return ans;
}

private:
int m;
vector<int> x; // 此数组用于求 nums 中的值离散化之后的值

int find(int v) // 从 nums 的值找到对应的离散化之后的值
{
return lower_bound(x.begin(), x.end(), v) - x.begin() + 1;
}
};

代码优化1

可以不单独写成类 BIT_RMQ,而是直接操作权值数组和 bit 数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Solution {
public:
// 树状数组优化
int lengthOfLIS(vector<int>& nums) {
int n = nums.size();
for(int i : nums)
x.push_back(i);
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end()); // 把实际值离散化
m = x.size();
dp.assign(m + 1, 0);
bit.assign(m + 1, 0); // 树状数组初始化
int ans = 0;
for(int i = 0; i < n; ++i)
{
int w_idx = find(nums[i]);
int len = 1;
if(w_idx > 0)
len += query(1, w_idx);
ans = max(ans, len);
update(w_idx + 1, len);
}
return ans;
}

private:
vector<int> x; // 所有能值排序后去重
vector<int> dp; // 权值数组
int m; // 权值数组的长度

vector<int> bit; // 权值数组的树状数组

void update(int idx, int val)
{
dp[idx - 1] = val;
bit[idx] = val;
for(int i = idx; i <= m; i += lowbit(i))
{
for(int j = 1; j < lowbit(i); j <<= 1)
{
bit[i] = max(bit[i], bit[i - j]);
}
}
}

int query(int l, int r)
{
int ans = dp[r - 1];
while(true)
{
for(; r - lowbit(r) + 1 > l; r -= lowbit(r))
ans = max(ans, bit[r]);
if(r - lowbit(r) + 1 == l)
{
ans = max(ans, bit[r]);
return ans;
}
ans = max(ans, dp[r - 1]);
--r;
}
}

int lowbit(int v)
{
return v & (-v);
}

int find(int v)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

代码优化2

本题有一定的特殊性:

  1. 所有待查询区间都是从 0 开始的
  2. 所有的更新都把值(权值数组的值也就是dp值)变大了
  3. 查询的是最大值

在这些更强的条件下,可以把树状数组的更新和查询写的更简洁(见代码中的 updategetmax)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solution {
public:
// 树状数组优化
int lengthOfLIS(vector<int>& nums) {
int n = nums.size();
for(int i : nums)
x.push_back(i);
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end()); // 把实际值离散化
m = x.size();
bit.assign(m + 1, 0); // 树状数组初始化
int ans = 0;
for(int i = 0; i < n; ++i)
{
int dp = getmax(find(nums[i]) - 1) + 1;
ans = max(ans, dp);
add(find(nums[i]), dp);
}
return ans;
}

private:
int m;
vector<int> x; // 此数组用于求 nums 中的值离散化之后的值

int find(int v) // 从 nums 的值找到对应的离散化之后的值
{
return lower_bound(x.begin(), x.end(), v) - x.begin() + 1;
}

vector<int> bit; // 树状数组
int lowbit(int x)
{
// 树状数组经典实现
return x & (-x);
}
int getmax(int x)
{
// 树状数组经典实现
int ma = 0;
for(int i = x; i; i -= lowbit(i))
ma = max(ma, bit[i]);
return ma;
}
void add(int x, int v)
{
// 树状数组经典实现
for(int i = x; i <= m; i += lowbit(i))
bit[i] = max(bit[i], v);
}
};

Share