二维线段树:单点修改,矩阵和查询

  |  

摘要: 二维线段树,单点修改与矩阵查询,原理与实现

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文通以 308. 二维区域和检索 - 可变 作为模板题来看一下二维线段树的原理与实现。

本文的线段树支持单点修改和矩阵查询这一套需求。


模板题 308. 二维区域和检索 - 可变

给你一个二维矩阵 matrix ,处理以下类型的多个查询:

  1. 更新 matrix 中单元格的值。
  2. 计算由 左上角 (row1, col1) 和 右下角 (row2, col2) 定义的 matrix 内矩阵元素的 和。

实现 NumMatrix 类:

  • NumMatrix(int[][] matrix) 用整数矩阵 matrix 初始化对象。
  • void update(int row, int col, int val) 更新 matrix[row][col] 的值到 val 。
  • int sumRegion(int row1, int col1, int row2, int col2) 返回矩阵 matrix 中指定矩形区域元素的 和 ,该区域由 左上角 (row1, col1) 和 右下角 (row2, col2) 界定。

提示:

1
2
3
4
5
6
7
8
9
10
m == matrix.length
n == matrix[i].length
1 <= m, n <= 200
-1e5 <= matrix[i][j] <= 1e5
0 <= row < m
0 <= col < n
-1e5 <= val <= 1e5
0 <= row1 <= row2 < m
0 <= col1 <= col2 < n
最多调用1e4 次 sumRegion 和 update 方法

示例 1:
输入
[“NumMatrix”, “sumRegion”, “update”, “sumRegion”]
[[[[3, 0, 1, 4, 2], [5, 6, 3, 2, 1], [1, 2, 0, 1, 5], [4, 1, 0, 1, 7], [1, 0, 3, 0, 5]]], [2, 1, 4, 3], [3, 2, 2], [2, 1, 4, 3]]
输出
[null, 8, null, 10]
解释
NumMatrix numMatrix = new NumMatrix([[3, 0, 1, 4, 2], [5, 6, 3, 2, 1], [1, 2, 0, 1, 5], [4, 1, 0, 1, 7], [1, 0, 3, 0, 5]]);
numMatrix.sumRegion(2, 1, 4, 3); // 返回 8 (即, 左侧红色矩形的和)
numMatrix.update(3, 2, 2); // 矩阵从左图变为右图
numMatrix.sumRegion(2, 1, 4, 3); // 返回 10 (即,右侧红色矩形的和)

二维线段树:四叉树实现

四叉树的基本思想是将空间递归划分为左上,右上,左下,右下四个部分。它将已知范围的空间等分成四个相等的子空间,如此递归下去,直至树的层次达到一定深度或者满足某种要求后停止分割。

例如题目: 427. 建立四叉树558. 四叉树交集,在节点表示的区域中元素都相等的时候停止分割。

四叉树实现的二维线段树,与一维线段树在逻辑上是相同的。

节点定义

1
2
3
4
5
6
7
8
9
10
struct QTNode
{
int top, bottom, left, right;
int sum;
QTNode *topleft, *topright, *bottomleft, *bottomright;
QTNode(int t, int b, int l, int r, int sum,
QTNode* tl=nullptr, QTNode* tr=nullptr, QTNode* bl=nullptr, QTNode* br=nullptr)
:top(t),bottom(b),left(l),right(r),sum(sum),topleft(tl),topright(tr),bottomleft(bl),bottomright(br){}
~QTNode(){}
};

接口

首先接口的逻辑与一维情况是一样的:

1
2
3
4
5
6
7
8
// 一维线段树接口
build(arr)
point_update(i, v)
range_query(start, end, v)
// 二维线段树接口
build(matrix)
point_update(i, j, v)
range_query(top, bottom, left, right, v)

分别为建树,单点更新,区间查询。区别如下:

  • 建树的时候,分别是一维区间和二维矩形;
  • 更新的时候,一个用 i 表示位置,一个用 (i, j) 表示位置;
  • 查询的时候,一个用 (start, end) 表示范围,一个用 (top, bottom, left, right) 表示范围。

