手撕平衡树-数组模拟链表实现二叉查找树

  |  

摘要: 用数组模拟链表的方式实现二叉查找树

【对数据分析、人工智能、金融科技、风控服务感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:潮汐朝夕
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们用数组模拟链表的方式实现二叉查找树,并形成代码模板。

关于二叉查找树,在文章 手撕平衡树-二叉查找树BST 中,我们学习了二叉查找树的定义,基于链表的实现,给出了 leetcode 上关于 BST 的题目列表,并且引入了节点旋转操作。

关于数组模拟链表,在文章 用数组模拟双向循环链表 中,我们学习了用数组模拟的方式实现链表的方法,并在一道题上进行了实践,然后在文章 结合链表容易删除的特点使用逆向思维 中给出了另一个例题。在文章 二分图匹配-最大匹配 中,匈牙利算法的代码模板包括基于 vector 和基于数组模拟链表这两种。

为了代码模板的对比方便,我们在这里直接贴一下基于链表的 BST 代码模板。


$0 基于链表的 BST 代码模板

说明: 假设没有重复元素,没有 size 字段。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
struct BSTNode
{
int data;
BSTNode *left;
BSTNode *right;
BSTNode():left(nullptr),right(nullptr){}
BSTNode(const int& x, BSTNode* p=nullptr, BSTNode* q=nullptr):data(x),left(p),right(q){}
~BSTNode()
{
left = nullptr;
right = nullptr;
}
};

class BinarySearchTree
{
public:
BinarySearchTree():root(nullptr){}
~BinarySearchTree()
{
if(root)
_delete_sub_tree(root);
}

bool find(const int& x) const
{
return find(x, root);
}

void insert(const int& x)
{
insert(x, root);
}

void remove(const int& x)
{
remove(x, root);
}

private:
BSTNode *root;

void _delete_sub_tree(BSTNode* node)
{
if(node -> left)
_delete_sub_tree(node -> left);
if(node -> right)
_delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

void insert(const int& x, BSTNode*& t);
void remove(const int& x, BSTNode*& t);
BSTNode* find(const int& x, BSTNode* t) const;
};

BSTNode* BinarySearchTree::find(const int& x, BSTNode* t) const
{
if(t == nullptr) return nullptr;
else if(t -> data > x) return find(x, t -> left);
else if(t -> data < x) return find(x, t -> right);
else return t;
}

void BinarySearchTree::insert(const int& x, BSTNode*& t)
{
if(t == nullptr)
t = new BSTNode(x, nullptr, nullptr);
else if(t -> data > x)
insert(x, t -> left);
else
insert(x, t -> right);
}

void BinarySearchTree::remove(const int& x, BSTNode*& t)
{
if(t == nullptr) return;
if(x < t -> data)
remove(x, t -> left);
if(x > t -> data)
remove(x, t -> right);
if(x == t -> data)
{
// 找到被删节点 t
if(t -> left != nullptr && t -> right != nullptr)
{
// t 有两个子节点
BSTNode *successor = t -> right;
while(successor -> left != nullptr)
successor = successor -> left;
t -> data = successor -> data;
remove(t -> data, t -> right);
}
else
{
// t 只有 1 个子节点,或没有子节点
BSTNode *oldNode = t;
t = (t -> left != nullptr) ? t -> left : t -> right;
delete oldNode;
}
}
}

基于以上 BST 实现的左旋和右旋。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void left_rotate(BSTNode*& t)
{
// 调用方保证 t -> right 存在
BSTNode *tmp = t -> right;
t -> right = tmp -> left;
tmp -> left = t;
t = tmp;
}

void right_rotate(BSTNode*& t)
{
// 调用方保证 t -> left 存在
BSTNode *tmp = t -> left;
t -> left = tmp -> right;
tmp -> right = t;
t = tmp;
}

$1 BST 复习

在二叉树中,有两类非常重要的条件,分别是两类数据结构的基础性质。一个是堆性质,另一个就是 BST 性质。

给定一个二叉树,每个节点有一个数值,称为关键码,BST 性质是指

