手撕平衡树-大小平衡树SBT

  |  

摘要: 大小平衡树的设计与实现

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


我们知道二叉查找树在多次插入删除之后会变得不平衡,使得树的深度逐渐大于 $O(\log N)$,使得查询的效率变低。如果能引入一些机制使得树在插入删除的过程中保持平衡,则查询的效率就可以保证了,这种在二叉查找树的基础上加上平衡机制的树,称为平衡树。

不同的平衡机制,可以形成不同的平衡树,大小平衡树就是其中一种。大小平衡树中的大小的含义是节点个数,在插入删除的过程中,保持左子树的节点个数和右子树的节点个数相等或相差 1,以这种方式保持平衡。本文我们就详细拆解大小平衡树的设计与实现,并解决几个力扣上的题目。主要内容如下:

  • 大小平衡树SBT(Size Balanced Tree) 的背景和定义
    • 节点定义
    • 保持 SBT 的平衡性需要节点中 size 字段满足的性质
  • SBT 的实现
    • maintain 操作及其优化,用到 BST 节点的旋转
    • insert 和 remove 操作相比于带 size 字段的 BST 中的修改
  • SBT 的完整代码
  • 最适合用 SBT 的题 面试题 10.10. 数字流的秩
  • 用 SBT 解 315. 计算右侧小于当前元素的个数
  • 带 size 字段的 BST 以及用带 size 字段的 BST 解力扣315, BST 节点的旋转,参考 手撕平衡树-二叉查找树BST
  • SBT 原始论文见置底

$1 SBT 的定义

大小平衡树(Size Balanced Tree),是一棵通过大小(Size)字段来维持平衡的二叉搜索树。

SBT不仅支持典型的二叉搜索树操作,而且也支持以下两个操作:

  • Select(k) 返回第 k 大元素对应的指针,取最大最小当然也可以做
  • Rank(x) 比 x 小的元素个数

以下内容涉及到带 size 字段的 BST,以及BST 节点的旋转,这两部分内容参考 手撕平衡树-二叉查找树BST

节点定义

如下,与带 size 字段的 BST 的节点定义完全一样。

1
2
3
4
5
6
7
8
9
struct SBTNode
{
int data;
SBTNode *left;
SBTNode *right;
int size;
SBTNode():left(nullptr),right(nullptr){}
SBTNode(const int& x, SBTNode* p=nullptr, SBTNode* q=nullptr, int s=1):data(x),left(p),right(q),size(s){}
};

SBT 对 size 的约束

对任意节点 t,必须满足以下两条性质

  • 性质1:
1
2
t -> right -> size >= t -> left -> left -> size;
t -> right -> size >= t -> left -> right -> size;
  • 性质2:
1
2
t -> left -> size >= t -> right -> right -> size;
t -> left -> size >= t -> right -> left -> size;

基本图

如图: 结点 L 和 R 分别是结点 T 的左右儿子。子树 A、B、C 和 D 分别是结点 L 和 R 各自的左右子树。

需要 L 的 size 既大于等于 C 的 size,又大于等于 D 的 size;R 的 size 既大于等于 A 的 size,又大于等于 B 的 size

一颗带 size 字段的 BST,如果其所有节点的 size 字段满足以上两条性质,则该树为 SBT。

$2 SBT 的实现

