树状数组维护区间最值RMQ

  |  

考虑以下需求:
给定一个数组 a,支持区间和的查询, 其中 a 中的元素可变,每次变化 a 中的 1 个元素。
这个需求用带单点修改和区间查询的线段树或树状数组都可以解决,其中用树状数组的代码量少很多。

如果 a 不变,则是前缀和解决的问题。而对与 a 可变的情况,区间和查询是树状数组解决的最基本的问题,它在力扣上有对应的题目:

现在需求变一下,把支持区间和的查询改成支持区间最大值的查询,即 RMQ 问题。
如果 a 不变,则常见解法是用基于区间 DP 的稀疏表。如果 a 可变,有几种主流的方案: 线段树,树状数组,分块。

树状数组做带单点修改的区间最值的查询是可做的,只是比区间和的需求复杂一些。

2.jpg

回顾用树状数组做区间和的做法:

更新的接口:update(int idx, int delta), 数组元素 a[idx - 1] 的值加上 delta,然后更新树状数组 vec 的值(数组 a 的元素 idx - 1 对应树状数组 vec 的 idx), 更新 vec 中与 idx 相关的节点的代码

1
2
3
4
5
6
7
8
9
10
11
12
for(int i = idx; i < n; i += lowbit(i))
{
// n 是树状数组 vec 的长度
vec[i] += delta;
}

// 或者 while 循环
while(idx < n)
{
vec[idx] += delta;
idx += lowbit(idx);
}

查询的接口:query(idx), 返回前缀 [0..idx-1] 的和, 求区间 [l, r] 的和: query(r + 1) - query(l), 汇总与 idx 相关的节点值的代码

1
2
3
4
5
6
7
8
9
for(int i = idx; i > 0; i -= lowbit(i))
sum += vec[i];

// 或者 while 循环
while(i > 0)
{
sum += vec[i];
i -= lowbit(i);
}

树状数组做区间最值的修改和查询

修改的接口:update(int idx, int val), 数组元素 a[idx - 1] 的值改为 val。

改了 a[idx - 1] 之后,与区间和的情况相同,也是要修改所有与 vec[idx] 相关的节点的值。

树状数组的节点 vec[idx] 负责的是区间 [idx - lowbit(idx) + 1, idx],vec[idx] 修改之后,会对 idx 的所有父节点有影响,这点与区间和的情况一样。
在区间和中,通过 idx += lowbit(idx) 找到所有父节点之后,直接将节点值加上 delta 就可以了,但是区间最值不行。

以最大值为例,某次修改将 a[idx - 1] 变了,待修改节点是 vec[idx] 以及其父亲链上的节点,但是 a[idx - 1] 不一定会影响到这些节点值。对于任意一个待修改节点 i,
其所有子节点的值,以及 i 对应的数据 a[i - 1] 本身的值共同决定了最大值 vec[i]。因此要考察其所有子节点,记录最大值

1
2
3
// 当前待修改节点为 i, i - 2^k 是其各个子节点(2^k < lowbit(i))
for(int j = 1; j < lowbit(i); j <<= 1)
vec[i] = max(vec[i], vec[i - j])

因此完整的更新过程如下, 总时间复杂度 $O(\log^{2}N)$

1
2
3
for(int i = idx; i < n; i += lowbit(i))
for(int j = 1; j < lowbit(i); j <<= 1)
vec[i] = max(vec[i], vec[i - j]);

例如若要更新 vec[8], 需要查看 vec7,vec6, vec4

查询的接口: query(int l, int r), 查询时,直接对区间 [l, r] 统计答案,而不是分别求两次前缀的答案再左差,这一点与区间和的情况区别较大。

vec[r] 表示的是区间 [r - lowbit(r) + 1, r] 的最大值,记 i = r - lowbit(r) + 1

  • (1) 如果 i = l, 则可以返回 vec[r], 即为当前 [l, r] 的最大值。
  • (2) 如果 i < l, 则 vec[r] 记录的 [i, r] 的最值可能来自 [i..l-1] 这一部分(对应 a[i-1..l]),因此只能从原数据 a 中取出对应值 a[r - 1] 并统计答案(这里要用到原数组,而区间和的需求不需要), 然后 —r 之后继续统计。
  • (3) 如果 i > l, 则 vec[r] 记录的 [i, r] 的最值完整地影响待查询区间,因此直接统计答案即可,之后 r -= lowbit(r) 更新 r 到所负责区间左端点的前一个点,继续统计