  1. 该节点的关键莫不小于它的左子树中任意节点的关键码
  2. 该节点的关键莫不大于它的右子树中任意节点的关键码

满足以上性质的二叉树就是二叉查找树(BST),中序遍历是一个单调递增的节点序列。

节点定义

节点编号从 1 开始,l, r 为 0 表示没有子节点。

1
2
3
4
5
struct BSTNode
{
int l, r;
int val;
};

树的建立

为了减少越界以及边界情况的特殊判断,在 BST 中额外插入一个关键码为正无穷和一个关键码为负无穷的节点,由这两个节点构成的 BST 是空的 BST。

在含有 size 字段的 BST 中,需要注意这两个额外的节点是否统计在内。详见后面的带 size 字段的 BST 的实现。

假设 BST 中不含关键码相同的节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
const int SIZE = 1e5;
BSTNode a[SIZE];

int tot, root;
const int INF = 1 << 30;

int New(int val)
{
// 开新节点
a[++tot].val = val;
return tot;
}

void init()
{
// 初始化空树
New(-INF);
New(INF);
root = 1;
a[1].r = 2;
}

树的清空

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void delete_sub_tree(int p)
{
if(a[p].l > 0)
delete_sub_tree(a[p].l);
if(a[p].r > 0)
delete_sub_tree(a[p].r);
a[p].l = 0;
a[p].r = 0;
}

void delete_tree()
{
delete_sub_tree(root);
tot = 0;
}

查找

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
bool find(int val)
{
return find(val, root) != 0;
}

int find(int val, int p)
{
if(p == 0)
return 0; // 没找到, 叶节点的 l, r 均为 0
if(val == a[p].val)
return p; // 找到了
else if(val < a[p].val)
return find(val, a[p].l);
else
return find(val, a[p].r);
}

插入

遇到 val 与 a[p].val 相同的情况,往 val 的右子树插入。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void insert(int val)
{
insert(val, root);
}

void insert(int val, int& p)
{
if(p == 0)
p = New(val);
else if(a[p].val > val)
insert(val, a[p].l);
else
insert(val, a[p].r);
}

删除

首先找到关键码为 val 的节点,如果没有子节点或只有一个子节点,删除比较简单。

如果有两个子节点,则先找到其后继节点 successor。

让 successor 代替 p 的位置,删除 successor 节点,由于 successor 没有左子节点,属于比较简单的删除情况。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
void remove(int val)
{
remove(val, root);
}

void remove(int val, int& p)
{
if(p == 0)
return;
if(val < a[p].val)
remove(val, a[p].l);
else if(val > a[p].val)
remove(val, a[p].r);
else
{
if(a[p].l == 0)
{
// 没有左子树
p = a[p].r; // 右子树代替 p 的位置
}
else if(a[p].r == 0)
{
// 没有右子树
p = a[p].l; // 左子树代替 p 的位置
}
else
{
// 既有左子树又有右子树
// 求后继结点
int successor = a[p].r;
while(a[successor].l != 0)
successor = a[successor].l;
a[p].val = a[successor].val;
remove(a[p].val, a[p].r);
}
}
}

$2 带 size 字段的 BST

除了常规的 BST 功能以外,还需要支持以下查询。

问: BST 中小于于 x 的元素个数有多少,或者问 BST 中第 k 小的元素。

与 BST 的变化的点如下

