划分树:区间第k大

  |  

摘要: 划分树,区间第 k 大数

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们从 力扣315-索引数组,归并树 出发,引出一个维护区间第 k 大数的数据结构:划分树。

对于区间 $[i, j]$ 第 $k$ 大问题 query(i, j, k),有以下几种方法:

  1. 归并树的查询 query(i, j, x) 可以回答区间 [i, j] 中大于 x 的元素有多少个,因此可以二分地猜答案,然后用归并树的查询快的得到猜大了还是猜小了。
  2. 划分树将快速选择算法中,每一步分治的子问题结果保存下来。查询 query(i, j, k),时可以直接查找结果返回。
  3. 此外还有可持久化线段树基于值域的离线分治做法。

模板题

255. 第K小数

给定长度为 N 的整数序列 A,下标为 1 ∼ N。

现在要执行 M 次操作,其中第 i 次操作为给出三个整数 li,ri,ki,求 A[li],A[li+1],…,A[ri] 中第 ki 小的数是多少。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
输入格式
第一行包含两个整数 N 和 M。

第二行包含 N 个整数,表示整数序列 A。

接下来 M 行,每行包含三个整数 li,ri,ki,用以描述第 i 次操作。

输出格式
对于每次操作输出一个结果,表示在该次操作中,第 k 小的数的数值。每个结果占一行。

数据范围
N <= 1e5,
M <= 1e4,
|A[i]| <= 1e9

输入样例:
7 3
1 5 2 6 3 7 4
2 5 3
4 4 1
1 7 3
输出样例:
5
6
3

划分树

划分树和归并树都是用线段树作为辅助的,其中划分树的建树是模拟快排过程,自顶向下建树,归并树是模拟归并排序过程,自底向上建树,两种方法含有分治思想。

划分树的基本思想就是对于某个节点对应的区间 [start, end],把它划分成两个子区间,左边区间的数小于等于右边区间的数,左子区间对应划分后的 [start, mid],右子区间对应划分后的 [mid + 1, end],其中 mid = (left + right) / 2。查找的时候通过记录进入左子树的数的个数,确定下一个查找区间,直到区间长度变成1(start = end, 叶子节点),就找到了。

划分树的节点定义

1
2
3
4
5
start, end:节点对应的区间
nums[0..end-start]:节点持有的数据,一共 end - start + 1 个,是原始数组排好序之后在 [start, end] 范围的数字,但这里 nums 并没按排好序的顺序存放,而是按原数组的顺序放的。它具体含了哪些数字来自父节点划分的结果
toleft[0..end-start+2]:记录进入左子树的数字的个数
toleft[i] := [0 .. i - 1] 这 i 个数中,进入左子树的数字的个数,有前缀和的思想
toleft[0] = 0
1
2
3
4
5
6
7
8
9
struct PTNode
{
int start, end;
vector<int> nums, toleft;
PTNode *left, *right;
PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
:start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
~PTNode(){}
};

划分树的建树

过程基本就是模拟快排过程,对于区间 [start, end] 取一个已经排过序的区间中位数,然后把小于中值的点放左边,大于的放右边,等于中位数的需要单独统计,使得进入左子树的数字个数为 [mid - left + 1],确保树是平衡的。划分树建树之前先求取并持有原数组排序后的数组 sorted,节点的中位数直接就取 sorted[mid]。

1
int median =  sorted[mid];

只要当前节点不是叶子节点(start = end),就先建好两个子节点:

1
2
PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);

然后执行划分的算法流程,顺序枚举当前节点的所有数据,判断应该划分的方向,往对应的子节点塞进去。如果是塞进了左子树,toleft 加一。

[5, 2, 6, 1] 的建树过程

划分树的查询

query(i, j, k) 查询区间 [i, j] 中第 k 大的数

首先确定 [start ..i - 1], [start..j] 中去往左子树的数字个数 tli, tlj。其中 start = i 的情况需要特判。tlj - tli 是 [i..j] 中去往左子树的元素个数,记 tl。

1
2
3
4
5
int tli = 0;
if(root -> start != i)
tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;