完整过程如下:

1
2
3
4
5
6
7
8
9
int ans = a[r - 1];
while(true)
{
ans = max(ans, a[r - 1]); // (2)
if(l == r) break; // (1)
--r;
for(; l <= r - lowbit(r); r -= lowbit(r))
ans = max(ans, vec[r]) // (3)
}

以下代码更直观,但比较冗长

1
2
3
4
5
6
7
8
9
10
11
12
while(true)
{
for(; r - _lowbit(r) + 1 > l; r -= _lowbit(r)) // (3)
ans = max(ans, vec[r]);
if(r - _lowbit(r) + 1 == l) // (1)
{
ans = max(ans, vec[r]);
break;
}
ans = max(ans, a[r - 1]); // (2)
--r;
}

将以上分析过程综合起来,就得到用带单点修改的树状数组做区间最值查询的代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class BIT_RMQ
{
public:
BIT_RMQ():vec(1,0),a(1,0){}
BIT_RMQ(int n):vec(n + 1, 0),a(n,0){}

void update(int idx, int x)
{
// vec[idx] 管的是 [idx-lowbit[idx] + 1..idx] 这个区间
// a[idx - 1] 改为 x
// vec[idx]
a[idx - 1] = x;
vec[idx] = x;
int n = a.size();
for(int i = idx; i <= n; i += _lowbit(i))
{
for(int j = 1; j < _lowbit(i); j <<= 1)
{
// j < _lowbit(i) <= j - i < _lowbit(i) - i <= i - j > i - _lowbit(i)
// i = 8,即改 vec[8]
// 要看 vec[7] = i - 1
// vec[6] = i - 2
// vec[4] = i - 4
vec[i] = max(vec[i], vec[i - j]);
}
}
}

int query(int l, int r)
{
// 直接看 vec[r] 不行
// vec[r] 对应 [r - lowbit[r] + 1, r]
int ans = a[r - 1];
while(true)
{
ans = max(ans, a[r - 1]);
if(l == r)
break;
--r;
for(; r - _lowbit(r) >= l; r -= _lowbit(r))
ans = max(ans, vec[r]);
}
return ans;
}

void view()
{
int n = a.size();
for(int i = 0; i < n; ++i)
cout << a[i] << " ";
cout << endl;
for(int i = 1; i <= n; ++i)
cout << vec[i] << " ";
cout << endl;
}

private:
vector<int> vec;
vector<int> a;

int _lowbit(int x)
{
return x & (-x);
}
};

应用以及模板的测试

LIS 问题的树状数组优化。

LIS 的转移方程如下

1
2
dp[i] := [0..i] 上的 LIS 长度
dp[i] = 1 + max(dp[j]) 其中 j < i 且 s[j] < s[i]

朴素的做法,max 这一步要用一个循环来求。

但将这个过程抽象一下,相当于询问:在 [0..i-1] 上,值比 s[i] 小的这些位置,dp 的最大值是多少。

这里的区间包含的相当于是 [0..i-1] 上所有比 s[i] 小的值,数据就是对应的 dp 值,一个 s[j] 值对应一个 dp 值,如果有两个 j1, j2,它们的值 s[j1] 与 s[j2] 相等,则 dp 值只记最大的。
在数据和区间的含义按照以上描述来理解时,问区间的最大值是多少。

维护这些 s 值及其对应的 dp 值的数组称为权值数组,权值数组的下标是 s 的各个值,权值数组的值就是 s 的值对应的 dp 值。
当用树状数组维护权值数组的最大值时,称为时,称为权值树状数组权值树状数组

以上描述的原始数组 - 权值数组 - 权值树状数组的对应关系如图

