二分图匹配:最大匹配,匈牙利算法

  |  

摘要: 二分图最大匹配的算法:匈牙利算法

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


对于一个无向图 G,有 N 个节点(N >= 2),可以发分成 A, B 两个非空集合,其中 $A \cap B = \emptyset$,并且在同一集合内的点之间没有边相连,那么称该无向图为二分图。

在文章 二分图判定 中,我们用 BFS、DFS、并查集这三种方法实现了二分图判定。

本文我们来学习一下二分图的匹配问题。主要涉及一些基本概念和最大匹配的匈牙利算法。


二分图匹配相关概念

G 上的一个边集合,如果该边集合中任意两条边都没有公共端点,则该边集合称为二分图的一组匹配

包含边数最多的一组匹配,称为二分图的最大匹配


增广路的定义

对于任意一组匹配 S,属于 S 的边称为匹配边、不属于 S 的边称为非匹配边

匹配边的端点为匹配点、其它的点为非匹配点。

如果在二分图中存在一条连接两个非匹配点的路径 path,使得非匹配边与匹配边在 path 上交替出现,则 path 称为匹配 S 的增广路

增广路有以下性质

  1. 长度是奇数
  2. 路径上第 1, 3, 5, …, len 条边是非匹配边,第 2, 4, 6, …, len-1 是匹配边。

基于以上性质,如果我们把 path 上的所有边的状态取反,也就是原来是匹配边的,现在变为非匹配边,原来是非匹配边的,现在变为匹配边,那么新的边集 $S^{‘}$ 也是一组匹配,且匹配边数增加了 1。

因此可以得到图论:

二分图的一组匹配 S 是最大匹配 $\Leftrightarrow$ 图中不存在 S 的增广路。


匈牙利算法(增广路算法)

匈牙利算法用于计算二分图最大匹配。

1
2
3
step1: S 置为空集,即所有边都是非匹配边
step2: 寻找增广路 path,把路径上所有边的匹配状态取反,得到更大的匹配 S'
step3: 重复 step2 直至图中不存在增广路

算法的关键是如何寻找增广路,匈牙利算法的做法如下

1
2
3
4
5
二分图 G 的两个内部不存在边的点集 A, B
依次枚举 A 中的点 x,为其寻找一个 B 中的点匹配
枚举 B 中的点 y
若 y 本身是非匹配点,则 x ~ y 是长度为 1 的增广路
若 y 是匹配点,假设 x' ~ y 是匹配边。且 x' ~ y' 也是 G 中的一条边,则 x ~ y ~ x' ~ y' 是增广路

对于每个 A 中的点,最多遍历整个二分图 1 次,因此时间复杂度 $O(NM)$

该算法的正确性基于贪心策略: 对于每个 A 中的点 x,一旦成为匹配点,那么后续最多只会因为找到增广路而更换匹配的 B 中的点,而不会变回非匹配点。

这种贪心正确的正确性证明,可以找图论书籍中关于匹配的章节学习,这本《图论导引》就不错。

匈牙利算法的实现(模板)

整体框架是 DFS,递归地从 x 出发寻找增广路。若找到增广路,则回溯阶段把各个边的匹配状态取反。

  • 基于用 vector<vector> 的邻接表

其中邻接表部分可以参考文章 邻接表,这里直接给出代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// 邻接表
vector<vector<int>> g;

// 匈牙利算法
int ans = 0;
vector<int> match(N * N, -1);
vector<bool> visited;
for(int i = 0; i < N; ++i)
for(int j = 0; j < N; ++j)
{
if(forbidden[i][j])
continue;
int u = i * N + j;
visited.assign(N * N, false);
// 此时 u 是非匹配点
visited[u] = true;
if(dfs(u, g, visited, match))
++ans;
}

bool dfs(int x, const vector<vector<int>>& g, vector<bool>& visited, vector<int>& match)
{
for(int y: g[x])
{
if(visited[y])
continue;
visited[y] = true;
if(match[y] == -1 || dfs(match[y], g, visited, match))
{
match[y] = x;
return true;
}
}
return false;
}

