力扣480-滑动窗口中位数

  |  

摘要: 滑动窗口中位数:一题多解。对堆的各种改进。

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们看一个对于堆这种数据结构的非常经典的问题,480. 滑动窗口中位数,涉及到对堆的各种改进。


$1 题目

中位数是有序序列最中间的那个数。如果序列的大小是偶数,则没有最中间的数;此时中位数是最中间的两个数的平均数。

例如:

[2,3,4],中位数是 3
[2,3],中位数是 (2 + 3) / 2 = 2.5
给你一个数组 nums,有一个大小为 k 的窗口从最左端滑动到最右端。窗口中有 k 个数,每次窗口向右移动 1 位。你的任务是找出每次窗口移动后得到的新窗口中元素的中位数,并输出由它们组成的数组。

提示:

1
2
你可以假设 k 始终有效,即:k 始终小于输入的非空数组的元素个数。
与真实值误差在 10 ^ -5 以内的答案将被视作正确答案。

示例:

给出 nums = [1,3,-1,-3,5,3,6,7],以及 k = 3。

窗口位置 中位数


[1 3 -1] -3 5 3 6 7 1
1 [3 -1 -3] 5 3 6 7 -1
1 3 [-1 -3 5] 3 6 7 -1
1 3 -1 [-3 5 3] 6 7 3
1 3 -1 -3 [5 3 6] 7 5
1 3 -1 -3 5 [3 6 7] 6
因此,返回该滑动窗口的中位数数组 [1,-1,-1,3,5,6]。

$2 题解

算法1: 对顶堆 + 标记删除

对顶堆的原理与实现可以参考文章 对顶堆

(1) 对顶堆

维护一个大顶堆,一个小顶堆,大顶堆,将大顶堆放左边,小顶堆放右边,两个队头对着。

从左向右看两个堆中的数据,恰好从大顶堆底到大顶堆顶到小顶堆顶到小顶堆底是递增的。

维护的时候始终保持两堆的数据量平衡,即左堆数据量始终等于右堆或者比右堆多一个。

平衡

1
2
3
4
5
左堆大小 > 右堆大小加1:
将左堆堆顶弹出,压进右堆

左堆大小 < 右堆大小
将右堆堆顶弹出,压进左堆

插入

1
2
3
如果左堆空或者数据小于等于左堆的堆顶,则插入左堆
如果数据比左堆的堆顶大,则插入右堆
平衡对顶堆

查询

1
2
3
4
左堆大小 = 右堆大小 + 1:
左堆堆顶为中位数
左堆大小 = 右堆大小:
两个堆顶的平均数为中位数

(2) 删除

  • 不用索引堆处理堆的按 key 删除问题:标记删除+惰性更新, 堆的标记删除

本题由于在推进滑窗的时候,需要将 numa[i - k] 从堆中删除,将 nums[i] 插入。

其中删除这一步比较难受。因为普通的堆是没有根据 key 将堆中元素删除的,如果要实现这个功能,一个办法是索引堆,但是码量就大了。

注意到一个事实:如果不删除 nums[i - k],堆中就会有多余的元素,但是其实堆中有多余元素问题不大,只要两堆仍然是平衡的,就不影响查询中位数的正确性。

因此可以用一种类似于惰性更新(Ref 惰性更新)的办法,当窗口右端点推进到 i 时,先将 nums[i - 1] 标记为删除。

当访问到该节点时,也就是访问堆顶时该节点刚好在堆顶,则将该节点真正从堆中删除(从堆顶弹出)。

标记删除用支持重复元素的哈希集合 label 即可。

迭代时,额外维护两个变量 left_cnt, right_cnt 分别表示左堆和右堆的实际元素个数。

