从递归/搜索出发到记忆化搜索

  |  

摘要: 原问题的解依赖子问题的解 -> 先写出递归算法 -> 然后发现有很多重复子问题 -> 增加记忆化

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


记忆化搜索是一种结合了搜索和动态规划的优点的算法。相应地,理解记忆化搜索算法就有搜索(递归)和动态规划两条路线。

(1)有时我们发现待解决的问题呈现出问题的解依赖于同一个问题的更小实例的解这样的特性,与递归的思想吻合。分析好子问题可以直接解决无需递归的情况,以及如何从子问题结果组装成原问题结果,就可以写出递归的算法解决问题了。但有时候会遇到子问题重复计算的情况,时间复杂度一般是指数型,此时可以开一个备忘录数组,记录子问题的解,形成记忆化搜索

(2)有时我们发现待解决的问题除了问题的解依赖于同一个问题的更小实例的解这种递归特性,还进一步地有最优子结构和重复子问题的特性。此时我们就会直接考虑用动态规划解决,设计好状态,再列出状态转移方程就可以写出动态规划算法了。但是会遇到两个问题,一个是动态规划过程要遍历所有状态,而其中可能有一些是推出最终结果过程中始终没有用到的状态,用记忆化搜索就可以排除掉这些没有用到的状态,更进一步用搜索的话还有可能通过剪枝去掉大量无效状态;另一个是状态的设计、状态转移的方向,状态转移方程等动态规划的要素很多时候很难想,但是用搜索就很容易想,在搜索过程中用记忆化的方式把重复子问题做个记录,这就形成了类似于前面的(1)的思考路线。

总结一下,记忆化搜索可以理解为搜索的形式加上动态规划的思想,结合了搜索的容易想、可以避免对结果无影响的无效状态的计算、可能剪枝的好处,以及动态规划的避免重复子问题的好处。

本文我们通过一个题目,来看一下首先写出递归算法,然后发现有重复子问题,进而增加记忆化形成记忆化搜索算法的过程。

题目

使用下面描述的算法可以扰乱字符串 s 得到字符串 t :

  1. 如果字符串的长度为 1 ,算法停止
  2. 如果字符串的长度 > 1 ,执行下述步骤:
  • 在一个随机下标处将字符串分割成两个非空的子字符串。即,如果已知字符串 s ,则可以将其分成两个子字符串 x 和 y ,且满足 s = x + y 。
  • 随机 决定是要「交换两个子字符串」还是要「保持这两个子字符串的顺序不变」。即,在执行这一步骤之后,s 可能是 s = x + y 或者 s = y + x 。
  • 在 x 和 y 这两个子字符串上继续从步骤 1 开始递归执行此算法。

给你两个 长度相等 的字符串 s1 和 s2,判断 s2 是否是 s1 的扰乱字符串。如果是,返回 true ;否则,返回 false 。

提示:

1
2
3
s1.length == s2.length
1 <= s1.length <= 30
s1 和 s2 由小写英文字母组成

示例 1:
输入:s1 = “great”, s2 = “rgeat”
输出:true
解释:s1 上可能发生的一种情形是:
“great” —> “gr/eat” // 在一个随机下标处分割得到两个子字符串
“gr/eat” —> “gr/eat” // 随机决定:「保持这两个子字符串的顺序不变」
“gr/eat” —> “g/r / e/at” // 在子字符串上递归执行此算法。两个子字符串分别在随机下标处进行一轮分割
“g/r / e/at” —> “r/g / e/at” // 随机决定:第一组「交换两个子字符串」,第二组「保持这两个子字符串的顺序不变」
“r/g / e/at” —> “r/g / e/ a/t” // 继续递归执行此算法,将 “at” 分割得到 “a/t”
“r/g / e/ a/t” —> “r/g / e/ a/t” // 随机决定:「保持这两个子字符串的顺序不变」
算法终止,结果字符串和 s2 相同,都是 “rgeat”
这是一种能够扰乱 s1 得到 s2 的情形,可以认为 s2 是 s1 的扰乱字符串,返回 true

示例 2:
输入:s1 = “abcde”, s2 = “caebd”
输出:false

示例 3:
输入:s1 = “a”, s2 = “a”
输出:true

题解

算法:递归/搜索

首先判断两个字符串 s1 和 s2 使用的字符及其个数是否完全相同。如果不相同,则 s1 肯定不是 s2 的扰乱字符串。