建树

  • 一维的建树,build(arr),通过 start == end 判断叶子节点。
  • 二维的建树,build(matrix),通过 top == bottom && left == right 判断叶子节点。

更新

  • 一维的更新,point_update(i, v),通过 i 与 nodeMid 的关系判断单点 i 要往哪棵子树走。通过 nodeLeft == nodeRight && nodeLeft == i 判断到达 i 的位置。
  • 二维的更新,point_update(i, j, v),通过 i 与 nodeTopBottomMid 以及 j 与 nodeLeftRightMid 的关系判断单点 (i, j) 往哪棵子树走。在叶子判断上,有三种情况
1
2
3
两个方向均收敛到单点:nodeTop == nodeBottom && nodeTop == i && nodeLeft == nodeRight && nodeLeft == j
仅上下方向收敛到单点:nodeTop == nodeBottom && nodeTop == i
仅左右方向收敛到单点:nodeLeft == nodeRight && nodeLeft == j

i, j 的判断是独立的,因此 if, else 会比较多,容易写错。

查询

一维的查询,range_query(start, end, v),用 nodeLeft == start && nodeRight == end 判断所查区间与节点表示的区间正好重合。通过 start, end 与 nodeMid 的关系确定是需要查找哪个子区间:

1
2
3
end <= nodeMid: 待查找区间 [start, end] 只在左子树
nodeMid < start: 待查找区间 [start, end] 只在右子树
否则:两个子区间各包含了待查找区间 [start, end] 的一部分

二维的查询,range_query(top, bottom, left, right, v),用 nodeTop == top && nodeBottom == bottom && nodeLeft == left && nodeRight == right 判断所查矩形与节点表示的矩形正好重合。

通过 bottom, top 与 nodeTopBottomMid 的关系确定上,下半边是否需要查询,left, right 与 nodeLeftRightMid 的关系确定左,右半边是否需要查询。

这两个是独立的,因此 if, else 比较多,容易写错。

1
2
3
4
5
6
7
bottom <= nodeTopBottomMid: 待查找区间 [top, bottom, left, right] 只在上半边的两个子树
nodeTopBottomMid < top: 待查找区间 [top, bottom, left, right] 只在下半边的两个子树
否则:上下两个半边的子树各包含了待查找区间 [top, bottom, left, right] 的一部分

right <= nodeLeftRightMid: 待查找区间 [top, bottom, left, right] 只在左半边的两个子树
nodeLeftRightMid < left: 待查找区间 [top, bottom, left, right] 只在右半边的两个子树
否则:左右两个半边的子树各包含了待查找区间 [top, bottom, left, right] 的一部分

两个维度独立判断,组合起来共有 9 种进入子树的情况。

代码 (C++,模板,基于链表)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
struct QTNode
{
int top, bottom, left, right;
int sum;
QTNode *topleft, *topright, *bottomleft, *bottomright;
QTNode(int t, int b, int l, int r, int sum,
QTNode* tl=nullptr, QTNode* tr=nullptr, QTNode* bl=nullptr, QTNode* br=nullptr)
:top(t),bottom(b),left(l),right(r),sum(sum),topleft(tl),topright(tr),bottomleft(bl),bottomright(br){}
~QTNode(){}
};

