树上差分

  |  

摘要: 树上差分算法

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


在文章 前缀和与差分,我们对一个序列定义了前缀和序列和差分序列,根据差分序列的前缀和序列是原序列,原序列区间上的增减转化为了前缀和序列上左端点加 1,右端点减 1。

类似地,在树上也可以做这样的简化,区间操作对应为路径操作,前缀和对应为子树和。本文我们就来学习这种树上差分的技巧。

树上差分

如果要对一系列树上的路径进行操作,比如将路径上所有点的权加 1,或者将路径上所有边的权加 1。操作完成后,问某个点或某条边的权是多少。

对于这类问题,可以使用树上差分来维护路径上的加法,与 差分维护区间加法 的思想一样。

再一次对路径的操作中,树上差分将路径上的重要节点进行修改,作为差分数组的值。最后在求值时,利用 DFS 求出差分数组的前缀和,进而得到要计算的原值。

树上差分的算法要用到 LCA,相关的内容参考 最近公共祖先问题,这里直接用结论和代码模板。

点差分与边差分的原理差不多,下面分别来看。


边差分

问题

给定一棵树,共 N 个节点,编号为 0 ~ N - 1,初始时边的权均为 0。

执行若干次操作,每次操作给定两个节点 u、v,将 uv 路径上所有边的权加 x。

某两个节点 u,v 的路径 uv 的权值和。

算法

对于图中的 n 个节点,建立差分数组 diff,diff[u] 可以视为节点 $u$ 的点权。

树上任一节点都有一个对应的深度,一条边连接的两个节点的深度又是不同的。如果一条边 $e = uv$,其中 $u$ 的深度较大,定义 $e$ 的边权为以较深的节点 $u$ 为根的子树中所有节点的 diff 值之和。这样就可以用子树的 diff 值之和来表示边权

如图,假如有一次操作是把 $u$ 到 $v$ 之间的路径全部加 $x$。在前面的定义下,就记 diff[u] += xdiff[v] += x。这样凡是子树中含节点 $u$ 或 $v$ 的节点,对应的子树 diff 值之和(对应一个边权)就会包含这个 $x$,于是我们只需要进行一次 DFS,过程中更新以每个节点为根的子树 diff 值之和(这里的子树和类似于前缀和)。

此时我们发现 $lca(u, v)$ 加了 $2x$,但 $lca(u, v)$ 这个节点对应的子树 diff 值之和表示的是 $lca(u, v)$ 与 $parent(lca(u, v))$ 的边权。这并不是路径 $uv$ 上的边,因此应该 diff[lca(y, v)] -= 2 * x。这样使得加 x 的效果只局限在 $u, v$,影响不会扩大到 $lca(u, v)$ 的父节点。

至此我们就可以总结一下,给定两节点 u, v,求路径 uv 的边权和,树上的边差分的完整算法:

1
2
3
4
5
6
7
8
9
定义差分数组 diff
枚举每个操作 (u, v, x),对差分数组执行以下操作
diff[u] += x
diff[v] += x
diff[lca(u, v)] -= 2 * x
在树上进行一次 DFS,算出每个节点对应的子树的 diff 值之和 sums。sums[u] 表示子树 u 的 diff 值之和
对查询 (u, v),进行以下操作:
枚举 u 到 lca(u, v) 路径的各个节点 node,ans += sums[node]
枚举 v 到 lca(u, v) 路径的各个节点 node,ans += sums[node]

点差分

问题

给定一棵树,共 N 个节点,编号为 0 ~ N - 1,初始时点的权均为 0。

执行若干次操作,每次操作给定两个节点 x、y,将 xy 路径上所有节点的权加 1。

询问某个节点 x 的权。

算法

与边差分的处理方式类似,还是定义一个差分数组 diff,diff[u] 可以视为另一个点权。节点 $u$ 的点权表示为以 $u$ 为根的子树的所有节点的 diff 值的和

如果有一次操作将 $uv$ 路径的所有节点的点权加 $x$,在前面的定一下,记 diff[u] += x, diff[v] += x,这样凡是子树中含节点 $u$, $v$ 的节点,对应的子树的 diff 值之和(对应一个点权)就会包含这个 $x$。于是只需要进行一次 DFS,过程中更新以每个节点为根的子树 diff 值之和即可(这里的子树和类似于前缀和)。

此时发现 $lca(u, v)$ 的点权按定义依然是加了 $2x$,但 $lca(u, v)$ 属于 $uv$ 路径上的节点,应该加 $x$,因此 diff[lca(u, v)] -= x 即可。此时 $lca(u, v)$ 的父节点的点权是加 $x$ 的,但它并不是 $uv$ 路径上的节点,因此需要 diff[parent[lca(u, v)]] -= x

至此我们就可以总结一下,给定节点 u,求 uv 的点权,树上的点差分的完整算法:

