手撕平衡树-Treap

  |  

摘要: Treap 的原理、代码模板、例题

【对数据分析、人工智能、金融科技、风控服务感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:潮汐朝夕
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


此前,我们已经写过不少关于树的文章,包括二叉树、二叉查找树、平衡树。

满足 BST 性质,且中序遍历序列相同的二叉查找树不是唯一的。这些 BST 维护的是相同的一组数值。

因此我们可以通过改变二叉树的形态,使得树上每个节点的左右子树大小达到平衡。

我们要解决两个问题。一个是改变树的形态如何实现,另一个是以什么策略改变树的形态。

第一个问题我们此前已经解决过,就是左旋和右旋,具体可以参考 手撕平衡树-二叉查找树BST手撕平衡树-数组模拟链表实现二叉查找树

对于第二个问题,思路就有很多了,以前实现过的大小平衡树是其中一种,参考 手撕平衡树-大小平衡树SBT

本文我们看另一种非常常见的改变树的形态的策略,Treap。并给出基于指针和基于数组模拟链表两种实现方式的代码模板。


Treap

随机数据下,普通 BST 就是平衡的。Treap 的思想就是利用随机来创造平衡条件。

因为旋转过程中必须满足 BST 性质,所以 Treap 就把随机作用在堆性质上。

Treap 通过适当的单旋转,在维持节点关键码满足 BST 性质的同时,还使得每个节点上随机生成的额外权值满足大根堆性质。关于堆,可以参考这篇文章 二叉堆

节点定义

节点的值为 val,随机权重为 w。

假定数据会有重复,我们用 cnt 表示值为 val 的节点的个数。

对于小于 val 的元素个数这种查询,我们需要一个 size 字段记录子树的节点个数。

1
2
3
4
5
6
7
8
9
struct TreapNode
{
int val;
int size;
int cnt;
int w;
TreapNode *left, *right;
TreapNode():size(1),cnt(1),w(rand()),left(nullptr),right(nullptr){}
};

以上定义的各个字段,val 和 w 是常规必须的,cnt 和 size 是应付常见的需求而加的,这里我们实现的代码模板是基于后面介绍的模板题的,需要这两个字段。

根据情况,如果不需要的 cnt 和 size 的话,可以去掉,这样只需要在代码模板中修改对应的位置即可。

更新 size

用当前节点的 cnt 和左右子树的 size 更新当前子树的 size。

这是带 size 字段时需要的,在 insert、remove、zig、zag 操作中会用到。

1
2
3
4
5
6
7
8
void update_size(TreapNode* p)
{
p -> size = p -> cnt;
if(p -> left)
p -> size += p -> left -> size;
if(p -> right)
p -> size += p -> right -> size;
}

左旋和右旋

这里是含 size 字段的左旋和右旋。手撕平衡树-数组模拟链表实现二叉查找树 中有数组模拟链表时,含 size 字段的左旋和右旋代码。下面的模板题中也会给出数组模拟链表实现的代码模板。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void zig(TreapNode*& p)
{
// 右旋
TreapNode *tmp = p -> left;
p -> left = tmp -> right;
tmp -> right = p;
p = tmp;
update_size(p -> right);
update_size(p);
}

void zag(TreapNode*& p)
{
// 左旋
TreapNode *tmp = p -> right;
p -> right = tmp -> left;
tmp -> left = p;
p = tmp;
update_size(p -> left);
update_size(p);
}

插入

给该节点一个随机的权值,然后像二叉堆的插入过程一样,自底向上依次检查,当某个节点不满足大顶堆性质时,就执行单旋转,使得该点与父节点的关系发生对换。按照以下旋转策略,可以保持堆性质。

如果左子节点与当前节点不满足堆性质,则右旋。
如果右子节点与当前节点不满足堆性质,则左旋。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void insert(int val)
{
insert(val, root);
}