class QT_2D_SegmentTree
{
public:
QT_2D_SegmentTree()
{
root = nullptr;
}

~QT_2D_SegmentTree()
{
if(root)
{
delete_sub_tree(root);
}
}

void build(int top, int bottom, int left, int right, const vector<vector<int>>& matrix)
{
// cout << "build" << endl;
if(top <= bottom && left <= right)
root = _build(top, bottom, left, right, matrix);
}

void point_update(int i, int j, int val)
{
// cout << "update" << endl;
_point_update(root, i, j, val);
// traverse();
}

int range_query(int row1, int col1, int row2, int col2)
{
// cout << "query" << endl;
return _range_query(root, row1, row2, col1, col2);
}

void traverse()
{
cout << "================" << endl;
_traverse(root);
cout << "================" << endl;
}

private:
QTNode *root;

void delete_sub_tree(QTNode* node)
{
for(QTNode* son: vector<QTNode*>{node -> topleft, node -> topright, node -> bottomleft, node -> bottomright})
{
if(son)
{
delete_sub_tree(son);
}
}
delete node;
node = nullptr;
}

void _traverse(QTNode* root)
{
cout << root -> top << " " << root -> bottom << " " << root -> left << " " << root -> right << endl;
cout << "sum: " << root -> sum << endl;
for(auto *son: vector<QTNode*>{root -> topleft, root -> topright, root -> bottomleft, root -> bottomright})
{
if(!son) continue;
_traverse(son);
}
}

int _range_query(QTNode* root, int top, int bottom, int left, int right)
{
// cout << "-----------------------------" << endl;
// cout << "查询区间: " << i_top << " " << i_bottom << " " << j_left << " " << j_right << endl;
// cout << "节点表示区间 : " << root -> top << " " << root -> bottom << " " << root -> left << " " << root -> right << endl;
// cout << "-----------------------------" << endl;
int nodeTop = root -> top, nodeBottom = root -> bottom;
int nodeLeft = root -> left, nodeRight = root -> right;
if(nodeTop == top && nodeBottom == bottom && nodeLeft == left && nodeRight == right)
{
// cout << "top == bottom && left == right" << endl;
return root -> sum;
}

// 通过 bottom, top 与 nodeTopBottomMid 的关系确定上,下半边是否需要查询,left, right 与 nodeLeftRightMid 的关系确定左,右半边是否需要查询。这两个是独立的,因此 if, else 比较多,容易写错。
int nodeTopBottomMid = (nodeBottom + nodeTop) / 2;
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
if(bottom <= nodeTopBottomMid)
{
if(right <= nodeLeftRightMid)
return _range_query(root -> topleft, top, bottom, left, right);
else if(nodeLeftRightMid < left)
return _range_query(root -> topright, top, bottom, left, right);
else
{
int s1 = _range_query(root -> topleft, top, bottom, left, nodeLeftRightMid);
int s2 = _range_query(root -> topright, top, bottom, nodeLeftRightMid + 1, right);
return s1 + s2;
}
}
else if(nodeTopBottomMid < top)
{
if(right <= nodeLeftRightMid)
return _range_query(root -> bottomleft, top, bottom, left, right);
else if(nodeLeftRightMid < left)
return _range_query(root -> bottomright, top, bottom, left, right);
else
{
int s1 = _range_query(root -> bottomleft, top, bottom, left, nodeLeftRightMid);
int s2 = _range_query(root -> bottomright, top, bottom, nodeLeftRightMid + 1, right);
return s1 + s2;
}
}
else
{
if(right <= nodeLeftRightMid)
{
int s1 = _range_query(root -> topleft, top, nodeTopBottomMid, left, right);
int s2 = _range_query(root -> bottomleft, nodeTopBottomMid + 1, bottom, left, right);
return s1 + s2;
}
else if(nodeLeftRightMid < left)
{

int s1 = _range_query(root -> topright, top, nodeTopBottomMid, left, right);
int s2 = _range_query(root -> bottomright, nodeTopBottomMid + 1, bottom, left, right);
return s1 + s2;
}
else
{
int s1 = _range_query(root -> topleft, top, nodeTopBottomMid, left, nodeLeftRightMid);
int s2 = _range_query(root -> bottomleft, nodeTopBottomMid + 1, bottom, left, nodeLeftRightMid);
int s3 = _range_query(root -> topright, top, nodeTopBottomMid, nodeLeftRightMid + 1, right);
int s4 = _range_query(root -> bottomright, nodeTopBottomMid + 1, bottom, nodeLeftRightMid + 1, right);
return s1 + s2 + s3 + s4;
}
}
}


void _point_update(QTNode* root, int i, int j, int val)
{
// 根据 i 与 [top, bottom], j 与 [left, right] 确定(i, j)在哪个子树
// 更新完叶子,回溯阶段更新当前节点sum
int nodeTop = root -> top, nodeBottom = root -> bottom;
int nodeLeft = root -> left, nodeRight = root -> right;
if(nodeTop == nodeBottom && nodeTop == i && nodeLeft == nodeRight && nodeLeft == j)
{
root -> sum = val;
return;
}

int nodeTopBottomMid = (nodeBottom + nodeTop) / 2;
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
if(i <= nodeTopBottomMid)
{
if(j <= nodeLeftRightMid)
_point_update(root -> topleft, i, j, val);
else
_point_update(root -> topright, i, j, val);
}
else
{
if(j <= nodeLeftRightMid)
_point_update(root -> bottomleft, i, j, val);
else
_point_update(root -> bottomright, i, j, val);
}
int sum = 0;
if(root -> topleft) sum += root -> topleft -> sum;
if(root -> topright) sum += root -> topright -> sum;
if(root -> bottomleft) sum += root -> bottomleft -> sum;
if(root -> bottomright) sum += root -> bottomright -> sum;
root -> sum = sum;
}

// topleft(tl),topright(tr),bottomleft(bl),bottomright(br)
QTNode* _build(int top, int bottom, int left, int right, const vector<vector<int>>& matrix)
{
if(top == bottom && left == right)
return new QTNode(top, bottom, left, right, matrix[top][left]);

int nodeTopBottomMid = (top + bottom) / 2;
int nodeLeftRightMid = (left + right) / 2;
// top == bottom 和 left == right 恰有一个存在
if(left == right)
{
QTNode *topleft = _build(top, nodeTopBottomMid, left, right, matrix);
QTNode *bottomleft = _build(nodeTopBottomMid + 1, bottom, left, right, matrix);
int sum = topleft -> sum + bottomleft -> sum;
return new QTNode(top, bottom, left, right, sum, topleft, nullptr, bottomleft, nullptr);
}
if(top == bottom)
{
QTNode *topleft = _build(top, bottom, left, nodeLeftRightMid, matrix);
QTNode *topright = _build(top, bottom, nodeLeftRightMid + 1, right, matrix);
int sum = topleft -> sum + topright -> sum;
return new QTNode(top, bottom, left, right, sum, topleft, topright, nullptr, nullptr);
}

// top < bottom 和 left < right 均满足
QTNode *topleft = _build(top, nodeTopBottomMid, left, nodeLeftRightMid, matrix);
QTNode *topright = _build(top, nodeTopBottomMid, nodeLeftRightMid + 1, right, matrix);
QTNode *bottomleft = _build(nodeTopBottomMid + 1, bottom, left, nodeLeftRightMid, matrix);
QTNode *bottomright = _build(nodeTopBottomMid + 1, bottom, nodeLeftRightMid + 1, right, matrix);
int sum = topleft -> sum + topright -> sum + bottomleft -> sum + bottomright -> sum;
return new QTNode(top, bottom, left, right, sum, topleft, topright, bottomleft, bottomright);
}

};

