权值线段树、权值树状数组:元素排名区间的权值(个数)和

  |  

摘要: 元素的计数,权值线段树、权值树状数组

【对数据分析、人工智能、金融科技、风控服务感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:潮汐朝夕
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


在文章 线段树和树状数组:单点修改、区间和查询 中,我们知道线段树和树状数组可以处理元素有修改时的区间和的查询。

在文章 线段树:区间修改(更改为定值)、区间最值查询树状数组:单点修改,区间最值查询 中,我们知道线段树和树状数组也可以处理区间最值查询的问题。

这种维护区间信息(例如区间和,区间最大/最小值)的线段树、树状数组,也叫区间线段树、区间树状数组。

对于元素计数的场景,我们经常想要知道给定的一组元素的计数的总数,或者计数最大的是哪个元素。如果元素之间可以比大小的话,那么就可以排序,之后就可以用类似于区间的形式来选一组元素考察其计数的和或计数的最值了。

权值线段树、权值树状数组就是处理这种情况的数据结构。权值线段树、权值树状数组维护元素值的计数,节点的位置代表元素值,节点的值代表元素的数量,其中叶子节点代表特定元素的数量,非叶子节点代表一个取值范围的元素的数量。

本文我们以 493. 翻转对 来看一下权值线段树和权值树状数组分别要怎么用。


$1 题目

493. 翻转对

给定一个数组 nums ,如果 $i < j$ 且 $nums[i] > 2 * nums[j]$ 我们就将 (i, j) 称作一个重要翻转对。

你需要返回给定数组中的重要翻转对的数量。

提示:

1
2
给定数组的长度不会超过50000。
输入数组中的所有数字都在32位整数的表示范围内。

示例 1:

输入: [1,3,2,3,1]
输出: 2
示例 2:

输入: [2,4,3,5,1]
输出: 3

$2 题解

本题有一个更常见的版本:面试题51. 数组中的逆序对,求数组中的逆序对个数(等价于冒泡排序的交换次数)。

此外 315. 计算右侧小于当前元素的个数327. 区间和的个数。都与本题一样,属于区间内的统计问题,并且都是有以下两种主流的做法,此外还有平衡树的做法:

  1. 离散化 + 权值线段树/权值树状数组。
  2. 基于时间的离线分治 (CDQ)。

线段树和树状数组内容参考 线段树和树状数组:单点修改、区间和查询

算法1: 离散化+权值线段树

维护区间信息的线段树,例如区间和,区间最大/最小值,也叫区间线段树。

权值线段树维护元素值的计数,节点的位置代表元素值,节点的值代表元素的数量,其中叶子节点代表特定元素的数量,非叶子节点代表一个取值范围的元素的数量。

每来一个新元素,递归地相应的叶子节点把节点值 $+1$,然后在回溯阶段把当前节点的值也 $+1$,从根到叶子的路径上的节点值就都被 $+1$ 了。

离散化

离散化的核心思想:将分布大却数量少(即稀疏)的数据进行集中化的处理,减少空间复杂度。

在统计元素的计数,以及求逆序时的过程中,不关心元素的实际值,只关心元素的大小关系。这样可以用排名代替原数组。流程:排序,去重,取原始数据的排名。

原始数组为 nums, ;排序去重后的数组为 x。则对原始数组中的数据 nums[i],它的从 0 开始计的排名是

1
upper_bound(x.begin(), x.end(), nums[i]) - x.begin();

例子:nums = [-7, 0, 4, 1e3+7, 1e7+7, 4, -1e5] 的离散化过程。离散化之后,大大缩小了数据(这里的权值)范围,并且可以用作数组下标。

关于离散化,还可以参考这篇文章:离散化

权值线段树的逻辑与区间线段树相同,只是更新时是对 index 的值(权值的计数) +1。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
struct STNode
{
int start, end;
int cnt;
STNode *left, *right;
STNode(int s, int e, int c, STNode* l=nullptr, STNode* r=nullptr)
:start(s),end(e),cnt(c),left(l),right(r){}
~STNode(){}
};

