【模板集锦】自定义比较函数:既有额外信息又有复杂控制逻辑的情况

  |  

摘要: 自定义比较函数:C++/Python,排序/取最值/堆

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


今天我们看一个比较基础的题,leetcode第524题,通过删除字母匹配到字典里最长单词。算法就是先将字典中所有字符串按长度和字典序进行双关键字排序,然后从大到小枚举,看是否是 s 的子序列。第一步是一个自定义排序问题,第二步是一个子序列匹配问题。这两个都是常规问题常规算法。

不过亮点是如果把判定子序列的部分塞进自定义排序的比较逻辑里面,就可以直接对数组取最大值就可以解决了。只是取最大值用到的比较函数就复杂了,既需要额外的信息(查询的串),也有复杂的控制逻辑(子序列匹配)。

因此本文我们以此题为例,看一下自定义比较逻辑时既有额外信息又有复杂控制逻辑,如何通过自定义比较函数来实现,以及这种自定义比较函数再排序、堆、取最值这三种场景下分别怎么用。

题目描述

给你一个字符串 s 和一个字符串数组 dictionary 作为字典,找出并返回字典中最长的字符串,该字符串可以通过删除 s 中的某些字符得到。

如果答案不止一个,返回长度最长且字典序最小的字符串。如果答案不存在,则返回空字符串。

提示:

1
2
3
4
1 <= s.length <= 1000
1 <= dictionary.length <= 1000
1 <= dictionary[i].length <= 1000
s 和 dictionary[i] 仅由小写英文字母组成

样例

示例 1:
输入:s = “abpcplea”, dictionary = [“ale”,”apple”,”monkey”,”plea”]
输出:”apple”
示例 2:
输入:s = “abpcplea”, dictionary = [“a”,”b”,”c”]
输出:”a”

算法:自定义排序 + 子序列匹配

首先将 dictionary 中的单词排序,排序规则是如果长度相等,则字典序小的放前面;如果长度不等,则长度大的放前面。这种自定义排序在 leetcode 还是很常见的,下面这篇文章中整理了一些题,有时间可以集中刷:自定义排序

dictionary 排序后,按顺序枚举 dictionary 中的单词 word,检查 word 是否是 s 的子序列。判定子序列的问题作为子问题在 leetcode 有题: 392. 判断子序列。在文章 子序列匹配 中给出了子序列判定的两种算法:动态规划和贪心+双指针,并总结了子序列匹配的几个变种题目与常见方法。后面的代码中用的是贪心 + 双指针算法。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
struct Cmp
{
Cmp(){}
bool operator()(const string& word1, const string& word2) const
{
if(word1.size() == word2.size())
return word1 < word2;
return word1.size() > word2.size();
}
};

class Solution {
public:
string findLongestWord(string s, vector<string>& d) {
if(s.empty() || d.empty()) return "";
Cmp cmp;
sort(d.begin(), d.end(), cmp);
for(const string& word: d)
if(check(s, word))
return word;
return "";
}

private:
bool check(const string& s, const string& word) const
{
// word 是否可以通过 origin 删除某些字符得到
int n = s.size();
int m = word.size();
if(m > n) return false;
int i = 0, j = 0;
while(i < n && j < m)
{
if(s[i] == word[j])
++j;
++i;
if(m - j > n - i)
return false;
}
return j == m;
}
};

代码 (Python)

关于 Python3 中自定义比较函数的写法,可以参考这篇文章 Python3自定义排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from functools import cmp_to_key

class Cmp:
def __call__(self, word1, word2):
# 如果 word1 < word2 则返回 True
# 此函数定义什么叫 word1 < word2
if len(word1) == len(word2):
if word1 < word2:
return -1
elif word1 > word2:
return 1
else:
return 0
if len(word1) > len(word2):
return -1
elif len(word1) < len(word2):
return 1
else:
return 0

class Solution:
def findLongestWord(self, s: str, dictionary: List[str]) -> str:
if len(s) == 0 or len(dictionary) == 0:
return ""
mycmp = Cmp()
dictionary.sort(key=cmp_to_key(mycmp))
for word in dictionary:
if self.check(s, word):
return word
return ""

