Python标准库-lru_cache缓存

  |  

摘要: Python 标准库 functools 中的 lru_cache

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


functools 模块应用于高阶函数,即参数或返回值为其他函数的函数。 通常来说,此模块的功能适用于所有可调用对象。

该模块有很多功能,在文章 Python标准库-functools-partial 中,我们学习了 Partial 对象相关的内容。本文我们来看一下【缓存】这个话题。涉及的知识点如下

lru_cache() 修饰符将一个函数包装在一个【最近最少使用】的缓存中。函数的参数用于建立散列键。

后续如果有相同的参数,则会从这个缓存取值而不会再次调用函数。

此修饰符还会为函数增加以下两个方法:

cache_info(): 检查缓存状态
cache_clear(): 清空缓存

参考书(中英文版):

lru_cache() 例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import functools

@functools.lru_cache()
def expensive(a, b):
print("expensive({}, {})".format(a, b))
return a * b

MAX = 2

print("First set of calls:")
for i in range(MAX):
for j in range(MAX):
expensive(i, j)
print(expensive.cache_info())

print("\nSecond set of calls:")
for i in range(MAX + 1):
for j in range(MAX + 1):
expensive(i, j)
print(expensive.cache_info())

print("\nClearing cache:")
expensive.cache_clear()
print(expensive.cache_info())

print("\nThird set of calls:")
for i in range(MAX + 1):
for j in range(MAX + 1):
expensive(i, j)
print(expensive.cache_info())

在一组嵌套循环中执行多个函数调用。

第二次调用时有相同的参数值,结果存在缓存中。清空缓存并再次运行循环时,这些值需要重新计算。


为了避免一个长时间运行的进程导致缓存无限扩张,需要指定一个最大大小,默认为128个元素。可以用 maxsize 参数控制。

最大元素限制的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import functools

@functools.lru_cache(maxsize=2)
def expensive(a, b):
print("called expensive({}, {})".format(a, b))
return a * b

def make_call(a, b):
print("({}, {})".format(a, b), end=" ")
pre_hits = expensive.cache_info().hits
expensive(a, b)
post_hits = expensive.cache_info().hits
if post_hits > pre_hits:
print("cache hit")

print("Establish the cache")
make_call(1, 2)
make_call(2, 3)
print("\nUse cached items")
make_call(1, 2)
make_call(2, 3)
print("\nCompute a new value, triggering cache expiration")
make_call(3, 4)

print("\nCache still contains one old item")
make_call(2, 3)

print("\nOldest item needs to be recomputed")
make_call(1, 2)

缓存大小设为 2 个元素,使用第三组不同参数 (3, 4) 时,缓存中最老的元素会被清楚,用此新结果取代。

lru_cache() 管理的缓存中,键必须是可散列的,因此用缓存包装的函数,它的所有参数必须是可散列的

例题

1
2
576. 出界的路径数
https://leetcode.cn/problems/out-of-boundary-paths/

给你一个大小为 m x n 的网格和一个球。球的起始坐标为 [startRow, startColumn] 。你可以将球移到在四个方向上相邻的单元格内(可以穿过网格边界到达网格之外)。你 最多 可以移动 maxMove 次球。

给你五个整数 m、n、maxMove、startRow 以及 startColumn ,找出并返回可以将球移出边界的路径数量。因为答案可能非常大,返回对 1e9 + 7 取余 后的结果。

提示:

1
2
3
4
1 <= m, n <= 50
0 <= maxMove <= 50
0 <= startRow < m
0 <= startColumn < n

示例 1:
输入:m = 2, n = 2, maxMove = 2, startRow = 0, startColumn = 0
输出:6
示例 2:
输入:m = 1, n = 3, maxMove = 3, startRow = 0, startColumn = 1
输出:12

算法: 动态规划

1
2
3
4
5
6
7
8
9
10
11
12
13
14
状态定义
dp[i][j][k] := 位置 (i,j),最多移动 k 次的答案

答案
dp[start_i][start_j][N]

初始化
if(i < 0 || i >= m || j < 0 || j >= n)
return 1;
if(k == 0)
return 0;

状态转移
dp[i][j][k] = sum(dp[x][y][k - 1]) 其中 (x,y) 为与 (i,j) 相邻的四个新方向

下面用记忆化搜索实现这个动态规划算法,Python 的代码中用到了 lru_cache。

代码(C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
public:
int findPaths(int m, int n, int N, int i, int j) {
if(N == 0) return 0;
// i, j, k: m, n, N
vector<vector<vector<int> > > dp(m, vector<vector<int> >(n, vector<int>(N + 1, -1)));
return dfs(i, j, m, n, N, dp);
}

private:
int MOD = 1e9 + 7;

int dfs(int i, int j, int m, int n, int k, vector<vector<vector<int> > >& dp)
{
if(i < 0 || i >= m || j < 0 || j >= n)
return 1;
if(k == 0)
return 0;
if(dp[i][j][k] != -1)
return dp[i][j][k];

dp[i][j][k] = 0;
int dx[4] = {0, 1, 0, -1};
int dy[4] = {1, 0, -1, 0};
for(int d = 0; d < 4; ++d)
{
int x = i + dx[d];
int y = j + dy[d];
dp[i][j][k] = (dp[i][j][k] + dfs(x, y, m, n, k - 1, dp)) % MOD;
}
return dp[i][j][k];
}
};

代码(Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import functools

MOD = int(1e9 + 7)

class Solution:
def findPaths(self, m: int, n: int, maxMove: int, startRow: int, startColumn: int) -> int:
self.m = m
self.n = n
self.dx = [0, 1, 0, -1]
self.dy = [1, 0, -1, 0]
if maxMove == 0:
return 0
return self.dfs(startRow, startColumn, maxMove)

@functools.lru_cache(int(1e8))
def dfs(self, i, j, k):
if i < 0 or i >= self.m or j < 0 or j >= self.n:
return 1
if k == 0:
return 0
ans = 0
for d in range(4):
x = i + self.dx[d]
y = j + self.dy[d]
ans = (ans + self.dfs(x, y, k - 1)) % MOD
return ans

Share