记 s1 和 s2 的长度均为 n,如果 s1 和 s2 互为扰乱字符串,则存在某个位置 i,使得以下两种情况之一成立:

  • s1[0..i-1]s2[0..i-1] 互为扰乱字符串、s1[i..n-1]s2[i..n-1]` 互为扰乱字符串。
  • s1[0..i-1]s2[n-i..n-1] 互为扰乱字符串、s1[i..n-1]s2[0..n-i-1]` 互为扰乱字符串。

dfs(l1, r1, l2, r2) 表示 s1[l1..r1]s2[l2..r2] 是否构成扰乱字符串。我们要解决的是 dfs(0, n - 1, 0, n - 1),这样 dfs(l1, r1, l2, r2) 就是子问题。因此可以考虑用递归求解。

从递归的角度看,我们还需要边界条件。而本题中递归的边界条件是长度为 1 的情况,也就是 r1 - l1 + 1 == 1 的情况,如果 s1[l1] == s2[l2] 则为 true,否则返回 false。

这样我们就走通了将原问题拆分成子问题,递归地求解子问题,由子问题的解组装原问题的解这一流程。

也可以从搜索的角度理解以上流程,子问题相当于搜索过程持有的状态集合,状态集合的一种取值对应于一种子问题

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
public:
bool isScramble(string s1, string s2) {
return solve(s1, 0, s1.size() - 1, s2, 0, s2.size() - 1);
}

bool solve(const string& s1, int l1, int r1, const string& s2, int l2, int r2)
{
if(r1 - l1 != r2 - l2)
return false;
int n = r1 - l1 + 1;
if(n == 1)
{
if(s1[l1] == s2[l2])
return true;
else
return false;
}
for(int i = 1; i < n; ++i)
{
if(solve(s1, l1, l1 + i - 1, s2, l2, l2 + i - 1)
&& solve(s1, l1 + i, r1, s2, l2 + i, r2))
return true;
if(solve(s1, l1, l1 + i - 1, s2, r2 - i + 1, r2)
&& solve(s1, l1 + i, r1, s2, l2, r2 - i))
return true;
}
return false;
}
};

算法:记忆化搜索

我们会发现以上递归求解子问题的过程中会遇到很多重复子问题,也就是搜索过程中会遇到很多已经搜索过的状态集合。

因此可以使用记忆化搜索,其备忘录 dp[s] 就是记录状态集合 s 的信息,有两种常见的记法:

  • dp[s] 表示状态集合 s 表示的子问题的最优解,从子问题的最优解可以推出原问题的最优解,也就是具有最优子结构,这样记忆化搜索可以理解为动态规划。
  • dp[s] 表示状态集合 s 是否已经搜索过,如果之前搜索过,则直接从 dfs(s) 返回,这样可以将记忆化搜素理解为剪枝。

由于 n 的范围为 0 ~ 30,因此 (l1, r1) 可以用一个数 l1 * 100 + r1 表示,同样地 (l2, r2) 可以用 l2 * 100 + r2。将这两个四位数拼接到一起,形成一个八位数,依然可以用一个 int 表示,因此可以用一个数 (l1 * 100 + r1) * 10000 + (l2 * 100 + r2) 表示状态集合 s

代码 (C++)

这里我们用第二种记法,也就是 dp[s] 表示状态集合 s 是否已经搜索过。如果 s 搜索过,那么 s 的结果就不是 true,因为如果是 true,那么在上一次搜索到 s 时就已经结束搜索了,因此直接返回 false 即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
public:
bool isScramble(string s1, string s2) {
dp = unordered_set<int>();
return solve(s1, 0, s1.size() - 1, s2, 0, s2.size() - 1);
}

unordered_set<int> dp;

bool solve(const string& s1, int l1, int r1, const string& s2, int l2, int r2)
{
if(r1 - l1 != r2 - l2)
return false;
int n = r1 - l1 + 1;
if(n == 1)
{
if(s1[l1] == s2[l2])
return true;
else
return false;
}
int s = (l1 * 100 + r1) * 10000 + (l2 * 100 + r2);
if(dp.count(s) > 0)
return false;
for(int i = 1; i < n; ++i)
{
if(solve(s1, l1, l1 + i - 1, s2, l2, l2 + i - 1)
&& solve(s1, l1 + i, r1, s2, l2 + i, r2))
return true;
if(solve(s1, l1, l1 + i - 1, s2, r2 - i + 1, r2)
&& solve(s1, l1 + i, r1, s2, l2, r2 - i))
return true;
}
dp.insert(s);
return false;
}
};

Share