heapdict:扩展标准库数据结构,支持堆中指定元素删改,以Dijkstra的优化为例

  |  

摘要: 扩展标准库数据结构,支持堆中指定元素删改。dijkstra 算法的 Python 模板

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


问题背景与 heapdict 简介

heapdict 是一个 Python 库,提供了一个很像 Python dict 的对象,与普通 dict 相比,heapdict 的主要区别是多了以下两个功能:

  • popitem(): 移除并返回优先级最低的 (key, priority) 对,
  • peekitem(): 返回优先级最低的 (key, priority) 对。

因此 heapdict 可以用于做优先级队列。与 heapq 相比,heapdict 还支持改变现有对象的优先级,也就是修改指定 key 的优先级,这种更改优先级是很常见的需求,比如 Dijkstra 和 AStar 等算法中都会遇到。

在文章 迪杰斯特拉算法(Dijkstra) 中我们阐述过这种修改优先级的背景,也就是节点在优先级变好时可能会重复压入。可以通过延迟删除的策略代替直接修改指定节点的优先级,但是队列中的元素个数就会变多,时间复杂度会从 $O(\log V)$ 变成 $O(\log E)$。

改变堆中既有元素的优先级是很麻烦的事情,为了能够在 $O(\log V)$ 时间复杂度完成对堆中指定节点的删除或修改,在以下文章中,我们做过一些尝试:

用下标索引堆优化邻接表的 Prim 算法 中我们基于支持指定节点删改的下标索引堆,将 Prim 中重复压入的节点改为修改堆中既有节点的优先级,这样时间复杂度从 $O(E\log E)$ 到 $O(E\log V)$。

而 heapdcit 直接就可以支持修改指定 key 的优先级,下面我们以 Dijkstra 算法的优化为例,看一下 heapdict 的具体用法。

heapdict 的方法和源码都很少,非常适合作为通过继承的方式扩展 Python 标准库中的数据结构的例子。下面完整地贴出来,看一下想要在原有标准库中的数据结构中增加方法需要怎么做。

heapdict 的 help 信息

从下面的 help 信息中,可以看到,heapdict 的方法主要分为三类:

(1)自定义的方法:

  • clear(self)
  • peekitem(self)
  • popitem(self)

(2)从 collections.abc.MutableMapping 继承的方法:

  • pop(self, key, default=<object object at 0x7fae50070170>)
  • setdefault(self, key, default=None)
  • update(self, other=(), /, **kwds)

(3)从 collections.abc.Mapping 继承的方法:

  • __contains__(self, key)
  • __eq__(self, other)
  • get(self, key, default=None)
  • items(self)
  • keys(self)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
Help on module heapdict:

NAME
heapdict

CLASSES
collections.abc.MutableMapping(collections.abc.Mapping)
heapdict

class heapdict(collections.abc.MutableMapping)
| heapdict(*args, **kw)
|
| Method resolution order:
| heapdict
| collections.abc.MutableMapping
| collections.abc.Mapping
| collections.abc.Collection
| collections.abc.Sized
| collections.abc.Iterable
| collections.abc.Container
| builtins.object
|
| Methods defined here:
|
| __delitem__(self, key)
| Delete self[key].
|
| __getitem__(self, key)
| x.__getitem__(y) <==> x[y]
|
| __init__(self, *args, **kw)
| Initialize self. See help(type(self)) for accurate signature.
|
| __iter__(self)
| Implement iter(self).
|
| __len__(self)
| Return len(self).
|
| __setitem__(self, key, value)
| Set self[key] to value.
|
| clear(self)
| D.clear() -> None. Remove all items from D.
|
| peekitem(self)
| D.peekitem() -> (k, v), return the (key, value) pair with lowest value;
| but raise KeyError if D is empty.
|
| popitem(self)
| D.popitem() -> (k, v), remove and return the (key, value) pair with lowest
| value; but raise KeyError if D is empty.
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| __abstractmethods__ = frozenset()
|
| ----------------------------------------------------------------------
| Methods inherited from collections.abc.MutableMapping:
|
| pop(self, key, default=<object object at 0x7fae50070170>)
| D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
| If key is not found, d is returned if given, otherwise KeyError is raised.
|
| setdefault(self, key, default=None)
| D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D
|
| update(self, other=(), /, **kwds)
| D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
| If E present and has a .keys() method, does: for k in E: D[k] = E[k]
| If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v
| In either case, this is followed by: for k, v in F.items(): D[k] = v
|
| ----------------------------------------------------------------------
| Methods inherited from collections.abc.Mapping:
|
| __contains__(self, key)
|
| __eq__(self, other)
| Return self==value.
|
| get(self, key, default=None)
| D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.
|
| items(self)
| D.items() -> a set-like object providing a view on D's items
|
| keys(self)
| D.keys() -> a set-like object providing a view on D's keys

heapdict 源码

从源码中可以看出,heapdict 内部就是继承了 MutableMapping,内部一个 heap 结构加上一个 dict 结构。把一些继承来的关于字典的方法进行定义,然后将新增的方法定义出来即可:

1
2
3
4
5
6
7
8
9
10
11
clear()
popitem()
peekitem()
__setitem__(key, value)
__delitem__(key)
__getitem__(key)
__iter__()
__len__()
_min_heapify(i)
_decrease_key(i)
_swap(i, j)

