【随机算法】蓄水池抽样

  |  

摘要: 蓄水池抽样入门

【对数据分析、人工智能、金融科技、风控服务感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:潮汐朝夕
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


各位好,在文章 【随机算法】蒙特卡洛【随机算法】拒绝采样 中,我们了解了蒙特卡洛方法的基本思想,实现蒙特卡洛方法最重要的是根据需求正确地实现对给定分布的随机变量的采样,学习了直接采样和拒绝采样。

本文我们来看另一种采样方式:蓄水池抽样。我们首先给出蓄水池抽样的适用场景、算法流程,以及正确性证明。然后用蓄水池抽样解决力扣上的两道随机算法的题目,382 和 398。

蓄水池抽样

场景

给定一个数据流,数据流长度 N 很大,且 N 直到处理完所有数据之前都不可知。

如何在只遍历一遍数据 $O(N)$ 的情况下,能够随机选取出 m 个不重复的数据。

  • 数据流长度未知,可能很大,将全部数据读入内存的算法不可行。
  • 所有数据过一遍之后出解,时间复杂度为 $O(N)$。
  • 随机选取 m 个数,每个数被选中的概率为 m/N。

算法

解决以上问题的算法是蓄水池抽样,算法如下:

1
2
3
4
5
6
开长度为 m 的蓄水池
接收到第 i 个数据(i 从 0 开始):
若 i < m,则放入蓄水池
若 i >= m,在 [0, i] 范围内取随机数 d:
若 d 落在 [0, m-1] 范围内,则用接收到的第 i 个数据替换蓄水池中的第 d 个数据
若 d == m,跳过第 i 个数据

这里使用已知长度的数组来表示未知长度的数据流,并假设数据流长度大于蓄水池容量 m。

正确性证明

对于前 m 个数,data[0..m-1], p(data[0]) = p(data[1]) = ... = p(data[m-1]) = 1

对于第 m+1 个数,data[m],以 m/(m+1) 概率在本次保留下来,那么前 m 个数中的某个数 data[j], 0<=j<=m-1 在本轮被保留的概率为:

1
2
3
4
5
data[j] 在本轮被保留有两种情况 : 
data[j] 在上轮未被替换且本轮 data[m] 被丢弃;
data[m] 本轮被保留且 data[j] 本轮未被替换

p(data[j]) = 1/(m+1) + m/(m+1) * ((m-1)/m) = m/(m+1)

对于第 m+2 个数,data[m+1],以 m/(m+2) 的概率在本次保留下来,那么前 m+1 个数中的某个数 data[j] 在本轮被保留的概率为:

1
p(data[j]) = m/(m+1) * 2/(m+2) + m/(m+1) * ((m-1)/(m+2))

对于第 i 个数,以 m/i 概率在本轮保留,前 i - 1 个数中的某个数 data[j] 在本轮被保留的概率为:

1
p(data[j]) = m/(i-1) * (i-m)/i + m/(i-1) * (m-1) / i = m / i

Follow Up

下面这些 Follow Up,有时间可以学习一下。

  • 《编程珠玑》 $12
  • 分布式蓄水池抽样
  • 有放回地取 m 个: 用 m 个长为 1 的独立蓄水池
  • 样本有加权时的抽样: 加权蓄水池抽样算法

下面我们看两个例题,算法都是蓄水池抽样,算法流程与前面介绍的一样,因此后面的题目直接给出代码。

题目1

给定一个可能含有重复元素的整数数组,要求随机输出给定的数字的索引。 您可以假设给定的数字一定存在于数组中。

注意:

数组大小可能非常大。 使用太多额外空间的解决方案将不会通过测试。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution {
public:
Solution(vector<int>& nums) {
int random_seed = rand();
dre = std::default_random_engine(random_seed);
dr = std::uniform_real_distribution<double>(0.0, 1.0);
for(int j = 0; j < (int)nums.size(); ++j)
mapping[nums[j]].push_back(j);
}

int pick(int target) {
int m = 1;
int ans = -1;
const vector<int> &nums = mapping[target];
for(int i = 0; i < (int)nums.size(); ++i)
{
if(i < m)
ans = nums[i];
else
{
int random_idx = floor((i + 1) * dr(dre));
if(random_idx < m)
ans = nums[i];
}
}
return ans;
}

private:
std::default_random_engine dre;
std::uniform_real_distribution<double> dr;

unordered_map<int, vector<int>> mapping;
};

题目2

给定一个单链表,随机选择链表的一个节点,并返回相应的节点值。保证每个节点被选的概率一样。

进阶:

如果链表十分大且长度未知,如何解决这个问题?你能否使用常数级空间复杂度实现?

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
public:
Solution(ListNode* head) {
this -> head = head;
int random_seed = rand();
std::default_random_engine dre(random_seed);
std::uniform_real_distribution<double> dr;
}

int getRandom() {
int i = 0;
ListNode *iter = head;
int ans = -1;
int m = 1;
while(iter != nullptr)
{
if(i < m)
ans = iter -> val;
else
{
// [0, i]
int random_idx = floor((i + 1) * dr(dre));
if(random_idx < m)
ans = iter -> val;
}
iter = iter -> next;
++i;
}
return ans;
}

private:
ListNode *head;
std::default_random_engine dre;
std::uniform_real_distribution<double> dr;
};

Share