1
2
3
4
5
6
7
left_cnt 和 right_cnt 的更新时机
插入元素 x:
x 插入到左堆,则 ++left_cnt
x 插入到右堆,则 ++right_cnt
标记删除 x:
x 在左堆 (x <= pq_left.top()): --left_cnt;
x 在右堆 (x > pq_left.top()): --right_cnt;

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class LabelOppositeHeap {
public:
LabelOppositeHeap()
{
pq_right = priority_queue<int, vector<int>, greater<int>>();
pq_left = priority_queue<int>();
left_cnt = 0, right_cnt = 0;
label = unordered_map<int, int>();
}

double query()
{
clear_left_top();
clear_right_top();
double left = pq_left.top();
if(left_cnt == right_cnt + 1)
return left;
double right = pq_right.top();
return ((double)left + (double)right) / 2;
}

void insert(int x)
{
clear_left_top();
clear_right_top();
if(pq_left.empty() || x <= pq_left.top())
{
pq_left.push(x);
++left_cnt;
}
else
{
pq_right.push(x);
++right_cnt;
}
balance();
}

void label_delete(int x)
{
clear_left_top();
clear_right_top();
if(x <= pq_left.top())
--left_cnt;
else
--right_cnt;
++label[x];
balance();
}

private:
priority_queue<int, vector<int>, greater<int>> pq_right;
priority_queue<int> pq_left;
int left_cnt, right_cnt;
unordered_map<int, int> label; // 哈希索引,用作删除标记

void clear_right_top()
{
while(!pq_right.empty() && label.count(pq_right.top()))
{
--label[pq_right.top()];
if(label[pq_right.top()] == 0)
label.erase(pq_right.top());
pq_right.pop();
}
}

void clear_left_top()
{
while(!pq_left.empty() && label.count(pq_left.top()))
{
--label[pq_left.top()];
if(label[pq_left.top()] == 0)
label.erase(pq_left.top());
pq_left.pop();
}
}

void balance()
{
if(left_cnt > right_cnt + 1)
{
pq_right.push(pq_left.top());
pq_left.pop();
++right_cnt;
--left_cnt;
}
else if(left_cnt < right_cnt)
{
pq_left.push(pq_right.top());
pq_right.pop();
++left_cnt;
--right_cnt;
}
}
};

class Solution {
public:
vector<double> medianSlidingWindow(vector<int>& nums, int k) {
LabelOppositeHeap label_opposite_heap;
for(int i = 0; i < k; ++i)
label_opposite_heap.insert(nums[i]);
int n = nums.size();
vector<double> result(n - k + 1);
for(int i = k; i < n; ++i)
{
result[i - k] = label_opposite_heap.query();
label_opposite_heap.label_delete(nums[i - k]);
label_opposite_heap.insert(nums[i]);
}
result[n - k] = label_opposite_heap.query();
return result;
}
};

算法2: 平衡树 (multiset)

用平衡树维护窗口内的元素。用前 k 个元素初始化平衡树,找到这 k 个元素的第 (k+1)/2 对应的节点,并用指针指向它,记为 it_mid,当 k 为奇数时,它就是中位数,当 k 为偶数时,它与它的后继结点的平均数为中位数。

在推进窗口时,始终维护该该指针为中位数,或者是两个中位数的左边哪一个。

1
2
3
4
每次更新,需要插入一个删除一个
插入值小于 it_mid ,会插入到 it_mid 左侧 --it_mid
删除值小于等于 it_mid ,会删除 it_mid 左侧: ++it_mid
(可能删除到 it_mid 本身,需要先自增迭代器再删除)

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
public:
vector<double> medianSlidingWindow(vector<int>& nums, int k)
{
multiset<int> setting(nums.begin(), nums.begin() + k);
auto it_mid = next(setting.begin(), (k - 1) / 2);

int n = nums.size();
vector<double> result(n - k + 1);
result[0] = (((double)(*it_mid) + *next(it_mid, 1 - k % 2)) * 0.5);
for (int i = k; i < n; ++i)
{
setting.insert(nums[i]);
if (nums[i] < *it_mid)
--it_mid;

if (nums[i - k] <= *it_mid)
++it_mid;
setting.erase(setting.lower_bound(nums[i - k]));

result[i - k + 1] = (((double)(*it_mid) + *next(it_mid, 1 - k % 2)) * 0.5);
}
return result;
}
};