class NumMatrix {
public:
NumMatrix(vector<vector<int>>& matrix) {
int n = 0, m = 0;
if(!matrix.empty())
n = matrix.size(), m = matrix[0].size();
sttree_2d = QT_2D_SegmentTree();
sttree_2d.build(0, n - 1, 0, m - 1, matrix);
// sttree_2d.traverse();
}

void update(int row, int col, int val) {
sttree_2d.point_update(row, col, val);
}

int sumRegion(int row1, int col1, int row2, int col2) {
return sttree_2d.range_query(row1, col1, row2, col2);
}

private:
QT_2D_SegmentTree sttree_2d;
};

代码 (C++,模板,基于数组)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class Seq_QT_2D_SegmentTree
{
public:
Seq_QT_2D_SegmentTree()
{
st_vec = vector<int>();
n = -1;
m = -1;
}

void build(const vector<vector<int>>& arr)
{
if(arr.empty()) return;
n = arr.size(), m = arr[0].size();
// 32 是试出来的
st_vec.resize(n * m * 32);
_build(1, 0, n - 1, 0, m - 1, arr);
}

void point_update(int i, int j, int v)
{
_point_update(1, 0, n - 1, 0, m - 1, i, j, v);
}

int range_query(int row1, int col1, int row2, int col2)
{
return _range_query(1, 0, n - 1, 0, m - 1, row1, row2, col1, col2);
}

private:
vector<int> st_vec;
int n, m; // 原始数据的行数列数
int topleft = 0, topright = 1, bottomleft = 2, bottomright = 3;

int _range_query(int node, int nodeTop, int nodeBottom, int nodeLeft, int nodeRight, int top, int bottom, int left, int right)
{
if(nodeTop == top && nodeBottom == bottom && nodeLeft == left && nodeRight == right)
return st_vec[node];

int nodeTopBottomMid = (nodeTop + nodeBottom) / 2;
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
if(bottom <= nodeTopBottomMid)
{
if(right <= nodeLeftRightMid)
return _range_query(son(node, topleft), nodeTop, nodeTopBottomMid, nodeLeft, nodeLeftRightMid, top, bottom, left, right);
else if(nodeLeftRightMid < left)
return _range_query(son(node, topright), nodeTop, nodeTopBottomMid, nodeLeftRightMid + 1, nodeRight, top, bottom, left, right);
else
{
int s1 = _range_query(son(node, topleft), nodeTop, nodeTopBottomMid, nodeLeft, nodeLeftRightMid, top, bottom, left, nodeLeftRightMid);
int s2 = _range_query(son(node, topright), nodeTop, nodeTopBottomMid, nodeLeftRightMid + 1, nodeRight, top, bottom, nodeLeftRightMid + 1, right);
return s1 + s2;
}
}
else if(nodeTopBottomMid < top)
{
if(right <= nodeLeftRightMid)
return _range_query(son(node, bottomleft), nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeLeftRightMid, top, bottom, left, right);
else if(nodeLeftRightMid < left)
return _range_query(son(node, bottomright), nodeTopBottomMid + 1, nodeBottom, nodeLeftRightMid + 1, nodeRight, top, bottom, left, right);
else
{
int s1 = _range_query(son(node, bottomleft), nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeLeftRightMid, top, bottom, left, nodeLeftRightMid);
int s2 = _range_query(son(node, bottomright), nodeTopBottomMid + 1, nodeBottom, nodeLeftRightMid + 1, nodeRight, top, bottom, nodeLeftRightMid + 1, right);
return s1 + s2;
}
}
else
{
if(right <= nodeLeftRightMid)
{
int s1 = _range_query(son(node, topleft), nodeTop, nodeTopBottomMid, nodeLeft, nodeLeftRightMid, top, nodeTopBottomMid, left, right);
int s2 = _range_query(son(node, bottomleft), nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeLeftRightMid, nodeTopBottomMid + 1, bottom, left, right);
return s1 + s2;

}
else if(nodeLeftRightMid < left)
{
int s1 = _range_query(son(node, topright), nodeTop, nodeTopBottomMid, nodeLeftRightMid + 1, nodeRight, top, nodeTopBottomMid, left, right);
int s2 = _range_query(son(node, bottomright), nodeTopBottomMid + 1, nodeBottom, nodeLeftRightMid + 1, nodeRight, nodeTopBottomMid + 1, bottom, left, right);
return s1 + s2;
}
else
{
int s1 = _range_query(son(node, topleft), nodeTop, nodeTopBottomMid, nodeLeft, nodeLeftRightMid, top, nodeTopBottomMid, left, nodeLeftRightMid);
int s2 = _range_query(son(node, topright), nodeTop, nodeTopBottomMid, nodeLeftRightMid + 1, nodeRight, top, nodeTopBottomMid, nodeLeftRightMid + 1, right);
int s3 = _range_query(son(node, bottomleft), nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeLeftRightMid, nodeTopBottomMid + 1, bottom, left, nodeLeftRightMid);
int s4 = _range_query(son(node, bottomright), nodeTopBottomMid + 1, nodeBottom, nodeLeftRightMid + 1, nodeRight, nodeTopBottomMid + 1, bottom, nodeLeftRightMid + 1, right);
return s1 + s2 + s3 + s4;
}
}
}

void _point_update(int node, int nodeTop, int nodeBottom, int nodeLeft, int nodeRight, int i, int j, int v)
{
if(nodeTop > nodeBottom || nodeLeft > nodeRight)
{
cout << "nodeTop > nodeBottom || nodeLeft > nodeRight" << endl;
return;
}
if(nodeTop == nodeBottom && nodeTop == i && nodeLeft == nodeRight && nodeLeft == j)
{
st_vec[node] = v;
return;
}

int nodeTopBottomMid = (nodeTop + nodeBottom) / 2;
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
if(i <= nodeTopBottomMid)
{
if(j <= nodeLeftRightMid)
_point_update(son(node, topleft), nodeTop, nodeTopBottomMid, nodeLeft, nodeLeftRightMid, i, j, v);
else
_point_update(son(node, topright), nodeTop, nodeTopBottomMid, nodeLeftRightMid + 1, nodeRight, i, j, v);
}
else
{
if(j <= nodeLeftRightMid)
_point_update(son(node, bottomleft), nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeLeftRightMid, i, j, v);
else
_point_update(son(node, bottomright), nodeTopBottomMid + 1, nodeBottom, nodeLeftRightMid + 1, nodeRight, i, j, v);
}
int sum = st_vec[son(node, topleft)];
if(nodeTop < nodeBottom) sum += st_vec[son(node, bottomleft)];
if(nodeLeft < nodeRight) sum += st_vec[son(node, topright)];
if(nodeTop < nodeBottom && nodeLeft < nodeRight) sum += st_vec[son(node, bottomright)];
st_vec[node] = sum;
}

void _build(int node, int nodeTop, int nodeBottom, int nodeLeft, int nodeRight, const vector<vector<int> >& arr)
{
if(nodeLeft == nodeRight && nodeTop == nodeBottom)
{
st_vec[node] = arr[nodeTop][nodeLeft];
return;
}
int nodeTopBottomMid = (nodeTop + nodeBottom) / 2;
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
if(nodeTop == nodeBottom)
{
_build(son(node, topleft), nodeTop, nodeBottom, nodeLeft, nodeLeftRightMid, arr);
_build(son(node, topright), nodeTop, nodeBottom, nodeLeftRightMid + 1, nodeRight, arr);
st_vec[node] = st_vec[son(node, topleft)] + st_vec[son(node, topright)];
return;
}
if(nodeLeft == nodeRight)
{
_build(son(node, topleft), nodeTop, nodeTopBottomMid, nodeLeft, nodeRight, arr);
_build(son(node, bottomleft), nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeRight, arr);
st_vec[node] = st_vec[son(node, topleft)] + st_vec[son(node, bottomleft)];
return;
}
_build(son(node, topleft), nodeTop, nodeTopBottomMid, nodeLeft, nodeLeftRightMid, arr);
_build(son(node, topright), nodeTop, nodeTopBottomMid, nodeLeftRightMid + 1, nodeRight, arr);
_build(son(node, bottomleft), nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeLeftRightMid, arr);
_build(son(node, bottomright), nodeTopBottomMid + 1, nodeBottom, nodeLeftRightMid + 1, nodeRight, arr);
st_vec[node] = st_vec[son(node, topleft)] + st_vec[son(node, topright)] + st_vec[son(node, bottomleft)] + st_vec[son(node, bottomright)];
}

int son(int node, int son_id)
{
// node : 1
// 0(topleft) -> 2
// 1(topright) -> 3
// 2(bottomleft) -> 4
// 3(bottomright) -> 5
return node * 4 - 2 + son_id;
}
};

