树的遍历:祖先链上的统计

  |  

摘要: 树形前缀就是祖先链

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


各位好,今天我们看一个树的遍历问题。主要涉及到树上的祖先链的概念。对于树来说,祖先链可以理解为树形前缀,在文章 树形前缀和:在树的DFS过程中维护祖先链的和 中讨论过类似题目,可以参考。

题目

给定二叉树的根节点 root,找出存在于 不同 节点 A 和 B 之间的最大值 V,其中 V = |A.val - B.val|,且 A 是 B 的祖先。

(如果 A 的任何子节点之一为 B,或者 A 的任何子节点是 B 的祖先,那么我们认为 A 是 B 的祖先)

提示:

1
2
树中的节点数在 2 到 5000 之间。
0 <= Node.val <= 1e5

示例 1:

输入:root = [8,3,10,1,6,null,14,null,null,4,7,13]
输出:7
解释:
我们有大量的节点与其祖先的差值,其中一些如下:
|8 - 3| = 5
|3 - 7| = 4
|8 - 1| = 7
|10 - 13| = 3
在所有可能的差值中,最大值 7 由 |8 - 1| = 7 得出。

示例 2:

输入:root = [1,null,2,null,0,3]
输出:3

题解

算法: 树的遍历

我们要求的是二叉树中所有祖先链上的最大极差。因此我们需要在 DFS 的过程中,维护祖先链上的最大值和最小值

假设当前遍历到节点 $u$,我们要做下面几件事:

  • (1) 其中一个点定为 $u$,另一个点为 $u$ 的某棵子树上的点,可以形成的最大极差是多少。这个值是一个候选答案,在遍历过程中维护。
  • (2) 返回以 $u$ 为根的子树的最大值和最小值。用于支持回溯阶段上述 (1) 的计算。

这样在整个 DFS 走完以后,可以得到答案。

具体地,我们需要先递归地处理 $u$ 的各个子节点 $v$,返回对于以 $v$ 为根的子树来说的祖先链的最大值 $M_{v}$ 和最小值 $m_{v}$,所有子节点都返回后,得到的所有子树上的最大值和最小值分别记为 $M$ 和 $m$:

然后完成以下两步计算,对应上面的 (1)(2):

(1)

(2)

在实现时,可以先计算 (2) 中的 $M_{u}$ 和 $m_{u}$,然后用 $M_{u}$ 和 $m_{u}$ 代替 (1) 中的 $M$ 和 $m$,这样可以避免复杂的边界判断。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
public:
int maxAncestorDiff(TreeNode* root) {
int ans = 0;
_postOrder(root, ans);
return ans;
}

private:
using PII = pair<int, int>;

PII _postOrder(TreeNode* node, int& ans)
{
// 返回 {min, max}
PII left({INT_MAX, INT_MIN}), right({INT_MAX, INT_MIN});
if(node -> left)
left = _postOrder(node -> left, ans);
if(node -> right)
right = _postOrder(node -> right, ans);

int m = node -> val;
m = min(m, left.first);
m = min(m, right.first);
int M = node -> val;
M = max(M, left.second);
M = max(M, right.second);

ans = max(ans, abs(node -> val - m));
ans = max(ans, abs(M - node -> val));

return PII(m, M);
}
};

代码 (Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
INF = int(1e10)

class Solution:
def maxAncestorDiff(self, root: Optional[TreeNode]) -> int:
ans = 0

def dfs(node: int) -> List[int]:
left = right = [INF, -INF]
if node.left:
left = dfs(node.left)
if node.right:
right = dfs(node.right)

m = node.val
m = min(m, left[0])
m = min(m, right[0])
M = node.val
M = max(M, left[1])
M = max(M, right[1])

nonlocal ans
ans = max(ans, abs(node.val - m))
ans = max(ans, abs(M - node.val))

return [m, M]

dfs(root)
return ans

Share