def check(self, s, word):
# 判断 word 是否是 s 的子序列
# 这里的算法用双串单向双指针
n = len(s)
m = len(word)
if m > n:
return False
i, j = 0, 0
while i < n and j < m:
if s[i] == word[j]:
j += 1
i += 1
if m - j > n - i:
return False
return j == m

模板集锦:自定义比较函数

下面我们把判断子序列的 check 函数的逻辑塞进自定义排序的比较逻辑中。形成一个复杂的自定义比较函数,其中既有外部信息(查询的字符串 s),又有复杂的控制逻辑(子序列匹配)。

这种情况要给原对象(字典中的字符串)直接定义小于号就不好解决了,因为有外部信息。而通过自定义函数(可调用对象)的方式可以解决,并且在排序、堆、取最值中都可以使用。

在自定义比较函数中,比大小的逻辑是三个关键字:

  1. 能否匹配上 s:匹配不上的更小。
  2. 长度:长度短的更小。
  3. 字典序:字典序大的更小。

C++ 定义 Cmp 结构体

在 Cmp 中,若 word1 比 word2 小,则返回 true。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
struct Cmp
{
string origin;
Cmp(const string& origin):origin(origin){}
bool operator()(const string& word1, const string& word2) const
{
bool flag1 = check(word1), flag2 = check(word2);
if(flag1 == flag2)
{
if(word1.size() == word2.size())
return word1 > word2;
return word1.size() < word2.size();
}
return !flag1 && flag2;
}

bool check(const string& word) const
{
// word 是否可以通过 origin 删除某些字符得到
int n = origin.size();
int m = word.size();
if(m > n) return false;
int i = 0, j = 0;
while(i < n && j < m)
{
if(origin[i] == word[j])
++j;
++i;
if(m - j > n - i)
return false;
}
return j == m;
}
};

取最值

如果 Cmp 中,word1 比 word2 小时返回 true。则 max_element 是取最大、min_element 是取最小。

1
2
3
4
5
6
7
8
9
10
11
class Solution {
public:
string findLongestWord(string s, vector<string>& d) {
if(s.empty() || d.empty()) return "";
Cmp cmp(s);
auto it = max_element(d.begin(), d.end(), cmp);
if(cmp.check(*it))
return *it;
return "";
}
};

排序

如果 Cmp 中,word1 比 word2 小时返回 true,则 sort 为从小到大排序。

1
2
3
4
5
6
7
8
9
10
11
12
class Solution {
public:
string findLongestWord(string s, vector<string>& d) {
if(s.empty() || d.empty()) return "";
Cmp cmp(s);
sort(d.begin(), d.end(), cmp);
string max_word = d.back();
if(cmp.check(max_word))
return max_word;
return "";
}
};

如果 HeapCmp 中,word1 比 word2 小时返回 true,则为最大堆。

这里由于是求最大值,把原数据数组堆化之后,直接取堆头即可,不用一个一个地插入。

这里有动态定义比较逻辑的问题,比较方便的右 std::function 和 lambda 表达式两种方案。

通用多态函数包装器 std::function

由于 HeapCmp 中持有运行时才能确定的状态,这里用 std::function 包装,使其可以作为模板参数传递。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
struct HeapCmp
{
Cmp cmp;
HeapCmp(Cmp cmp):cmp(cmp){}
bool operator()(const string& word1, const string& word2) const
{
return cmp(word1, word2);
}
};

class Solution {
public:
string findLongestWord(string s, vector<string>& d) {
if(s.empty() || d.empty()) return "";
Cmp cmp(s);

function<bool(string, string)> heapcmp = HeapCmp(cmp);

// 直接将 d 堆化
priority_queue<string, vector<string>, decltype(heapcmp)> pq(heapcmp, d);
/* 一个一个地插入,这里不需要
* priority_queue<string, vector<string>, decltype(heapcmp)> pq(heapcmp);
* for(const string& w: d)
* pq.push(w);
*/

string max_word = pq.top();
if(cmp.check(max_word))
return max_word;
return "";
}
};

lambda 表达式

lambda 表达式的优点是它非常灵活,可以根据需要动态地定义比较逻辑。

