自定义堆或优先级队列的比较规则:基于比较函数或键函数

  |  

摘要: C++ / Python 的堆或优先级队列中,自定义比较函数

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


各位好,我们知道排序可以自定义比较规则,参考文章 自定义排序题目清单:自定义定义 Cmp 结构体并重载 () 运算符。类似地,堆或优先级队列也有自定义比较规则的问题,

本题算法本身很简单,就是堆中维护的是自定义对象,重点看一下在 C++ 和 Python 中的堆或优先级队列中如何通过比较函数或键函数自定义比较规则。

$0 Python/C++ 自定义比较函数/键函数

C++ 中可以自定义比较函数,常见方法就是定义 HeapCmp 结构体并重载其 () 运算符,然后再定义优先级队列时提供 HeapCmp 的实例即可,参考文章 C++中对堆或优先级队列priority_queue自定义比较函数:自定义 HeapCmp 结构体并重载 () 的方式


Python3 在排序时支持指定键函数,在文章 Python3自定义排序:直接定义键函数,或先定义比较函数再转换为键函数 中有详细阐述。有两种方法:

方法一是先定义 HeapCmp 类,在其 __call__ 方法中实现比较函数的逻辑,这样 HeapCmp 实例就相当于比较函数,然后再通过 cmp_to_key 转换为键函数即可。这样转换成的键函数输入原始对象,输出的可以理解为还是该对象本身,只是增加了自定义的 __lt__,在这个 __lt__ 中实际调用的就是由 HeapCmp 给出的比较函数。

方法二是定义 HeapKey 类,在其 __call__ 方法中实现键函数的逻辑,返回一个带有 __lt__ 的对象作为比较键。这样 HeapKey 的实例就相当于一个键函数,它输入原始对象,输出用于比较的键。


但是 Python 标准库的 heqpq 和 PriorityQueue 中并不直接支持指定键函数,于是就有两种妥协的思路:

如果定义了 HeapCmp 再通过 cmp_to_key 得到一个类包装器 mykey,用于给原对象包装一个 __lt__(逻辑来源于 HeapCmp)。这样可以直接压入 mykey(myobj) 在堆中维护,弹出 item 时,其 item.obj 就是原对象。

如果定义了 HeapKey 进而直接得到键函数 mykey,可以构造元组 (mykey(myobj), myobj),但这样的话堆中如果有两个元组的第一位元素相同,还是会启动对第二位的原始对象的比较,而这里原始对象可能并没有 __lt__,这就产生问题。所以在Python3自定义排序中可行的直接定义键函数的方式在 heapq 这里不适用,根本的解法还是给原对象定义 __lt__


综上,在 Python3 中如果想在 heapq 和 PriorityQueue 中自定义比较逻辑的方法如下,(参考文章 在heapq中指定键函数:实现自定义比较逻辑):

(1) 优先给原对象直接定义 __lt__,相当于 C++ 中的重载对象中的小于号,参考文章 自定义排序/最值/堆的比较规则:自定义对象的小于方法

(2) 如果原对象是数值型这种基础类型但比较有另外的逻辑,或者比较时需要未知的额外信息,不方便直接给原对象定义 __lt__。可以先定义 HeapCmp,在其中放入额外信息和比较逻辑,然后通过 cmp_to_key 转换为一个类包装器 mykey,相当于给原对象包装了一个 __lt__


$1 题目

在一个仓库里,有一排条形码,其中第 i 个条形码为 barcodes[i]。

请你重新排列这些条形码,使其中两个相邻的条形码 不能 相等。 你可以返回任何满足该要求的答案,此题保证存在答案。

提示:

1
2
1 <= barcodes.length <= 10000
1 <= barcodes[i] <= 10000

示例 1:
输入:[1,1,1,2,2,2]
输出:[2,1,2,1,2,1]

示例 2:
输入:[1,1,1,1,2,2,3,3]
输出:[1,3,1,3,2,1,2,1]

$2 题解

算法:贪心

对每个出现过的字符计数,维护在 Item 结构体中。

用堆或者用排序的方式,每次取出当前计数最大的 item,往 result 数组中填数。

填数时先从 0 开始填下标为偶数的位置,填到尾之后再从 1 开始填下标为奇数的位置。

代码 (C++,模板)

参考文章:C++中对堆或优先级队列priority_queue自定义比较函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
struct Item
{
int v, cnt;
Item(int v, int cnt):v(v),cnt(cnt){}
Item():v(-1),cnt(0){}
};