看 k 和 tl 的关系决定子查询,类似于快速选择算法中考察 l 和 left + K 的关系。

1
2
3
tl >= k: 第 k 大在左子树,答案是 query(root -> left, new_i, new_j, k)
tl < k: 第 k 大在右子树,答案是 query(root -> right, new_i, new_j, k - tl)
new_i, new_j 需要画图推导

完整代码 (C++,模板)

query(i, j, k) 返回区间 $[i,j]$ 的第 K 大元素。(从小到大排序后的第 K 个)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
struct PTNode
{
int start, end;
vector<int> nums, toleft;
PTNode *left, *right;
PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
:start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
~PTNode(){}
};

class PartitionTree
{
public:
PartitionTree()
{
root = nullptr;
sorted = vector<int>();
}

~PartitionTree()
{
if(root)
{
delete_sub_tree(root);
}
}

void delete_sub_tree(PTNode* node)
{
if(node -> left)
delete_sub_tree(node -> left);
if(node -> right)
delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

void build(int start, int end, const vector<int>& nums)
{
sorted = nums;
sort(sorted.begin(), sorted.end());
root = new PTNode(start, end);
root -> nums = nums;
_build(root);
}

int query(int i, int j, int k)
{
if(i > j) return 0;
return _query(root, i, j, k);
}

private:
PTNode *root;
vector<int> sorted;

int _query(PTNode* root, int i, int j, int k)
{
if(root -> start == root -> end) return (root -> nums)[0];

int tli = 0;
if(root -> start != i)
tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;
int new_i, new_j;
if(tl >= k)
{
// 第 k 大在左子
new_i = root -> start + tli;
new_j = new_i + tl - 1;
return _query(root -> left, new_i, new_j, k);
}
else
{
// 第 k 大在右子
int mid = root -> start + (root -> end - root -> start) / 2;
new_i = mid + 1 + i - (root -> start) - tli;
new_j = new_i + j - i - tl;
return _query(root -> right, new_i, new_j, k - tl);
}
}

void _build(PTNode* root)
{
if(root -> start == root -> end)
return;
int mid = root -> start + (root -> end - root -> start) / 2;
int median = sorted[mid];
PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);
int n = (root -> nums).size();
int median_to_left = mid - root -> start + 1;
for(int i = 0; i < n; ++i)
{
if((root -> nums)[i] < median)
--median_to_left;
}
// 出循环后 median_to_left 为去往左子树中等于中位数的个数
int to_left = 0; // 去往左子树的个数
int idx_left = 0, idx_right = 0;
for(int i = 0; i < n; ++i)
{
int cur = (root -> nums)[i];
if(cur < median || ((cur == median) && median_to_left > 0))
{
(left -> nums)[idx_left] = cur;
++idx_left;
++to_left;
if(cur == median)
--median_to_left;
}
else
{
(right -> nums)[idx_right] = cur;
++idx_right;
}
(root -> toleft)[i + 1] = to_left;
}
_build(left);
_build(right);
root -> left = left;
root -> right = right;
}
};

代码测试

但是跑的效果很差,因为查询只有1次,需要查询次数较多时才能体现划分树的优势。

查询次数为 M 时,时间复杂度 $O(N\log N + M\log N)$,比 M 次快速选择的 $O(MN)$ 和归并树的 $O(N\log N + M\log^{3}N)$ 都好。

1
2
3
4
5
6
7
8
9
10
11
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1)
return nums[0];
PartitionTree *partitiontree = new PartitionTree();
partitiontree -> build(0, n - 1, nums);
return partitiontree -> query(0, n - 1, n + 1 - k);
}
};

模板题代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main()
{
int N, M;
cin >> N >> M;
vector<int> nums(N);
for(int i = 0; i < N; ++i)
cin >> nums[i];

PartitionTree *partitiontree = new PartitionTree();
partitiontree -> build(0, N - 1, nums);

for(int i = 0; i < M; ++i)
{
int l, r, k;
cin >> l >> r >> k;
l--;
r--;
int ans = partitiontree -> query(l, r, k);
cout << ans << endl;
}
}

Share