上述算法在力扣上有对应题目:300. 最长上升子序列

下面的代码中 BIT_RMQ 部分完全照搬模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class BIT_RMQ;

class Solution_5 {
public:
// 树状数组优化
int lengthOfLIS(vector<int>& nums) {
int n = nums.size();
for(int i : nums)
x.push_back(i);
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end()); // 把实际值离散化
m = x.size();
BIT_RMQ rmq(m);
int ans = 0;
for(int i = 0; i < n; ++i)
{
if(find(nums[i]) == 1)
{
ans = max(ans, 1);
rmq.update(1, 1);
continue;
}
int dp = rmq.query(1, find(nums[i]) - 1) + 1;
ans = max(ans, dp);
rmq.update(find(nums[i]), dp);
}
return ans;
}

private:
int m;
vector<int> x; // 此数组用于求 nums 中的值离散化之后的值

int find(int v) // 从 nums 的值找到对应的离散化之后的值
{
return lower_bound(x.begin(), x.end(), v) - x.begin() + 1;
}
};

也可以不单独写成类 BIT_RMQ,而是直接操作权值数组和 bit 数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Solution_6 {
public:
// 树状数组优化
int lengthOfLIS(vector<int>& nums) {
int n = nums.size();
for(int i : nums)
x.push_back(i);
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end()); // 把实际值离散化
m = x.size();
dp.assign(m + 1, 0);
bit.assign(m + 1, 0); // 树状数组初始化
int ans = 0;
for(int i = 0; i < n; ++i)
{
int w_idx = find(nums[i]);
int len = 1;
if(w_idx > 0)
len += query(1, w_idx);
ans = max(ans, len);
update(w_idx + 1, len);
}
return ans;
}

private:
vector<int> x; // 所有能值排序后去重
vector<int> dp; // 权值数组
int m; // 权值数组的长度

vector<int> bit; // 权值数组的树状数组

void update(int idx, int val)
{
dp[idx - 1] = val;
bit[idx] = val;
for(int i = idx; i <= m; i += lowbit(i))
{
for(int j = 1; j < lowbit(i); j <<= 1)
{
bit[i] = max(bit[i], bit[i - j]);
}
}
}

int query(int l, int r)
{
int ans = dp[r - 1];
while(true)
{
for(; r - lowbit(r) + 1 > l; r -= lowbit(r))
ans = max(ans, bit[r]);
if(r - lowbit(r) + 1 == l)
{
ans = max(ans, bit[r]);
return ans;
}
ans = max(ans, dp[r - 1]);
--r;
}
}

int lowbit(int v)
{
return v & (-v);
}

int find(int v)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

本题有一定的特殊性:

  1. 所有待查询区间都是从 0 开始的
  2. 所有的更新都把值(权值数组的值也就是dp值)变大了
  3. 查询的是最大值

在这些更强的条件下,可以把树状数组的更新和查询写的更简洁(见代码中的 updategetmax)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Solution_3 {
public:
// 树状数组优化
int lengthOfLIS(vector<int>& nums) {
int n = nums.size();
for(int i : nums)
x.push_back(i);
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end()); // 把实际值离散化
m = x.size();
bit.assign(m + 1, 0); // 树状数组初始化
int ans = 0;
for(int i = 0; i < n; ++i)
{
int dp = getmax(find(nums[i]) - 1) + 1;
ans = max(ans, dp);
update(find(nums[i]), dp);
}
return ans;
}

private:
int m;
vector<int> x; // 此数组用于求 nums 中的值离散化之后的值

int find(int v) // 从 nums 的值找到对应的离散化之后的值
{
return lower_bound(x.begin(), x.end(), v) - x.begin() + 1;
}

vector<int> bit; // 树状数组
int lowbit(int x)
{
return x & (-x);
}

int getmax(int x)
{
int ma = 0;
for(int i = x; i; i -= lowbit(i))
ma = max(ma, bit[i]);
return ma;
}

void update(int x, int v)
{
for(int i = x; i <= m; i += lowbit(i))
bit[i] = max(bit[i], v);
}
};

Share