class NumMatrix {
public:
NumMatrix(vector<vector<int>>& matrix) {
seqsttree_2d = Seq_QT_2D_SegmentTree();
seqsttree_2d.build(matrix);
}

void update(int row, int col, int val) {
seqsttree_2d.point_update(row, col, val);
}

int sumRegion(int row1, int col1, int row2, int col2) {
return seqsttree_2d.range_query(row1, col1, row2, col2);
}

private:
Seq_QT_2D_SegmentTree seqsttree_2d;
};

二维线段树:树套树实现

在一个树形数据结构上,每个点不再是一个节点,而是另外一个树形数据结构。

先按行建立线段树,然后在行线段树的每个节点下再按照列建立线段树。以 $2 \times 3$ 的矩阵为例:

单点修改

对于内层树的修改,就和普通线段树一样,主要就是外层树的修改稍有不同。

(1) 当前外层树节点是叶节点

修改这个外层树节点所对应的内层树,由于是单点修改,找到内层树的叶节点时,直接修改。

(2) 当前外层树节点不是叶节点

修改这个外层树节点所对应的内层树,同理找到内层树叶节点,但是不能直接修改,而是要从当前外层树的儿子所对应的内层树的相同的叶节点更新上来(因为不仅内层树是单点修改,外层树也是单点修改)。