class SegmentTree
{
public:
SegmentTree()
{
root = nullptr;
}

~SegmentTree()
{
if(root)
{
delete_sub_tree(root);
}
}

void delete_sub_tree(STNode* node)
{
if(node -> left)
delete_sub_tree(node -> left);
if(node -> right)
delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

void build(int start, int end)
{
if(start <= end)
root = _build(start, end);
}

void add(int index)
{
_add(root, index);
}

int query(int i, int j)
{
if(i > j) return 0;
return _query(root, i, j);
}


private:
STNode *root;

STNode* _build(int start, int end)
{
if(start == end)
return new STNode(start, end, 0);
int mid = start + (end - start) / 2;
STNode *left = _build(start, mid);
STNode *right = _build(mid + 1, end);
return new STNode(start, end, 0, left, right);
}

void _add(STNode* root,int index)
{
if(!root) return;
if(root -> start == root -> end && root -> start == index)
{
++(root -> cnt);
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(index > mid)
_add(root -> right, index);
else
_add(root -> left, index);
++(root -> cnt);
}

int _query(STNode* root, int i, int j)
{
if(!root) return 0;
if(root -> start == i && root -> end == j)
return root -> cnt;
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
return _query(root -> left, i, j);
else if(i > mid)
return _query(root -> right, i, j);
else
return _query(root -> left, i, mid) + _query(root -> right, mid + 1, j);
}
};

class Solution {
public:
int reversePairs(vector<int>& nums) {
int n = nums.size();
if(n <= 1) return 0;
vector<ll> x; // 离散化后的值
for(int i: nums)
{
x.push_back(i);
x.push_back(2ll * i);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int m = x.size();
SegmentTree segmenttree;
segmenttree.build(0, m - 1);
segmenttree.add(_find(nums[0], x));
int result = 0;
for(int i = 1; i < n; ++i)
{
result += segmenttree.query(_find(2ll * nums[i], x) + 1, m - 1);
segmenttree.add(_find(nums[i], x));
}
return result;
}

private:
using ll = long long;

int _find(ll v, const vector<ll>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

算法2: 离散化+权值树状数组

权值树状数组的思想与权值线段树一样,数组下标是离散化后的元素值,数组的值是特定范围内元素的计数。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class BIT {
public:
BIT():cnts(1, 0){}
BIT(int n):cnts(n + 1, 0){}

void update(int index, int delta)
{
int n = cnts.size();
while(index < n)
{
cnts[index] += delta;
index += _lowbit(index);
}
}

int query(int i)
{
int cnt = 0;
while(i > 0)
{
cnt += cnts[i];
i -= _lowbit(i);
}
return cnt;
}

private:
vector<int> cnts;

int _lowbit(int x)
{
return x & (-x);
}
};

class Solution {
public:
int reversePairs(vector<int>& nums) {
if(nums.empty()) return 0;
int n = nums.size();

vector<ll> x; // 离散化后的值
for(int i: nums)
{
x.push_back(i);
x.push_back(2ll * i);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int m = x.size();
BIT bit(m); // bit 内部对 m 做了 + 1
bit.update(_find(nums[0], x), 1);
int ans = 0;
for(int i = 1; i < n; ++i)
{
ans += i - bit.query(_find(2ll * nums[i], x));
bit.update(_find(nums[i], x), 1);
}
return ans;
}

private:
using ll = long long;

int _find(ll v, const vector<ll>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin() + 1;
}
};

算法3: 基于时间的离线分治 (CDQ)

参考 离线分治:基于时间 (CDQ分治)

题目2

315. 计算右侧小于当前元素的个数

给你一个整数数组 nums ,按要求返回一个新数组 counts 。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。

提示:

1
2
1 <= nums.length <= 1e5
-1e4 <= nums[i] <= 1e4

示例 1:
输入:nums = [5,2,6,1]
输出:[2,1,1,0]
解释:
5 的右侧有 2 个更小的元素 (2 和 1)
2 的右侧仅有 1 个更小的元素 (1)
6 的右侧有 1 个更小的元素 (1)
1 的右侧有 0 个更小的元素

示例 2:
输入:nums = [-1]
输出:[0]

示例 3:
输入:nums = [-1,-1]
输出:[0,0]

算法1:离散化+权值线段树

493. 翻转对 思路一样。

代码(c++)

线段树的代码模板参考 线段树和树状数组:单点修改、区间和查询。根据具体问题有少量微调。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
struct STNode
{
int start, end;
int sum;
STNode *left, *right;
STNode(int s, int e, int v, STNode* l=nullptr, STNode* r=nullptr)
:start(s), end(e), sum(v), left(l), right(r) {}
~STNode(){}
};

class SegmentTree
{
public:
SegmentTree()
{
root = nullptr;
}

~SegmentTree()
{
if(root)
{
delete_sub_tree(root);
}
}

void build(int start, int end)
{
if(start <= end)
root = _build(start, end);
}

void add(int index)
{
_add(root, index);
}

int query(int i, int j)
{
if(i > j)
return 0;
return _query(root, i, j);
}

private:
STNode *root;

void delete_sub_tree(STNode* node)
{
if(node -> left)
delete_sub_tree(node -> left);
if(node -> right)
delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

int _query(STNode* root, int i, int j)
{
if(!root) return 0;
if(root -> start == i && root -> end == j)
return root -> sum;
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
return _query(root -> left, i, j);
if(mid < i)
return _query(root -> right, i, j);
return _query(root -> left, i, mid) + _query(root -> right, mid + 1, j);
}

void _add(STNode* root, int index)
{
if(!root) return;
if(root -> start == index && root -> end == index)
{
++(root -> sum);
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(mid >= index)
_add(root -> left, index);
else
_add(root -> right, index);
++(root -> sum);
}

STNode* _build(int start, int end)
{
if(start == end)
return new STNode(start, end, 0);
int mid = start + (end - start) / 2;
STNode *left = _build(start, mid);
STNode *right = _build(mid + 1, end);
return new STNode(start, end, 0, left, right);
}
};

class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
if(nums.empty()) return vector<int>();
int n = nums.size();
vector<int> x(nums.begin(), nums.end()); // 离散化后的值
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());
int m = x.size();

SegmentTree segmenttree;
segmenttree.build(0, m - 1);
vector<int> result(n);
result[n - 1] = 0;
segmenttree.add(_find(nums[n - 1], x));
for(int i = n - 2; i >= 0; --i)
{
result[i] = segmenttree.query(0, _find(nums[i], x) - 1);
cout << _find(nums[i], x) - 1 << endl;
segmenttree.add(_find(nums[i], x));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

算法2:平衡树

参考 手撕平衡树-Treap手撕平衡树-大小平衡树SBT

算法3:索引数组+CDQ分治

参考 线性索引表


Share