void insert(int val, TreapNode*& p)
{
if(p == nullptr)
{
p = new TreapNode();
p -> val = val;
return;
}
if(p -> val == val)
++(p -> cnt);
else if(p -> val < val)
{
insert(val, p -> right);
if(p -> w < p -> right -> w)
zag(p);
}
else
{
insert(val, p -> left);
if(p -> w < p -> left -> w)
zig(p);
}
update_size(p);
}

删除

因为 Treap 支持旋转,我们可以直接找到需要删除的节点,并把它向下旋转成叶节点,最后直接删除。这样避免了采取普通 BST 删除方法时,导致的节点信息更新,堆性质维护等复杂问题。

注意 update_size 的调用位置。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
void remove(int val)
{
remove(val, root);
}

void remove(int val, TreapNode*& p)
{
if(p == nullptr)
return;
if(val == p -> val)
{
if(p -> cnt > 1)
{
--(p -> cnt);
update_size(p);
return;
}
if(p -> left || p -> right)
{
// 不是叶节点,向下旋转
if(!p -> right || (p -> left && p -> left -> w > p -> right -> w))
{
zig(p);
remove(val, p -> right);
}
else
{
zag(p);
remove(val, p -> left);
}
update_size(p);
}
else
{
// 叶节点
delete p;
p = nullptr;
}
return;
}
if(val < p -> val)
remove(val, p -> left);
else
remove(val, p -> right);
update_size(p);
}

查询

种类给出前驱和后继的查询。如果要精确查询的话,稍作修改即可。

前驱

val 的前驱,小于 val 的最大值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int get_precursor(int val)
{
// 小于 val 的最大的
TreapNode *ans = nullptr;
TreapNode *p = root;
while(p)
{
if(val == p -> val)
{
if(p -> left)
{
p = p -> left;
while(p -> right)
p = p -> right;
ans = p;
}
break;
}
if(p -> val < val && (!ans || p -> val > ans -> val))
ans = p;
if(val < p -> val)
p = p -> left;
else
p = p -> right;
}
if(!ans)
return -INF;
return ans -> val;
}

后继

val 的后继,大于 val 的最小值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int get_successor(int val)
{
// 大于 val 的最小的
TreapNode *ans = nullptr;
TreapNode *p = root;
while(p)
{
if(val == p -> val)
{
if(p -> right)
{
p = p -> right;
while(p -> left)
p = p -> left;
ans = p;
}
break;
}
if(p -> val > val && (!ans || p -> val < ans -> val))
ans = p;
if(val > p -> val)
p = p -> right;
else
p = p -> left;
}
if(!ans)
return INF;
return ans -> val;
}

小于 val 的元素个数

这是带 size 字段时的查询。注意节点中的 cnt 字段的处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
int lessthan(int val)
{
return lessthan(val, root);
}

int lessthan(int val, TreapNode* p)
{
if(p == nullptr)
return 0;
if(val == p -> val)
{
int ans = 0;
if(p -> left)
ans += p -> left -> size;
return ans;
}
if(val < p -> val)
return lessthan(val, p -> left);
// val > p -> val
int ans = p -> cnt;
ans += lessthan(val, p -> right);
if(p -> left)
ans += p -> left -> size;
return ans;
}

第 k 大元素

有了 size 字段后,第 k 大的查询也比较方便。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
int get_value(int rank)
{
return get_value(rank, root);
}

int get_value(int rank, TreapNode* p)
{
if(p == nullptr)
return INF;
if(p -> left && p -> left -> size >= rank)
return get_value(rank, p -> left);
int s = p -> cnt;
if(p -> left)
s += p -> left -> size;
if(s >= rank)
return p -> val;
// rank > s
return get_value(rank - s, p -> right);
}

模板题: 带 size 字段的 Treap

写一种数据结构,来维护一些数,其中需要提供以下操作:

插入数值 x。
删除数值 x(若有多个相同的数,应只删除一个)。
查询数值 x 的排名(若有多个相同的数,应输出最小的排名)。
查询排名为 x 的数值。
求数值 x 的前驱(前驱定义为小于 x 的最大的数)。
求数值 x 的后继(后继定义为大于 x 的最小的数)。