算法3: 哈希索引对顶堆

将对顶堆中的堆改为哈希索引堆。无脑加哈希索引性能不一定好,对本题来说这个算法的耗时会有点长,有可能过不了。但是这是一种应对按 key 删除的比较直接的方式。

哈希索引堆的原理参考:哈希索引堆

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
class IndexMaxHeap
{
public:
IndexMaxHeap()
{
data.assign(1, -1);
_size = 0;
}

void build(const vector<int>& nums)
{
data.clear();
mapping.clear();
int n = nums.size();
data.assign(n + 1, 0);
data[0] = -1;
for(int i = 0; i < n; ++i)
{
data[i + 1] = nums[i];
mapping[nums[i]].insert(i + 1);
}
_size = n;
for(int i = n; i >= 1; --i)
push_down(i);
}

int top()
{
if(empty())
return -1;
return data[1];
}

int pop()
{
if(empty())
return -1;
int ans = data[1];
_remove(1);
return ans;
}

void push(int key)
{
if(_size + 1 >= (int)data.size())
dilatation();
++_size;
data[_size] = key;
mapping[key].insert(_size);
push_up(_size);
}

void remove(int key)
{
if(mapping.count(key) == 0)
return;
int i = *mapping[key].begin();
_remove(i);
}

void change(int key, int new_key)
{
if(mapping.count(key) == 0)
return;
if(key == new_key)
return;
int i = *mapping[key].begin();
mapping[key].erase(i);
if(mapping[key].empty())
mapping.erase(key);
data[i] = new_key;
mapping[new_key].insert(i);
push_up(i);
push_down(i);
}

int size()
{
return _size;
}

bool empty()
{
return _size == 0;
}

private:
vector<int> data; // keys
int _size;
unordered_map<int, unordered_set<int>> mapping; // key -> idxs

void dilatation()
{
vector<int> tmp_data((_size + 1) * 2 + 1);
tmp_data[0] = 1;
for(int i = 1; i <= _size; ++i)
tmp_data[i] = data[i];
data.swap(tmp_data);
}

void _remove(int i)
{
if(i > _size)
return;
mapping[data[_size]].erase(_size);
if(data[_size] != data[i])
{
mapping[data[_size]].insert(i);
mapping[data[i]].erase(i);
if(mapping[data[i]].empty())
mapping.erase(data[i]);
}
data[i] = data[_size--];
push_up(i);
push_down(i);
}

void push_up(int i)
{
if(i > _size) return;
while(i / 2 > 0 && data[i / 2] < data[i])
{
mapping[data[i]].erase(i);
mapping[data[i]].insert(i / 2);
mapping[data[i / 2]].erase(i / 2);
mapping[data[i / 2]].insert(i);
swap(data[i], data[i / 2]);
i /= 2;
}
}

void push_down(int i)
{
int ori = i, left = i * 2, right = i * 2 + 1;
if(left <= _size && data[left] > data[ori])
ori = left;
if(right <= _size && data[right] > data[ori])
ori = right;
if(ori != i)
{
mapping[data[i]].erase(i);
mapping[data[i]].insert(ori);
mapping[data[ori]].erase(ori);
mapping[data[ori]].insert(i);
swap(data[i], data[ori]);
push_down(ori);
}
}
};