代码与上面使用 std::function 的代码有点像,可以理解为把 HeapCmp 中的比较逻辑放到了 lambda 表达式中,但是实例化 priority_queue 的代码没变,可以对比着看。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
public:
string findLongestWord(string s, vector<string>& d) {
if(s.empty() || d.empty()) return "";
Cmp cmp(s);

// 使用 lambda 创建一个比较函数对象,并捕获 cmp 实例
auto heapcmp = [cmp](const string& word1, const string& word2) {
// 比较逻辑依赖于 cmp
return cmp(word1, word2);
};

// 直接将 d 堆化
priority_queue<string, vector<string>, decltype(heapcmp)> pq(heapcmp, d);
/* 一个一个地插入,这里不需要
* priority_queue<string, vector<string>, decltype(heapcmp)> pq(heapcmp);
* for(const string& w: d)
* pq.push(w);
*/

string max_word = pq.top();
if(cmp.check(max_word))
return max_word;
return "";
}
};

Python 定义 Cmp 类

在 Cmp 中,若 word1 比 word2 小,则返回 -1。若 word2 比 word1 小,则返回 1,若相等则返回 0。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Cmp:
def __init__(self, origin):
self.origin = origin

def __call__(self, word1, word2):
# 如果 word1 < word2 则返回 True
# 此函数定义什么叫 word1 < word2
flag1 = self.check(word1)
flag2 = self.check(word2)
if flag1 == flag2:
if len(word1) == len(word2):
if word1 > word2:
return -1
elif word1 < word2:
return 1
else:
return 0
elif len(word1) < len(word2):
return -1
else:
return 1
elif not flag1 and flag2:
return -1
else
return 1

def check(self, word):
# 判断 word 是否是 origin 的子序列
# 这里的算法用双串单向双指针
n = len(self.origin)
m = len(word)
if m > n:
return False
i, j = 0, 0
while i < n and j < m:
if self.origin[i] == word[j]:
j += 1
i += 1
if m - j > n - i:
return False
return j == m

取最值

如果 Cmp 中当 word1 比 word2 小时返回 -1,则 max 为取最大值,min 为取最小值。

1
2
3
4
5
6
7
8
9
10
11
12
from functools import cmp_to_key

class Solution:
def findLongestWord(self, s: str, dictionary: List[str]) -> str:
if len(s) == 0 or len(dictionary) == 0:
return ""
mycmp = Cmp(s)
mykey = cmp_to_key(mycmp)
max_word = max(dictionary, key=mykey)
if mycmp.check(max_word):
return max_word
return ""

排序

如果 Cmp 中当 word1 比 word2 小时返回 -1,则 sort 为从小到大排序。

1
2
3
4
5
6
7
8
9
10
11
12
13
from functools import cmp_to_key

class Solution:
def findLongestWord(self, s: str, dictionary: List[str]) -> str:
if len(s) == 0 or len(dictionary) == 0:
return ""
mycmp = Cmp(s)
mykey = cmp_to_key(mycmp)
dictionary.sort(key=mykey)
max_word = dictionary[-1]
if mycmp.check(max_word):
return max_word
return ""

如果 HeapCmp 中当 word1 比 word2 小时返回 -1,则 heapq 为小顶堆。

下面代码中,HeapCmp 继承 Cmp,在比较时将 Cmp 中的结果取相反数,就形成了 word1 比 word2 小时返回 1,此时 heapq 为大顶堆。

这里由于是求最大值,把原数据数组堆化之后,直接取堆头即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from functools import cmp_to_key
import heapq

class HeapCmp(Cmp):
def __call__(self, word1, word2):
return -super().__call__(word1, word2)

class Solution:
def findLongestWord(self, s: str, dictionary: List[str]) -> str:
if len(s) == 0 or len(dictionary) == 0:
return ""

mycmp = HeapCmp(s)
mykey = cmp_to_key(mycmp)

# 直接将原数组数组堆化
heap_data = [mykey(word) for word in dictionary]
heapq.heapify(heap_data)

# 一个一个地插入,这里不需要
# heap_data = []
# for word in dictionary:
# heapq.heappush(heap_data, mykey(word))

max_word = heap_data[0].obj
if mycmp.check(max_word):
return max_word
return ""

Share