每棵外层树结构一样,因此”从当前外层树的儿子所对应的内层树的相同的叶节点更新上来”就很容易实现。

基于链表的话,”找不同外层树对应的内层树的相同叶节点” 很麻烦,因此这个算法一般不基于链表实现。

另外注意树套树的二维线段树不能区间修改:因为外层树的懒标记是区间,无法合并。

区间查询

与普通线段树查询过程相同。

代码 (c++,模板,基于数组)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class Seq_TT_2D_SegmentTree
{
public:
Seq_TT_2D_SegmentTree()
{
st_vec = vector<vector<int>>();
n = -1;
m = -1;
}

void build(const vector<vector<int>>& arr)
{
if(arr.empty()) return;
n = arr.size(), m = arr[0].size();
// 32 是试出来的
st_vec.assign(n * 4, vector<int>(m * 4));
_build(1, 0, n - 1, 0, m - 1, arr);
}

void point_update(int i, int j, int v)
{
_point_update(1, 0, n - 1, 0, m - 1, i, j, v);
}

int range_query(int row1, int col1, int row2, int col2)
{
return _range_query(1, 0, n - 1, 0, m - 1, row1, row2, col1, col2);
}

private:
vector<vector<int> > st_vec;
int n, m; // 原始数据的行数列数

int _range_subquery(int node, int nodeLeft, int nodeRight, int left, int right, int f_node)
{
if(left == nodeLeft && nodeRight == right)
return st_vec[f_node][node];
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
if(right <= nodeLeftRightMid)
return _range_subquery(node * 2, nodeLeft, nodeLeftRightMid, left, right, f_node);
else if(nodeLeftRightMid < left)
return _range_subquery(node * 2 + 1, nodeLeftRightMid + 1, nodeRight, left, right, f_node);
else
{
int s1 = _range_subquery(node * 2, nodeLeft, nodeLeftRightMid, left, nodeLeftRightMid, f_node);
int s2 = _range_subquery(node * 2 + 1, nodeLeftRightMid + 1, nodeRight, nodeLeftRightMid + 1, right, f_node);
return s1 + s2;
}
}

int _range_query(int node, int nodeTop, int nodeBottom, int nodeLeft, int nodeRight, int top, int bottom, int left, int right)
{
if(top <= nodeTop && nodeBottom <= bottom)
return _range_subquery(1, nodeLeft, nodeRight, left, right, node);
int nodeTopBottomMid = (nodeTop + nodeBottom) / 2;
if(bottom <= nodeTopBottomMid)
return _range_query(node * 2, nodeTop, nodeTopBottomMid, nodeLeft, nodeRight, top, bottom, left, right);
else if(nodeTopBottomMid < top)
return _range_query(node * 2 + 1, nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeRight, top, bottom, left, right);
else
{
int s1 = _range_query(node * 2, nodeTop, nodeTopBottomMid, nodeLeft, nodeRight, top, nodeTopBottomMid, left, right);
int s2 = _range_query(node * 2 + 1, nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeRight, nodeTopBottomMid + 1, bottom, left, right);
return s1 + s2;
}
}

void _point_subupdate(int node, int nodeLeft, int nodeRight, int j, int v, int f_node, int f)
{
// f_node 表示外层节点索引
// f 表示外层节点的 nodeTop 和 nodeBottom 关系
// f = 0: nodeTop < nodeBottom
// f = 1: nodeTop == nodeBottom
if(nodeLeft == nodeRight)
{
if(f) st_vec[f_node][node] = v;
else st_vec[f_node][node] = st_vec[f_node * 2][node] + st_vec[f_node * 2 + 1][node];
}
else
{
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
if(j <= nodeLeftRightMid)
_point_subupdate(node * 2, nodeLeft, nodeLeftRightMid, j, v, f_node, f);
else
_point_subupdate(node * 2 + 1, nodeLeftRightMid + 1, nodeRight, j, v, f_node, f);
st_vec[f_node][node] = st_vec[f_node][node * 2] + st_vec[f_node][node * 2 + 1];
}
}

void _point_update(int node, int nodeTop, int nodeBottom, int nodeLeft, int nodeRight, int i, int j, int v)
{
if(nodeTop == nodeBottom)
{
_point_subupdate(1, nodeLeft, nodeRight, j, v, node, 1);
return;
}
int nodeTopBottomMid = (nodeTop + nodeBottom) / 2;
if(i <= nodeTopBottomMid)
_point_update(node * 2, nodeTop, nodeTopBottomMid, nodeLeft, nodeRight, i, j, v);
else
_point_update(node * 2 + 1, nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeRight, i, j, v);
_point_subupdate(1, nodeLeft, nodeRight, j, v, node, 0);
}

void _subbuild(int node, int nodeTop, int nodeBottom, int nodeLeft, int nodeRight, const vector<vector<int>>& arr, int f_node)
{
if(nodeLeft == nodeRight)
{
if(nodeTop == nodeBottom) st_vec[f_node][node] = arr[nodeTop][nodeLeft];
else st_vec[f_node][node] = st_vec[f_node * 2][node] + st_vec[f_node * 2 + 1][node];
return;
}
int nodeLeftRightMid = (nodeLeft + nodeRight) / 2;
_subbuild(node * 2, nodeTop, nodeBottom, nodeLeft, nodeLeftRightMid, arr, f_node);
_subbuild(node * 2 + 1, nodeTop, nodeBottom, nodeLeftRightMid + 1, nodeRight, arr, f_node);
st_vec[f_node][node] = st_vec[f_node][node * 2] + st_vec[f_node][node * 2 + 1];
}

void _build(int node, int nodeTop, int nodeBottom, int nodeLeft, int nodeRight, const vector<vector<int> >& arr)
{
if(nodeTop == nodeBottom)
{
_subbuild(1, nodeTop, nodeBottom, nodeLeft, nodeRight, arr, node);
return;
}
int nodeTopBottomMid = (nodeTop + nodeBottom) / 2;
_build(node * 2, nodeTop, nodeTopBottomMid, nodeLeft, nodeRight, arr);
_build(node * 2 + 1, nodeTopBottomMid + 1, nodeBottom, nodeLeft, nodeRight, arr);
_subbuild(1, nodeTop, nodeBottom, nodeLeft, nodeRight, arr, node);
}
};

class NumMatrix {
public:
NumMatrix(vector<vector<int>>& matrix) {
seqsttree_2d = Seq_TT_2D_SegmentTree();
seqsttree_2d.build(matrix);
}

void update(int row, int col, int val) {
seqsttree_2d.point_update(row, col, val);
}

int sumRegion(int row1, int col1, int row2, int col2) {
return seqsttree_2d.range_query(row1, col1, row2, col2);
}

private:
Seq_TT_2D_SegmentTree seqsttree_2d;
};

Share