SBT 的接口定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class SizeBalancedTree
{
public:
SizeBalancedTree():root(nullptr){}
~SizeBalancedTree()
{
if(root)
_delete_sub_tree(root);
}

bool find(const int& x) const
{
return find(x, root);
}

void insert(const int& x)
{
insert(x, root);
}

void remove(const int& x)
{
remove(x, root);
}

int lessthan(const int& x) const
{
return lessthan(x, root);
}

private:
SBTNode *root;

void _delete_sub_tree(SBTNode* node)
{
if(node -> left)
_delete_sub_tree(node -> left);
if(node -> right)
_delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

int lessthan(const int& x, SBTNode* t) const;
void insert(const int& x, SBTNode*& t);
void remove(const int& x, SBTNode*& t);
SBTNode* find(const int& x, SBTNode* t) const;
void maintain(SBTNode*& t);
void maintain(SBTNode*& t, bool flag);

void left_rotate(SBTNode*& t)
{
// 调用方保证 t -> right 存在
SBTNode *tmp = t -> right;
t -> right = tmp -> left;
tmp -> left = t;
tmp -> size = t -> size;
t -> size = 1;
if(t -> right)
t -> size += t -> right -> size;
if(t -> left)
t -> size += t -> left -> size;
t = tmp;
}

void right_rotate(SBTNode*& t)
{
// 调用方保证 t -> left 存在
SBTNode *tmp = t -> left;
t -> left = tmp -> right;
tmp -> right = t;
tmp -> size = t -> size;
t -> size = 1;
if(t -> right)
t -> size += t -> right -> size;
if(t -> left)
t -> size += t -> left -> size;
t = tmp;
}
};

与带 size 字段的 BST 相比,findlessthan 不变,insertremove 中增加一步 maintain 操作(后续经过画图分析,remove 中的 maintain 可以省略掉),

maintain 是主要增加的部分,当一个 SBT 插入或删除一个元素后,某些节点的 size 可能不满足两条性质了,此时 maintain(t) 的作用就是使得 t 的 size 重新满足两条性质。

left_rotateright_rotate 是 BST 节点的旋转,在 maintain 中要用到,但他们是 BST 的通用操作,不专属于 SBT。

maintain 操作

在执行完简单的插入之后,性质1 或性质2 可能就不满足了,于是我们需要调整 SBT。

一个节点 T,在 insert(x, T) 完成后,maintain(T) 用于调整以 T 为根的 SBT。这里 T 的子树已经是 SBT,如图。

基本图

由于性质1 和性质2 是对称的, 所以仅详细的讨论性质1。

对应 T 节点的性质 1 被破坏,有两种情况:

  • 情况1:t -> left -> left -> size > t -> right -> size

即基本图中 A 的 size 大于 R 的 size。

此时先对 T 执行一次右旋,变为下图

基本图

之后,这棵树还不一定是 SBT,因为 B 的性质不一定满足,即 C 和 D 的 size 与 B 的 size 的大小关系是不确定的,因此对上图中的 T 再做一次 maintain

之后,这可以依然还不一定是 SBT,因为 A 的性质还不一定满足,因此对上图中的 L 再做一次 maintain

在实际代码中,注意指针为空的判断。

1
2
3
4
5
6
7
if((t -> left && t -> left -> left) && (!t -> right || t -> left -> left -> size > t -> right -> size))
{
right_rotate(t);
maintain(t -> right);
maintain(t);
return;
}
  • 情况2:t -> left -> right -> size > t -> right -> size

即基本图图中 B 的 size 大于 R 的 size。

情况2 比情况1 稍微复杂一些,可以先把基本图再细化以下,形成下图:

除了 E、B、F 以外,其他结点都和基本图中的定义一样。E、F 是结点 B 的子树。

首先对 L 执行一次左旋,形成下图:

然后对 T 执行一次右旋,形成下图:

以上两步旋转之后,这棵以 B 为根的树的结构好像不太可控,但是根据各个字母的原始定义,子树 A、E、F 和 R 仍就是 SBT, 只是 B, L, T 仍然不确定是 BST。

所以我们可以调用 Maintain(L)Maintain(T) 来修复结点 B 的子树。

现在 L, T 子树都已经是 SBT 了, 但是在结点 B 上还可能不满足性质 1 或性质 2, 因此需要再一次调用 Maintain(B)

在实际代码中,还是要主要指针为空的判断。

1
2
3
4
5
6
7
8
9
if((t -> left && t -> left -> right) && (!t -> right || t -> left -> right -> size > t -> right -> size))
{
left_rotate(t -> left);
right_rotate(t);
maintain(t -> left);
maintain(t -> right);
maintain(t);
return;
}

对于性质2被破坏,也有两种情况,分析过程与性质1一样,可以类似地画图分析。

  • maintain 的完整代码如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void SizeBalancedTree::maintain(SBTNode*& t)
{
if((t -> left && t -> left -> left) && (!t -> right || t -> left -> left -> size > t -> right -> size))
{
right_rotate(t);
maintain(t -> right);
maintain(t);
return;
}
if((t -> left && t -> left -> right) && (!t -> right || t -> left -> right -> size > t -> right -> size))
{
left_rotate(t -> left);
right_rotate(t);
maintain(t -> left);
maintain(t -> right);
maintain(t);
return;
}
if((t -> right && t -> right -> right) && (!t -> left || t -> right -> right -> size > t -> left -> size))
{
left_rotate(t);
maintain(t -> left);
maintain(t);
return;
}
if((t -> right && t -> right -> left) && (!t -> left || t -> right -> left -> size > t -> left -> size))
{
right_rotate(t -> right);
left_rotate(t);
maintain(t -> left);
maintain(t -> right);
maintain(t);
}
}

maintain 的优化

通常我们可以保证性质 a 和性质 b 的满足,因此我们只需要检查情况 1 和情况 2 或者情况 3 和情况 4,这样可以提高速度。所以在那种情况下,我们需要增加一个布尔(boolean)型变量,flag,来避免毫无疑义的判断。

如果 flag 是 false,那么检查情况 1 和情况 2;否则检查情况 3 和情况 4。

代码如下,maintain(t, flag) 与原始 maintain(*t) 区别不大。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
void SizeBalancedTree::maintain(SBTNode*& t, bool flag)
{
if(!flag)
{
if((t -> left && t -> left -> left) && (!t -> right || t -> left -> left -> size > t -> right -> size))
{
right_rotate(t);
}
else
{
if((t -> left && t -> left -> right) && (!t -> right || t -> left -> right -> size > t -> right -> size))
{
left_rotate(t -> left);
right_rotate(t);
}
else
return;
}
}
else
{
if((t -> right && t -> right -> right) && (!t -> left || t -> right -> right -> size > t -> left -> size))
{
left_rotate(t);
}
else
{
if((t -> right && t -> right -> left) && (!t -> left || t -> right -> left -> size > t -> left -> size))
{
right_rotate(t -> right);
left_rotate(t);
}
else
return;
}
}
maintain(t -> left, false);
maintain(t -> right, true);
maintain(t, false);
maintain(t, true);
}

唯一要注意的是 Maintain(left[t],true)Maintain(right[t],false) 被省略,这是有依据的,可以证明,不过并没有看懂,参考论文的分析章节。

插入

与带 size 字段的 BST 的插入唯一区别就是最后加了一个 maintain

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void SizeBalancedTree::insert(const int& x, SBTNode*& t)
{
if(t == nullptr)
{
t = new SBTNode(x, nullptr, nullptr, 1);
return;
}
++(t -> size);
if(t -> data > x)
insert(x, t -> left);
else
insert(x, t -> right);
// maintain(t);
maintain(t, x >= t -> data);
}

删除

删除与 手撕平衡树-二叉查找树BST 中提到的带 size 字段的 BST 的删除用的方法不同。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
void SizeBalancedTree::remove(const int& x)
{
int v = remove(x, root);
if(x != v)
insert(v, root);
}

int SizeBalancedTree::remove(const int& x, SBTNode*& t)
{
--(t -> size);
int deleted;
if(x == t -> data || (x < t -> data && !t -> left) || (x > t -> data && !t -> right))
{
deleted = t -> data;
if(!t -> left || !t -> right)
{
if(!t -> left && !t -> right)
{
delete t;
t = nullptr;
}
else
{
SBTNode *tmp = t;
if(tmp -> left)
{
// tmp -> left = nullptr; 必须加上
// 否则 tmp 被 delete 时,tmp -> left 也一并被 delete
t = tmp -> left;
tmp -> left = nullptr;
}
else
{
t = tmp -> right;
tmp -> right = nullptr;
}
delete tmp;
tmp = nullptr;
}
}
else
{
// t -> left 一定均比 deleted + 1 小
t -> data = remove(deleted + 1, t -> left);
}
}
else
{
if(x < t -> data)
deleted = remove(x, t -> left);
else
deleted = remove(x, t -> right);
}
// maintain(t) 可以去掉,需要画图分析
// if(t)
// maintain(t);
return deleted;
}

主要区别在树中没有要删除的元素 x 时的处理上,之前是如果没有要删的元素,就直接返回。

这里的处理如下:

(1) 如果在 SBT 中没有找到被删除元素 x, 我们就删除搜索到的最后一个结点 t 并做后处理(代码中第22行的else判断成功就是这种情况)。

  • t 是叶子节点,则直接删掉(代码中第17行判断成功是这种情况)
  • t 不是叶子节点,则肯定只有一颗子树不为空(因为若两个子树都不为空,对应的是代码中第47行的else判断成功的情况),则删除并做后处理(代码中第22行的else判断成功就是这种情况)。

(2) 如果在 SBT 中找到被删的元素 x, 对应的节点为 t

  • t 是叶子节点,则直接删掉(代码中第17行判断成功是这种情况)
  • t 不是叶子节点,且有一颗子树不为空,则删除并做后处理(代码中第22行的else判断成功就是这种情况)。
  • t 不是叶子节点,且两个子树都不为空(代码中第41行的else判断成功是这种情况),则找到 t 的前驱,将其值交给 t,然后删除 t 的前驱。实现方式: t -> data = remove(deleted + 1, t -> left);,这不调用就把前驱节点删了,并将前驱节点的值返回,因为 t -> left 上的值都比 deleted + 1 小,因此找不到待删除元素,且最后一个访问的节点就是小于 delete + 1 的最大值,即前驱节点的值。

v = remove(x, t) 返回的是被删除的节点的值,如果 v != x 则说明 v 是由于树中没有要删的元素 x 而被删的最后一个访问到的元素,此时将 v 再插入回去(代码第4行判断成功即是这种情况)

在删除 remove(x, t) 之后,如果 t 不为空,还有一步 maintain(t),但是这一步可以省略,可以通过画图分析得到。

$3 SBT 完整代码 (模板)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
struct SBTNode
{
int data;
SBTNode *left;
SBTNode *right;
int size;
SBTNode():left(nullptr),right(nullptr){}
SBTNode(const int& x, SBTNode* p=nullptr, SBTNode* q=nullptr, int s=1):data(x),left(p),right(q),size(s){}
~SBTNode(){}
};

class SizeBalancedTree
{
public:
SizeBalancedTree():root(nullptr){}
~SizeBalancedTree()
{
if(root)
_delete_sub_tree(root);
}

bool find(const int& x) const
{
return find(x, root);
}

void insert(const int& x)
{
insert(x, root);
}

void remove(const int& x)
{
int v = remove(x, root);
if(x != v)
insert(v, root);
}

int greaterthan(const int& x) const
{
return greaterthan(x, root);
}

int lessthan(const int& x) const
{
return lessthan(x, root);
}

private:
SBTNode *root;

void _delete_sub_tree(SBTNode* node)
{
if(node -> left)
_delete_sub_tree(node -> left);
if(node -> right)
_delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

int lessthan(const int& x, SBTNode* t) const;
int greaterthan(const int& x, SBTNode* t) const;
void insert(const int& x, SBTNode*& t);
int remove(const int& x, SBTNode*& t);
SBTNode* find(const int& x, SBTNode* t) const;
void maintain(SBTNode*& t);
void maintain(SBTNode*& t, bool flag);

void left_rotate(SBTNode*& t)
{
// 调用方保证 t -> right 存在
SBTNode *tmp = t -> right;
t -> right = tmp -> left;
tmp -> left = t;
tmp -> size = t -> size;
t -> size = 1;
if(t -> right)
t -> size += t -> right -> size;
if(t -> left)
t -> size += t -> left -> size;
t = tmp;
}

void right_rotate(SBTNode*& t)
{
// 调用方保证 t -> left 存在
SBTNode *tmp = t -> left;
t -> left = tmp -> right;
tmp -> right = t;
tmp -> size = t -> size;
t -> size = 1;
if(t -> right)
t -> size += t -> right -> size;
if(t -> left)
t -> size += t -> left -> size;
t = tmp;
}
};

void SizeBalancedTree::maintain(SBTNode*& t)
{
if((t -> left && t -> left -> left) && (!t -> right || t -> left -> left -> size > t -> right -> size))
{
right_rotate(t);
maintain(t -> right);
maintain(t);
return;
}
if((t -> left && t -> left -> right) && (!t -> right || t -> left -> right -> size > t -> right -> size))
{
left_rotate(t -> left);
right_rotate(t);
maintain(t -> left);
maintain(t -> right);
maintain(t);
return;
}
if((t -> right && t -> right -> right) && (!t -> left || t -> right -> right -> size > t -> left -> size))
{
left_rotate(t);
maintain(t -> left);
maintain(t);
return;
}
if((t -> right && t -> right -> left) && (!t -> left || t -> right -> left -> size > t -> left -> size))
{
right_rotate(t -> right);
left_rotate(t);
maintain(t -> left);
maintain(t -> right);
maintain(t);
}
}

void SizeBalancedTree::maintain(SBTNode*& t, bool flag)
{
if(!flag)
{
if((t -> left && t -> left -> left) && (!t -> right || t -> left -> left -> size > t -> right -> size))
{
right_rotate(t);
}
else
{
if((t -> left && t -> left -> right) && (!t -> right || t -> left -> right -> size > t -> right -> size))
{
left_rotate(t -> left);
right_rotate(t);
}
else
return;
}
}
else
{
if((t -> right && t -> right -> right) && (!t -> left || t -> right -> right -> size > t -> left -> size))
{
left_rotate(t);
}
else
{
if((t -> right && t -> right -> left) && (!t -> left || t -> right -> left -> size > t -> left -> size))
{
right_rotate(t -> right);
left_rotate(t);
}
else
return;
}
}
maintain(t -> left, false);
maintain(t -> right, true);
maintain(t, false);
maintain(t, true);
}

int SizeBalancedTree::lessthan(const int& x, SBTNode* t) const
{
if(t -> data >= x)
{
if(!t -> left)
return 0;
return lessthan(x, t -> left);
}
// t -> data < x
int ans = 1;
if(t -> left)
ans += t -> left -> size;
if(t -> right)
ans += lessthan(x, t -> right);
return ans;
}

int SizeBalancedTree::greaterthan(const int& x, SBTNode* t) const
{
if(t -> data <= x)
{
if(!t -> right)
return 0;
return greaterthan(x, t -> right);
}
// t -> data > x
int ans = 1;
if(t -> right)
ans += t -> right -> size;
if(t -> left)
ans += greaterthan(x, t -> left);
return ans;
}

SBTNode* SizeBalancedTree::find(const int& x, SBTNode* t) const
{
if(t == nullptr) return nullptr;
else if(t -> data > x) return find(x, t -> left);
else if(t -> data < x) return find(x, t -> right);
else return t;
}

void SizeBalancedTree::insert(const int& x, SBTNode*& t)
{
if(t == nullptr)
{
t = new SBTNode(x, nullptr, nullptr, 1);
return;
}
++(t -> size);
if(t -> data > x)
insert(x, t -> left);
else
insert(x, t -> right);
// maintain(t);
maintain(t, x >= t -> data);
}

int SizeBalancedTree::remove(const int& x, SBTNode*& t)
{
--(t -> size);
int deleted;
if(x == t -> data || (x < t -> data && !t -> left) || (x > t -> data && !t -> right))
{
deleted = t -> data;
if(!t -> left || !t -> right)
{
if(!t -> left && !t -> right)
{
delete t;
t = nullptr;
}
else
{
SBTNode *tmp = t;
if(tmp -> left)
{
// tmp -> left = nullptr; 必须加上
// 否则 tmp 被 delete 时,tmp -> left 也一并被 delete
t = tmp -> left;
tmp -> left = nullptr;
}
else
{
t = tmp -> right;
tmp -> right = nullptr;
}
delete tmp;
tmp = nullptr;
}
}
else
{
// t -> left 一定均比 deleted + 1 小
t -> data = remove(deleted + 1, t -> left);
}
}
else
{
if(x < t -> data)
deleted = remove(x, t -> left);
else
deleted = remove(x, t -> right);
}
// maintain(t) 可以去掉,需要画图分析
// if(t)
// maintain(t);
return deleted;
}

$4 最适合用 SBT 的题: 面试题 10.10. 数字流的秩

lessthan 函数稍有修改。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
int SizeBalancedTree::lessthan(const int& x, SBTNode* t) const
{
if(t -> data > x)
{
if(!t -> left)
return 0;
return lessthan(x, t -> left);
}
// t -> data < x
int ans = 1;
if(t -> left)
ans += t -> left -> size;
if(t -> right)
ans += lessthan(x, t -> right);
return ans;
}

class StreamRank {
public:
StreamRank() {
sbt = SizeBalancedTree();
}

void track(int x) {
sbt.insert(x);
}

int getRankOfNumber(int x) {
return sbt.lessthan(x);
}

private:
SizeBalancedTree sbt;
};

$5 用 SBT 解 315. 计算右侧小于当前元素的个数

本题只用到了 insertlessthan

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution_5 {
public:
vector<int> countSmaller(vector<int>& nums) {
if(nums.empty()) return {};
int n = nums.size();
vector<int> result(n);
SizeBalancedTree sbt;
result[n - 1] = 0;
sbt.insert(nums[n - 1]);
for(int i = n - 2; i >= 0; --i)
{
result[i] = sbt.lessthan(nums[i]);
sbt.insert(nums[i]);
}
return result;
}
};

$6 一些分析

高度分析,maintain 的时间复杂度分析,删除的分析,优化的 maintainmaintain(left[t],true)maintain(right[t],false) 可以被省略的证明。

参考中文版论文中的第 6 小节。论文如下:


Share