算法: Treap

数据中可能有相同的值,可以增加一个 cnt 字段,记录节点的副本数。初始为 1。当减为 0 时,删除该节点。

size 字段用于查询排名。插入和删除时,需要自底向上更新 size 信息。旋转时,也要同时修改 size 信息。

代码(C++) 使用指针的实现方式

节点定义以及各种操作前面都有介绍过,这里直接给出完整代码(可以作为模板使用)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#include <iostream>

using namespace std;

const int INF = 1e9;

struct TreapNode
{
int val;
int size;
int cnt;
int w;
TreapNode *left, *right;
TreapNode():size(1),cnt(1),w(rand()),left(nullptr),right(nullptr){}
};

class Treap
{
public:
Treap():root(nullptr){}
~Treap()
{
delete_sub_tree(root);
}

void insert(int val)
{
insert(val, root);
}

void remove(int val)
{
remove(val, root);
}

int lessthan(int val)
{
return lessthan(val, root);
}

int get_value(int rank)
{
return get_value(rank, root);
}

int get_successor(int val)
{
// 大于 val 的最小的
TreapNode *ans = nullptr;
TreapNode *p = root;
while(p)
{
if(val == p -> val)
{
if(p -> right)
{
p = p -> right;
while(p -> left)
p = p -> left;
ans = p;
}
break;
}
if(p -> val > val && (!ans || p -> val < ans -> val))
ans = p;
if(val > p -> val)
p = p -> right;
else
p = p -> left;
}
if(!ans)
return INF;
return ans -> val;
}

int get_precursor(int val)
{
// 小于 val 的最大的
TreapNode *ans = nullptr;
TreapNode *p = root;
while(p)
{
if(val == p -> val)
{
if(p -> left)
{
p = p -> left;
while(p -> right)
p = p -> right;
ans = p;
}
break;
}
if(p -> val < val && (!ans || p -> val > ans -> val))
ans = p;
if(val < p -> val)
p = p -> left;
else
p = p -> right;
}
if(!ans)
return -INF;
return ans -> val;
}

private:
TreapNode *root;

void update_size(TreapNode* p)
{
// 用当前节点的 cnt 和左右子树的 size 更新当前子树的 size
p -> size = p -> cnt;
if(p -> left)
p -> size += p -> left -> size;
if(p -> right)
p -> size += p -> right -> size;
}

void zig(TreapNode*& p)
{
// 右旋
TreapNode *tmp = p -> left;
p -> left = tmp -> right;
tmp -> right = p;
p = tmp;
update_size(p -> right);
update_size(p);
}

void zag(TreapNode*& p)
{
// 左旋
TreapNode *tmp = p -> right;
p -> right = tmp -> left;
tmp -> left = p;
p = tmp;
update_size(p -> left);
update_size(p);
}

void insert(int val, TreapNode*& p)
{
if(p == nullptr)
{
p = new TreapNode();
p -> val = val;
return;
}
if(p -> val == val)
++(p -> cnt);
else if(p -> val < val)
{
insert(val, p -> right);
if(p -> w < p -> right -> w)
zag(p);
}
else
{
insert(val, p -> left);
if(p -> w < p -> left -> w)
zig(p);
}
update_size(p);
}

void remove(int val, TreapNode*& p)
{
if(p == nullptr)
return;
if(val == p -> val)
{
if(p -> cnt > 1)
{
--(p -> cnt);
update_size(p);
return;
}
if(p -> left || p -> right)
{
// 不是叶节点,向下旋转
if(!p -> right || (p -> left && p -> left -> w > p -> right -> w))
{
zig(p);
remove(val, p -> right);
}
else
{
zag(p);
remove(val, p -> left);
}
update_size(p);
}
else
{
// 叶节点
delete p;
p = nullptr;
}
return;
}
if(val < p -> val)
remove(val, p -> left);
else
remove(val, p -> right);
update_size(p);
}

int lessthan(int val, TreapNode* p)
{
if(p == nullptr)
return 0;
if(val == p -> val)
{
int ans = 0;
if(p -> left)
ans += p -> left -> size;
return ans;
}
if(val < p -> val)
return lessthan(val, p -> left);
// val > p -> val
int ans = p -> cnt;
ans += lessthan(val, p -> right);
if(p -> left)
ans += p -> left -> size;
return ans;
}

int get_value(int rank, TreapNode* p)
{
if(p == nullptr)
return INF;
if(p -> left && p -> left -> size >= rank)
return get_value(rank, p -> left);
int s = p -> cnt;
if(p -> left)
s += p -> left -> size;
if(s >= rank)
return p -> val;
// rank > s
return get_value(rank - s, p -> right);
}

void delete_sub_tree(TreapNode* p)
{
if(p -> left)
delete_sub_tree(p -> left);
if(p -> right)
delete_sub_tree(p -> right);
delete p;
p = nullptr;
}
};