struct HeapCmp
{
bool operator()(const Item& i1, const Item& i2) const
{
return i1.cnt < i2.cnt;
}
};

class Solution {
public:
vector<int> rearrangeBarcodes(vector<int>& barcodes) {
int n = barcodes.size();
if(n <= 2)
return barcodes;
const int MAX_VAL = 10000;
vector<int> cnts(MAX_VAL + 1);
for(int i: barcodes)
++cnts[i];
priority_queue<Item, vector<Item>, HeapCmp> pq;
for(int i = 1; i <= MAX_VAL; ++i)
{
if(cnts[i] == 0)
continue;
pq.push(Item(i, cnts[i]));
}
vector<int> result(n, -1);
bool odd = true;
int iter = 0;
while(!pq.empty())
{
Item item = pq.top();
pq.pop();
int cnt = item.cnt;
while(iter < n && cnt > 0)
{
result[iter] = item.v;
iter += 2;
--cnt;
}
if(iter >= n && odd)
{
odd = false;
iter = 1;
while(iter < n && cnt > 0)
{
result[iter] = item.v;
iter += 2;
--cnt;
}
}
}
return result;
}
};

代码 (Python,模板)

参考文章:在heapq中实现自定义比较逻辑,这里使用 PriorityQueue。

先写比较函数 HeapCmp,再转换为键函数

根据前面的分析,先定义 HeapCmp 再通过 cmp_to_key 实际上就相当于对原对象包装了一个自定义 __lt__

堆中压入弹出的是包装了 __lt__ 后的原对象。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from functools import cmp_to_key
from collections import Counter
from queue import PriorityQueue

class Item:
def __init__(self, v: int=-1, cnt: int=0):
self.v = v
self.cnt = cnt

def __repr__(self):
return "v: {}, cnt: {}".format(self.v, self.cnt)


class HeapCmp:
def __call__(self, i1: Item, i2: Item) -> bool:
# 在堆中按 cnt 从大到小
if i1.cnt < i2.cnt:
return 1
elif i1.cnt > i2.cnt:
return -1
else:
return 0

class Solution:
def rearrangeBarcodes(self, barcodes: List[int]) -> List[int]:
n = len(barcodes)
if n <= 2:
return barcodes
MAX_VAL = 10000
cnts = Counter(barcodes)

mycmp = HeapCmp()
mykey = cmp_to_key(mycmp)
pq = PriorityQueue()
for i in range(1, MAX_VAL + 1):
if cnts[i] == 0:
continue
item = Item(i, cnts[i])
pq.put(mykey(item))

result = [-1] * n
odd = True
_iter = 0
while not pq.empty():
item = pq.get()
# print(item.obj)
cnt = item.obj.cnt
while _iter < n and cnt > 0:
result[_iter] = item.obj.v
_iter += 2
cnt -= 1
if _iter >= n and odd:
odd = False
_iter = 1
while _iter < n and cnt > 0:
result[_iter] = item.obj.v
_iter += 2
cnt -= 1
return result

直接自定义 __lt__

根据前面的分析,如果有比较简单的提取键的逻辑,也不建议直接写成键函数,而是在原对象的 __lt__ 中对键进行比较。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from collections import Counter
from queue import PriorityQueue

class Item:
def __init__(self, v: int=-1, cnt: int=0):
self.v = v
self.cnt = cnt

def __lt__(self, other: Self) -> bool:
# 在堆中按 cnt 从大到小
return self.cnt > other.cnt

def __repr__(self):
return "v: {}, cnt: {}".format(self.v, self.cnt)

class Solution:
def rearrangeBarcodes(self, barcodes: List[int]) -> List[int]:
n = len(barcodes)
if n <= 2:
return barcodes
MAX_VAL = 10000
cnts = Counter(barcodes)

pq = PriorityQueue()
for i in range(1, MAX_VAL + 1):
if cnts[i] == 0:
continue
item = Item(i, cnts[i])
pq.put(item)

result = [-1] * n
odd = True
_iter = 0
while not pq.empty():
item = pq.get()
# print(item)
cnt = item.cnt
while _iter < n and cnt > 0:
result[_iter] = item.v
_iter += 2
cnt -= 1
if _iter >= n and odd:
odd = False
_iter = 1
while _iter < n and cnt > 0:
result[_iter] = item.v
_iter += 2
cnt -= 1
return result

Share