内部的数据结构为 heap + dict,其中 dict 中的数据为 key -> [value, key, idx]: value 为优先级,idx 为该元素在 heap 数组中的位置。heap 中的数据也为 [value, key, idx],代码中记其为 wrapper

这与之前我们探讨过的哈希索引堆比较像,可以与文章 哈希索引堆:维护 key 到堆中节点的哈希映射,支持对给定 key 的删改操作 对比看,原理大致相同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping


def doc(s):
if hasattr(s, '__call__'):
s = s.__doc__

def f(g):
g.__doc__ = s
return g
return f


class heapdict(MutableMapping):
__marker = object()

def __init__(self, *args, **kw):
self.heap = []
self.d = {}
self.update(*args, **kw)

@doc(dict.clear)
def clear(self):
del self.heap[:]
self.d.clear()

@doc(dict.__setitem__)
def __setitem__(self, key, value):
if key in self.d:
self.pop(key)
wrapper = [value, key, len(self)]
self.d[key] = wrapper
self.heap.append(wrapper)
self._decrease_key(len(self.heap)-1)

def _min_heapify(self, i):
n = len(self.heap)
h = self.heap
while True:
# calculate the offset of the left child
l = (i << 1) + 1
# calculate the offset of the right child
r = (i + 1) << 1
if l < n and h[l][0] < h[i][0]:
low = l
else:
low = i
if r < n and h[r][0] < h[low][0]:
low = r

if low == i:
break

self._swap(i, low)
i = low

def _decrease_key(self, i):
while i:
# calculate the offset of the parent
parent = (i - 1) >> 1
if self.heap[parent][0] < self.heap[i][0]:
break
self._swap(i, parent)
i = parent

def _swap(self, i, j):
h = self.heap
h[i], h[j] = h[j], h[i]
h[i][2] = i
h[j][2] = j

@doc(dict.__delitem__)
def __delitem__(self, key):
wrapper = self.d[key]
while wrapper[2]:
# calculate the offset of the parent
parentpos = (wrapper[2] - 1) >> 1
parent = self.heap[parentpos]
self._swap(wrapper[2], parent[2])
self.popitem()

@doc(dict.__getitem__)
def __getitem__(self, key):
return self.d[key][0]

@doc(dict.__iter__)
def __iter__(self):
return iter(self.d)

def popitem(self):
"""D.popitem() -> (k, v), remove and return the (key, value) pair with lowest\nvalue; but raise KeyError if D is empty."""
wrapper = self.heap[0]
if len(self.heap) == 1:
self.heap.pop()
else:
self.heap[0] = self.heap.pop()
self.heap[0][2] = 0
self._min_heapify(0)
del self.d[wrapper[1]]
return wrapper[1], wrapper[0]

@doc(dict.__len__)
def __len__(self):
return len(self.d)

def peekitem(self):
"""D.peekitem() -> (k, v), return the (key, value) pair with lowest value;\n but raise KeyError if D is empty."""
return (self.heap[0][1], self.heap[0][0])


del doc
__all__ = ['heapdict']

例子:Dijkstra 最短路算法的优化

题目

带权图最短路径算法与实现 中我们以模板题为例,给出了几个常见的最短路径算法,其中一种是用堆实现的 dijkstra 算法。题目、算法原理、C++ 代码都可以参考那篇文章,这里直接给出 Python 的代码。

Dijkstra 算法

堆中维护元组 (到源的最短距离,节点编号)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
INF = int(1e9)

def dijkstra_heap(g: List[List[int]], start: int, N: int) -> List[int]:
# N 个节点,节点编号从 1 开始
d = [INF] * (N + 1)
d[start] = 0

# 堆中维护 (到start最短距离,节点编号)
heap_data = []
heapq.heappush(heap_data, (0, start))

while heap_data:
min_d, u = heapq.heappop(heap_data)
if d[u] < min_d:
continue
for son in g[u]:
v, w = son
if d[v] <= d[u] + w:
continue
d[v] = d[u] + w
heapq.heappush(heap_data, (d[v], v))

for i in range(1, N + 1):
if d[i] == INF:
d[i] = -1

return d


class Solution:
def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
g = [[] for _ in range(n+1)]
for edge in times:
g[edge[0]].append((edge[1], edge[2]))

d = dijkstra_heap(g, k, n)
ans = 0
for i in range(1, n + 1):
if d[i] == -1:
return -1
ans = max(ans, d[i])

return ans

优化

以下代码中,heapdict 的插入、删除、修改都是通过字典的操作完成。只有 popitem 是关于堆的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def dijkstra_heap(g: List[List[int]], start: int, N: int) -> List[int]:
# N 个节点,节点编号从 1 开始
d = [INF] * (N + 1)
d[start] = 0

# 堆中维护 (到start最短距离,节点编号)
hd = heapdict()
hd[start] = 0

while len(hd) > 0:
u, min_d = hd.popitem()
for son in g[u]:
v, w = son
if d[v] <= d[u] + w:
continue
d[v] = d[u] + w
hd[v] = d[v]

for i in range(1, N + 1):
if d[i] == INF:
d[i] = -1

return d

Share