下面是用数组模拟链表的邻接表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
const int E_SIZE = 8e4; // 总边数
const int V_SIZE = 2e4; // 总点数

int head[V_SIZE]; // head[i] 的值是 ver 下标,相当于链表节点指针
int ver[E_SIZE]; // 边的终点,相当于链表节点的 v 字段
// int edge[E_SIZE]; // 边的权重,相当于链表节点的 w 字段
int next_[E_SIZE]; // 相当于链表节点的 next 字段

int tot;
// tot 表示 node 数组(这里是 ver 和 next)已使用的最右位置,而不是链表长度

void init()
{
tot = 0;
memset(head, -1, sizeof(head));
memset(next_, -1, sizeof(next_));
}

void add(int u, int v)
{
// 增加有向边 (u, v),这里没有边权
ver[++tot] = y;
// edge[tot] = w;
next_[tot] = head[x];
head[x] = tot; // 在表头 x 处插入
}

// 访问从 u 出发的所有边
for(int i = head[u]; i != -1; i = next_[i])
{
int v = ver[i];
// int w = edge[i];
// 找到有向边 (x, y) 其权值为 w
}

下面是匈牙利算法 dfs 过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
bool visited[V_SIZE];
int match[V_SIZE];

bool dfs(int u)
{
// 返回是否找到增广路
for(int i = head[u]; i != -1; i = next_[i])
{
int v = ver[i];
if(visited[v])
continue;
visited[v] = true;
if(match[v] == -1 || dfs(match[v]))
{
match[v] = u; // 若找到增广路,回溯阶段把各个边的匹配状态取反
return true;
}
}
return false;
}

int ans = 0;
memset(match, -1, sizeof(match));
for(int u = 1; u <= n; ++i)
{
memset(visited, 0, sizeof(visited));
if(dfs(u))
++ans;
}

模板题: 棋盘覆盖

给定一个 N 行 N 列的棋盘,已知某些格子禁止放置。

求最多能往棋盘上放多少块的长度为 2、宽度为 1 的骨牌,骨牌的边界与格线重合(骨牌占用两个格子),并且任意两张骨牌都不重叠。

1
2
3
4
5
6
7
8
9
10
输入格式
第一行包含两个整数 N 和 t,其中 t 为禁止放置的格子的数量。
接下来 t 行每行包含两个整数 x 和 y,表示位于第 x 行第 y 列的格子禁止放置,行列数从 1 开始。

输出格式
输出一个整数,表示结果。

数据范围
1<=N<=100,
0<=t<=100

输入样例:
8 0
输出样例:
32

算法: 二分图最大匹配

二分图匹配有两个要点

  1. 节点能分成独立的两个集合,每个集合内部有 0 条边
  2. 每个节点只能与 1 条匹配的边相连

把实际问题抽象成二分图匹配问题时,要寻找具有以上两条性质的对象。

本题中,任意两张骨牌不重叠,也就是每个格子只能被一张骨牌覆盖,与要点 2 对应。而骨牌大小是 1 * 2,与要点 1 对应。因此,我们可以把棋盘上未被禁止的各自作为节点,骨牌作为无向边。

代码 (C++)

  • 使用模板1: vector<vector> 的邻接表
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <cstring>
#include <vector>
#include <iostream>
#include <fstream>

using namespace std;

bool dfs(int u, const vector<vector<int>>& g, vector<bool>& visited, vector<int>& match)
{
for(int v: g[u])
{
if(visited[u])
continue;
visited[u] = true;
if(match[v] == -1 || dfs(match[v], g, visited, match))
{
match[v] = u;
return true;
}
}
return false;
}