  • 节点定义: 除了增加 size 字段,没有变化
  • 查找: 不变
  • 插入和删除: 需要动态维护 size
  • 新增的需求查询:小于 x 的元素有多少个

这里直接给出代码模板

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
struct BSTNode
{
int l, r;
int val;
int size;
};

const int SIZE = 2e5;
BSTNode a[SIZE];

int tot, root;
const int INF = 1 << 30;

int New(int val)
{
// 开新节点
a[++tot].val = val;
a[tot].size = 1;
return tot;
}

void init()
{
// 初始化空树
New(-INF);
New(INF);
root = 1;
a[1].r = 2;
}

void delete_tree()
{
delete_sub_tree(root);
tot = 0;
}

void delete_sub_tree(int p)
{
if(a[p].l > 0)
delete_sub_tree(a[p].l);
if(a[p].r > 0)
delete_sub_tree(a[p].r);
a[p].l = 0;
a[p].r = 0;
}

int find(int val, int p)
{
if(p == 0)
return 0; // 没找到, 叶节点的 l, r 均为 0
if(val == a[p].val)
return p; // 找到了
else if(val < a[p].val)
return find(val, a[p].l);
else
return find(val, a[p].r);
}

bool find(int val)
{
return find(val, root) != 0;
}

void insert(int val, int& p)
{
if(p == 0)
{
p = New(val);
return;
}
++a[p].size;
if(a[p].val > val)
insert(val, a[p].l);
else
insert(val, a[p].r);
}

void insert(int val)
{
insert(val, root);
}

void remove(int val, int& p)
{
if(p == 0)
return;
--a[p].size;
if(val < a[p].val)
remove(val, a[p].l);
else if(val > a[p].val)
remove(val, a[p].r);
else
{
if(a[p].l == 0)
{
// 没有左子树
p = a[p].r; // 右子树代替 p 的位置
}
else if(a[p].r == 0)
{
// 没有右子树
p = a[p].l; // 左子树代替 p 的位置
}
else
{
// 既有左子树又有右子树
// 求后继结点
int successor = a[p].r;
while(a[successor].l != 0)
successor = a[successor].l;
a[p].val = a[successor].val;
remove(a[p].val, a[p].r);
}
}
}

void remove(int val)
{
remove(val, root);
}

int lessthan(int x, int p)
{
if(a[p].val >= x)
{
if(a[p].l == 0)
return 0;
return lessthan(x, a[p].l);
}
// a[p].val < x
int ans = 1;
if(a[p].l > 0)
ans += a[a[p].l].size;
if(a[p].r > 0)
ans += lessthan(x, a[p].r);
return ans;
}

int lessthan(int x)
{
// -INF 的点需要去掉
return lessthan(x, root) - 1;
}

例子

以下是利用这颗带 size 字段的 BST 解力扣315 题,315. 计算右侧小于当前元素的个数,本题只用到了插入,和需求功能的查询。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
if(nums.empty()) return {};
int n = nums.size();
vector<int> result(n);
init();
result[n - 1] = 0;
insert(nums[n - 1]);
for(int i = n - 2; i >= 0; --i)
{
result[i] = lessthan(nums[i]);
insert(nums[i]);
}
delete_tree();
return result;
}
};

$3 左旋和右旋

  • 左旋 zag(p) 操作通过更改两个指针将左边两个结点的结构转变成右边的结构,
  • 右边的结构也可以通过相反的操作 zig(p) 来转变成左边的结构。

右旋 (zig)

1
2
3
4
5
6
7
8
void zig(int& p)
{
// 调用方保证 a[p].l 不为 0
int tmp = a[p].l;
a[p].l = a[tmp].r;
a[tmp].r = p;
p = tmp;
}

左旋 (zag)

1
2
3
4
5
6
7
8
void zag(int& p)
{
// 调用方保证 a[p].r 不为 0
int tmp = a[p].r;
a[p].r = a[tmp].l;
a[tmp].l = p;
p = tmp;
}

含 size 字段的旋转

左旋和右旋的改动均是在 p = tmp 之前加入以下两步对 size 的更新。

1
2
3
4
5
a[p].size = 1;
if(a[p].r > 0)
a[p].size += a[a[p].r].size;
if(a[p].l > 0)
a[p].size += a[a[p].l].size;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void zig(int &p)
{
// 调用方保证 a[p].l 不为 0
int tmp = a[p].l;
a[p].l = a[tmp].r;
a[tmp].r = p;
a[tmp].size = a[p].size;
a[p].size = 1;
if(a[p].r > 0)
a[p].size += a[a[p].r].size;
if(a[p].l > 0)
a[p].size += a[a[p].l].size;
p = tmp;
}

void zag(int &p)
{
// 调用方保证 a[p].r 不为 0
int tmp = a[p].r;
a[p].r = a[tmp].l;
a[tmp].l = p;
a[tmp].size = a[p].size;
a[p].size = 1;
if(a[p].r > 0)
a[p].size += a[a[p].r].size;
if(a[p].l > 0)
a[p].size += a[a[p].l].size;
p = tmp;
}

Share