int main()
{
int n;
cin >> n;
srand((unsigned)time(0));
Treap treap = Treap();
for(int i = 0; i < n; ++i)
{
int opt, x;
cin >> opt >> x;
switch(opt)
{
case 1:
treap.insert(x);
break;
case 2:
treap.remove(x);
break;
case 3:
cout << treap.lessthan(x) + 1 << endl;
break;
case 4:
cout << treap.get_value(x) << endl;
break;
case 5:
cout << treap.get_precursor(x) << endl;
break;
case 6:
cout << treap.get_successor(x) << endl;
break;
}
}
}

代码(C++) 数组模拟链表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#include <iostream>
#include <cstdlib>

using namespace std;

const int SIZE = 100005;

struct TreapNode
{
int l, r;
int val;
int w; // 权值
int cnt; // 副本数
int size; // 子树大小
};

TreapNode a[SIZE];

int tot, root, n, INF = 0x7fffffff;

int New(int val)
{
a[++tot].val = val;
a[tot].w = rand();
a[tot].cnt = 1;
a[tot].size = 1;
return tot;
}

void update_size(int p)
{
// 用当前节点的 cnt 和左右子树的 size 更新当前子树的 size
a[p].size = a[a[p].l].size + a[a[p].r].size + a[p].cnt;
}

void init()
{
New(-INF);
New(INF);
root = 1;
a[1].r = 2;
update_size(root);
}

void zig(int& p)
{
// 调用方保证 a[p].l 不为 0
int tmp = a[p].l;
a[p].l = a[tmp].r;
a[tmp].r = p;
p = tmp;
update_size(a[p].r);
update_size(p);
}

void zag(int& p)
{
// 调用方保证 a[p].r 不为 0
int tmp = a[p].r;
a[p].r = a[tmp].l;
a[tmp].l = p;
p = tmp;
update_size(a[p].l);
update_size(p);
}

void insert(int& p, int val)
{
if(p == 0)
{
p = New(val);
return;
}
if(val == a[p].val)
++a[p].cnt;
else if(val < a[p].val)
{
// 往左子树插入
insert(a[p].l, val);
if(a[p].w < a[a[p].l].w)
zig(p); // p 的左子节点违反堆性质 右旋
}
else
{
// 往右子树插入
insert(a[p].r, val);
if(a[p].w < a[a[p].r].w)
zag(p); // p 的右子节点违反堆性质 左旋
}
update_size(p);
}

void insert(int val)
{
insert(root, val);
}

void remove(int& p, int val)
{
if(p == 0)
return;
if(val == a[p].val)
{
if(a[p].cnt > 1)
{
--a[p].cnt;
update_size(p);
return;
}
if(a[p].l > 0 || a[p].r > 0)
{
// 不是叶子节点,向下旋转
if(a[p].r == 0 || a[a[p].l].w > a[a[p].r].w)
{
zig(p);
remove(a[p].r, val);
}
else
{
zag(p);
remove(a[p].l, val);
}
update_size(p);
}
else
{
// 叶子节点,删除
p = 0;
}
return;
}
if(val < a[p].val)
remove(a[p].l, val);
else
remove(a[p].r, val);
update_size(p);
}