class IndexMinHeap
{
public:
IndexMinHeap()
{
data.assign(1, -1);
_size = 0;
}

void build(const vector<int>& nums)
{
data.clear();
mapping.clear();
int n = nums.size();
data.assign(n + 1, 0);
data[0] = -1;
for(int i = 0; i < n; ++i)
{
data[i + 1] = nums[i];
mapping[nums[i]].insert(i + 1);
}
_size = n;
for(int i = n; i >= 1; --i)
push_down(i);
}

int top()
{
if(empty())
return -1;
return data[1];
}

int pop()
{
if(empty())
return -1;
int ans = data[1];
_remove(1);
return ans;
}

void push(int key)
{
if(_size + 1 >= (int)data.size())
dilatation();
++_size;
data[_size] = key;
mapping[key].insert(_size);
push_up(_size);
}

void remove(int key)
{
if(mapping.count(key) == 0)
return;
int i = *mapping[key].begin();
_remove(i);
}

void change(int key, int new_key)
{
if(mapping.count(key) == 0)
return;
if(key == new_key)
return;
int i = *mapping[key].begin();
mapping[key].erase(i);
if(mapping[key].empty())
mapping.erase(key);
data[i] = new_key;
mapping[new_key].insert(i);
push_up(i);
push_down(i);
}

int size()
{
return _size;
}

bool empty()
{
return _size == 0;
}

private:
vector<int> data; // keys
int _size;
unordered_map<int, unordered_set<int>> mapping; // key -> idxs

void dilatation()
{
vector<int> tmp_data((_size + 1) * 2 + 1);
tmp_data[0] = 1;
for(int i = 1; i <= _size; ++i)
tmp_data[i] = data[i];
data.swap(tmp_data);
}

void _remove(int i)
{
if(i > _size)
return;
mapping[data[_size]].erase(_size);
if(data[_size] != data[i])
{
mapping[data[_size]].insert(i);
mapping[data[i]].erase(i);
if(mapping[data[i]].empty())
mapping.erase(data[i]);
}
data[i] = data[_size--];
push_up(i);
push_down(i);
}

void push_up(int i)
{
if(i > _size) return;
while(i / 2 > 0 && data[i / 2] > data[i])
{
mapping[data[i]].erase(i);
mapping[data[i]].insert(i / 2);
mapping[data[i / 2]].erase(i / 2);
mapping[data[i / 2]].insert(i);
swap(data[i], data[i / 2]);
i /= 2;
}
}

void push_down(int i)
{
int ori = i, left = i * 2, right = i * 2 + 1;
if(left <= _size && data[left] < data[ori])
ori = left;
if(right <= _size && data[right] < data[ori])
ori = right;
if(ori != i)
{
mapping[data[i]].erase(i);
mapping[data[i]].insert(ori);
mapping[data[ori]].erase(ori);
mapping[data[ori]].insert(i);
swap(data[i], data[ori]);
push_down(ori);
}
}
};

class IndexOppositeHeap
{
public:
IndexOppositeHeap(){}

void insert(int x)
{
if(heap_left.empty() || x <= heap_left.top())
heap_left.push(x);
else
heap_right.push(x);
balance();
}

double query()
{
double left = heap_left.top();
if(heap_left.size() == heap_right.size() + 1)
return left;
double right = heap_right.top();
return ((double)left + (double)right) / 2;
}

void index_delete(int x)
{
if(!heap_left.empty() && x <= heap_left.top())
heap_left.remove(x);
else
heap_right.remove(x);
balance();
}

private:
IndexMaxHeap heap_left;
IndexMinHeap heap_right;

void balance()
{
if(heap_left.size() > heap_right.size() + 1)
{
heap_right.push(heap_left.top());
heap_left.pop();
}
else if(heap_left.size() < heap_right.size())
{
heap_left.push(heap_right.top());
heap_right.pop();
}
}
};

class Solution {
public:
vector<double> medianSlidingWindow(vector<int>& nums, int k) {
IndexOppositeHeap index_opposite_heap;
for(int i = 0; i < k; ++i)
{
index_opposite_heap.insert(nums[i]);
}
int n = nums.size();
vector<double> result(n - k + 1);
for(int i = k; i < n; ++i)
{
result[i - k] = index_opposite_heap.query();
index_opposite_heap.index_delete(nums[i - k]);
index_opposite_heap.insert(nums[i]);
}
result[n - k] = index_opposite_heap.query();
return result;
}
};

Share