1
2
3
4
5
6
7
8
定义差分数组 diff 
枚举每个操作 (u, v, x),对差分数组执行以下操作
diff[u] += x
diff[v] += x
diff[lca(u, v)] -= x
diff[parent[lca(u, v)][0]] -= x
在树上进行一次 DFS,算出每个节点对应的子树的 diff 值之和 sums。sums[u] 表示子树 u 的 diff 值之和
对查询 u,返回 sums[u] 即可

模板题

一个牛棚有 N 个隔间,编号 1 ~ N,之间安装了 N - 1 根管道,所有隔间都连通。

有 K 条运输牛奶的路线,第 i 条路线从隔间 $s_{i}$ 运输到隔间 $t_{i}$。

一条路线会给它的两个端点处的隔间,以及中途经过的所有隔间增加 1 个单位的运输压力。

求压力最大的隔间的压力是多少。

说明:

1
2
3
4
5
6
7
8
9
10
输入格式
第一行输入两个整数 N 和 K。
接下来 N − 1 行每行输入两个整数 x 和 y,其中 x != y。表示一根在牛棚 x 和 y 之间的管道。
接下来 K 行每行两个整数 s 和 t,描述一条从 s 到 t 的运输牛奶的路线。

输出格式
一个整数,表示压力最大的隔间的压力是多少。

2 <= N <= 5e4
1 <= K <= 1e5

输入输出样例
输入
5 10
3 4
1 5
4 2
5 4
5 4
5 4
3 5
4 3
4 3
1 3
3 5
5 4
1 5
3 4
输出
9

代码 (模板,C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <iostream>
#include <vector>
#include <cmath>
#include <fstream>

using namespace std;

// ----- 树上倍增 -----

void get_parent(const vector<vector<int>>& g, int u, int prev, vector<int>& d, vector<int>& parent)
{
for(int v: g[u])
{
if(v == prev)
continue;
d[v] = d[u] + 1;
parent[v] = u;
get_parent(g, v, u, d, parent);
}
}

void get_fa(const vector<int>& parent, vector<vector<int>>& fa, const int N, const int M)
{
// fa[i][j] := 从 i 爬 2 ^ j 步所到点的 id
// fa[i][0] := 从 i 爬 1 不所到点的 id
for(int i = 1; i <= N; ++i)
fa[i][0] = parent[i];
for(int j = 1; j < M; ++j)
fa[0][j] = -1;
for(int j = 1; j < M; ++j)
for(int i = 1; i <= N; ++i)
{
if(fa[i][j - 1] == -1)
fa[i][j] = -1;
else
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
}

// ----- LCA -----

int lowbit(int n)
{
return n & (-n);
}

int highbit(int n)
{
int p = lowbit(n);
while(p != n)
{
n -= p;
p = lowbit(n);
}
return p;
}

int lca(int x, int y, const vector<int>& d, const vector<vector<int>>& fa)
{
// d[x] >= d[y]
if(d[x] < d[y])
return lca(y, x, d, fa);
// 将 y 向上调整直到和 x 一个深度
int delta = d[x] - d[y];
while(delta > 0)
{
x = fa[x][log2(highbit(delta))];
delta -= highbit(delta);
}
if(x == y)
return x;
int M = fa[0].size();
while(true)
{
if(fa[x][0] == fa[y][0])
break;
int k = 0;
while(k <= M)
{
if(fa[x][k] == -1 || fa[y][k] == -1)
break;
if(fa[x][k] == fa[y][k])
break;
++k;
}
x = fa[x][k - 1];
y = fa[y][k - 1];
}
return fa[x][0];
}

// --- 树上差分 ---

void dfs(const vector<vector<int>>& g, int u, int prev, const vector<int>& diff, vector<int>& sums)
{
for(int v: g[u])
{
if(v == prev)
continue;
dfs(g, v, u, diff, sums);
sums[u] += sums[v];
}
sums[u] += diff[u];
}

int main()
{
int N, K;
cin >> N >> K;

vector<vector<int>> g(N + 1);
for(int i = 0; i < N - 1; ++i)
{
int x, y;
cin >> x >> y;
g[x].push_back(y);
g[y].push_back(x);
}

vector<int> d(N + 1);
int M = log2(N) + 1;
vector<int> parent(N + 1);
get_parent(g, 1, -1, d, parent); // 视 1 为根

vector<vector<int>> fa(N + 1, vector<int>(M));
get_fa(parent, fa, N, M);

vector<int> diff(N + 1);
for(int i = 0; i < K; ++i)
{
int s, t;
cin >> s >> t;
diff[s] += 1;
diff[t] += 1;
diff[lca(s, t, d, fa)] -= 1;
diff[parent[lca(s, t, d, fa)]] -= 1;
}

vector<int> sums(N + 1);
dfs(g, 1, -1, diff, sums);

int ans = 0;
for(int s: sums)
ans = max(ans, s);

cout << ans << endl;
}

Share