void remove(int val)
{
remove(root, val);
}

int get_successor(int val)
{
// 后继: 大于 val 的最小的数
int ans = 2; // a[2].val = INF
int p = root;
while(p != 0)
{
if(val == a[p].val)
{
if(a[p].r > 0)
{
p = a[p].r;
while(a[p].l > 0)
p = a[p].l;
ans = p;
}
break;
}
if(a[p].val > val && a[p].val < a[ans].val)
ans = p;
if(a[p].val < val)
p = a[p].r;
else
p = a[p].l;
}
return a[ans].val;
}

int get_precursor(int val)
{
// 前驱: 小于 val 的最大的数
int ans = 1; // a[1].val = -INF
int p = root;
while(p != 0)
{
if(val == a[p].val)
{
if(a[p].l > 0)
{
p = a[p].l;
while(a[p].r > 0)
p = a[p].r;
ans = p;
}
break;
}
if(a[p].val < val && a[p].val > a[ans].val)
ans = p;
if(val < a[p].val)
p = a[p].l;
else
p = a[p].r;
}
return a[ans].val;
}

int lessthan(int p, int val)
{
if(p == 0)
return 0;
if(val == a[p].val)
return a[a[p].l].size;
if(val < a[p].val)
return lessthan(a[p].l, val);
return lessthan(a[p].r, val) + a[a[p].l].size + a[p].cnt;
}

int lessthan(int val)
{
// 小于 val 的个数, -INF 去掉
return lessthan(root, val) - 1;
}

int get_value(int p, int rank)
{
if(p == 0)
return INF;
if(a[a[p].l].size >= rank)
return get_value(a[p].l, rank);
if(a[a[p].l].size + a[p].cnt >= rank)
return a[p].val;
return get_value(a[p].r, rank - a[a[p].l].size - a[p].cnt);
}

int get_value(int rank)
{
// 根据 rank 返回值, 要把 -INF 占的名额去掉
return get_value(root, rank + 1);
}

void delete_sub_tree(int p)
{
if(a[p].l > 0)
delete_sub_tree(a[p].l);
if(a[p].r > 0)
delete_sub_tree(a[p].r);
a[p].l = 0;
a[p].r = 0;
}

void delete_tree()
{
delete_sub_tree(root);
tot = 0;
}

int main()
{
init();
int n;
cin >> n;
for(int i = 0; i < n; ++i)
{
int opt, x;
cin >> opt >> x;
switch(opt)
{
case 1:
insert(x);
break;
case 2:
remove(x);
break;
case 3:
cout << lessthan(x) + 1 << endl;
break;
case 4:
cout << get_value(x) << endl;
break;
case 5:
cout << get_precursor(x) << endl;
break;
case 6:
cout << get_successor(x) << endl;
break;
}
}
}

例子

以下是利用这棵带 size 字段的 Treap 解力扣315 题,315. 计算右侧小于当前元素的个数,本题只用到了插入,和 lessthen 的查询。

  • 指针的实现方式
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
if(nums.empty()) return {};
int n = nums.size();
vector<int> result(n);
Treap treap;
result[n - 1] = 0;
treap.insert(nums[n - 1]);
for(int i = n - 2; i >= 0; --i)
{
result[i] = treap.lessthan(nums[i]);
treap.insert(nums[i]);
}
return result;
}
};
  • 数组模拟链表的实现方式
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
if(nums.empty()) return {};
int n = nums.size();
vector<int> result(n);
init();
result[n - 1] = 0;
insert(nums[n - 1]);
for(int i = n - 2; i >= 0; --i)
{
result[i] = lessthan(nums[i]);
insert(nums[i]);
}
delete_tree();
return result;
}
};

Share