int main()
{
// 输入
int N, t;
cin >> N >> t;
vector<vector<bool>> forbidden(N, vector<bool>(N, false));
for(int i = 0; i < t; ++i)
{
int x, y;
cin >> x >> y;
forbidden[x - 1][y - 1] = true;
}

// 建图
int dx[4] = {1, -1, 0, 0};
int dy[4] = {0, 0, 1, -1};
vector<vector<int>> g(N * N);
for(int i = 0; i < N; ++i)
for(int j = 0; j < N; ++j)
{
if(forbidden[i][j])
continue;
int u = i * N + j;
for(int d = 0; d < 4; ++d)
{
int x = i + dx[d];
int y = j + dy[d];
if(x < 0 || x >= N || y < 0 || y >= N)
continue;
if(forbidden[x][y])
continue;
int v = x * N + y;
g[u].push_back(v);
}
}

// 匈牙利算法
int ans = 0;
vector<int> match(N * N, -1);
vector<bool> visited;
for(int i = 0; i < N; ++i)
for(int j = 0; j < N; ++j)
{
if(forbidden[i][j])
continue;
int u = i * N + j;
visited.assign(N * N, false);
// 此时 u 是非匹配点
visited[u] = true;
if(dfs(u, g, visited, match))
++ans;
}

cout << ans << endl;
}
  • 使用模板2: 用数组模拟链表的邻接表
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <cstring>
#include <iostream>
#include <fstream>

using namespace std;

// 邻接表
const int E_SIZE = 8e4; // 总边数
const int V_SIZE = 2e4; // 总点数

int head[V_SIZE]; // head[i] 的值是 ver 下标,相当于链表节点指针
int ver[E_SIZE]; // 边的终点,相当于链表节点的 v 字段
int next_[E_SIZE]; // 相当于链表节点的 next 字段

// tot 表示 node 数组(这里是 ver 和 next)已使用的最右位置,而不是链表长度
int tot;

void init()
{
tot = 0;
memset(head, -1, sizeof(head));
memset(next_, -1, sizeof(next_));
}

void add(int x, int y)
{
// 增加有向边 (x, y)
ver[++tot] = y;
next_[tot] = head[x];
head[x] = tot; // 在表头 x 处插入
}

// 匈牙利算法
bool visited[V_SIZE];
bool forbidden[V_SIZE];
int match[V_SIZE];

int key(int x, int y, int N)
{
// 棋盘坐标与图节点编号的映射
return (x - 1) * N + (y - 1);
}

bool dfs(int u)
{
// 返回是否找到增广路
for(int i = head[u]; i != -1; i = next_[i])
{
int v = ver[i];
if(visited[v])
continue;
visited[v] = true;
if(match[v] == -1 || dfs(match[v]))
{
match[v] = u; // 若找到增广路,回溯阶段把各个边的匹配状态取反
return true;
}
}
return false;
}

int main()
{
// 输入
int N, t;
cin >> N >> t;
memset(forbidden, false, sizeof(forbidden));
for(int i = 0; i < t; ++i)
{
int x, y;
cin >> x >> y;
forbidden[key(x, y, N)] = true;
}

// 建图
int dx[4] = {1, -1, 0, 0};
int dy[4] = {0, 0, 1, -1};
init();
for(int i = 1; i <= N; ++i)
for(int j = 1; j <= N; ++j)
{
int u = key(i, j, N);
if(forbidden[u])
continue;
for(int d = 0; d < 4; ++d)
{
int x = i + dx[d];
int y = j + dy[d];
if(x < 1 || x > N || y < 1 || y > N)
continue;
int v = key(x, y, N);
if(forbidden[v])
continue;
add(u, v);
}
}

int ans = 0;
memset(match, -1, sizeof(match));
for(int i = 1; i <= N; ++i)
for(int j = 1; j <= N; ++j)
{
int u = key(i, j, N);
if(forbidden[u])
continue;
memset(visited, false, sizeof(visited));
// 此时 u 是非匹配点
if(dfs(u))
++ans;
}
cout << ans / 2 << endl;
}

Share