PyMC3入门

  |  

摘要: PyMC3 入门知识,主要参考官网。

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


PyMC3 是一个做贝叶斯分析使用的 Python 库,运行速度快。本文我来学习以下这个框架的基本用法。

PyMC3参考文档如下,包括 API,例子,以及一些背景知识的资料。

下面是一些 PyMC3 的核心特性:

  • Friendly modelling API

PyMC3 allows you to write down models using an intuitive syntax to describe a data generating process.

  • Cutting edge algorithms and model building blocks

Fit your model using gradient-based MCMC algorithms like NUTS
using ADVI for fast approximate inference — including minibatch-ADVI for scaling to large datasets
using Gaussian processes to build Bayesian nonparametric models.

关于贝叶斯方法,简单来说就是下面这三步,PyMC3 完成的也是这三步工作。

step1: 了解数据,建立先验和似然
step2: 利用先验和似然在原始数据的基础上更新模型得到后验概率分布
step3: 使用后验概率更新先前的先验概率,循环迭代执行 step2,直至后验概率趋于收敛。


本文会介绍一些基础入门的 API,更多 API 参考 PyMC3 API,这里记录一下目录,以后方便查找。

  • Distributions
    • Continuous
    • Discrete
    • Multivariate
    • Mixture
    • Timeseries
    • Transformations of a random variable from one space to another.
    • Distribution utility classes and functions
  • Bounded Variables
    • Usage
    • Caveats
    • Bounded Variable API
  • Inference
    • Sampling
    • Variational Inference
  • Generalized Linear Models
  • Gaussian Processes API
    • Implementations
    • Mean Functions
    • Covariance Functions
  • Plots
  • Stats
  • Backends
    • Selecting values from a backend
    • Loading a saved backend
    • ndarray
    • tracetab
  • Math
  • Data
  • Model
  • Graphing Models
  • Random Variables
  • shape_utils
  • ODE

$0 安装 PyMC3

PyMC3 机器学习库,基于Theano, NumPy, SciPy, Pandas, 和 Matplotlib。由于依赖很多,建议新建虚拟环境安装 PyMC3。

1
2
conda create -n pymc3 python=3.9 pip ipython jupyter
pip install pymc3 -i https://pypi.tuna.tsinghua.edu.cn/simple

另外解释以及可视化后验分布时,还需要用到 arviz 这个库,pip 安装即可。

1
pip install arviz

$1 PyMC3 通用 API 快速入门

1
2
3
4
5
6
7
8
9
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import theano.tensor as tt

warnings.simplefilter(action="ignore", category=FutureWarning)

首先看一下 PyMC3 和 ArviZ 的版本。

1
2
3
az.style.use("arviz-darkgrid")
print(f"Running on PyMC3 v{pm.__version__}")
print(f"Running on ArviZ v{az.__version__}")

结果如下

1
2
Running on PyMC3 v3.11.5
Running on ArviZ v0.12.1

$1-1 创建模型

PyMC3 中的模型通过 Model 类来创建。创建的实例可以访问所有随机变量(RVs),计算模型的 logp(对数似然),以及 logp 的梯度

1
2
with pm.Model() as model:
# 模型定义

例如下面这个模型:

1
2
3
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=1, observed=np.random.randn(100))

我们可以查看模型中的 RVs,分为 free RVs 和 observed RVs 两种

1
2
3
print(model.basic_RVs)
print(model.free_RVs)
print(model.observed_RVs)

打印结果如下:

1
2
3
[mu ~ Normal, obs ~ Normal]
[mu ~ Normal]
[obs ~ Normal]

logp 是模型实例的一个方法,返回该模型下,mu 为 0 的似然的自然对数

1
model.logp({"mu": 0})

返回 array(-132.59114888)

$1-2 概率分布

在概率编程中,我们要考虑的是可以直接观察的随机变量,以及不可直接观察的随机变量

  • 可以直接观察的随机变量(Observed RVs)通过似然分布定义。
  • 不可直接观察的随机变量(Unobserved RVs)通过先验分布定义。

关于 PyMC3 中的概率分布,还可以看这个 Tutorial Probability_Distributions

(1) 例子: pm.Normal

这里通过 pm.Normal 看一下概率分布的实例都有哪些方法

pm.Normal 多元正态分布,在 pymc3.distributions.continuous 中,它的 PDF 如下:

其中 $\tau$ 与标准差 $\sigma$ 的转换关系如下:

下面我们画一下正态分布在几种 $\mu$ 和 $\sigma$ 下的 PDF 的图像。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st

plt.style.use("seaborn-darkgrid")

x = np.linspace(-5, 5, 1000)
mus = [0, 0, 0, -2]
sigmas = [0.4, 1, 2, 0.4]

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

for mu, sigma in zip(mus, sigmas):
pdf = st.norm.pdf(x, mu, sigma)
ax.plot(x, pdf, label=r"$\mu$ = {}, $\sigma$ = {}".format(mu, sigma))

ax.set_xlabel("x", fontsize=12)
ax.set_ylabel("f(x)", fontsize=12)
ax.legend(loc=1)
plt.show()

有了以上关于正态分布的基础知识,我们看一下 pm.Normal 的参数,这些参数的含义可以与前面的公式对应起来:

  • mu: 均值
  • sigma: 标准差 (sigma > 0),反映分散程度
  • tau: 精度 (tau > 0),反映聚集程度

下面我们看看用以这些参创建的 pm.Normal 实例有哪些方法:

  • logp(value): 计算正态分布在 value 下的对数似然。
  • random(point=None, size=None): 从正态分布中取随机数,point 是条件,size 是抽样的个数。

除此之外还有很多其它方法,其中有一些是从 pymc3.distributions.distribution.Distribution 继承来的,具体可以将 pm.Normal 实例传进 dir() 查看。

(2) Unobserved RV

Unobserved RV 的意思是该变量是不会被观测到的,会被拟合算法改变。定义方式是 name(str), parameter keyword arguments,例如对于正态先验,我们可以向下面这样定义:

1
2
with pm.Model() as model:
x = pm.Normal("x", mu=0, sigma=1)

与模型实例的 logp 方法差不多,我们可以输出 x = 0 的对数似然

1
x.logp({"x": 0})

返回 array(-0.91893853)

如果要求某个值的 CDF,可以用 x.distribution.logcdf,例如下面这行代码是求 0.0 的 CDF,也就是密度函数在 (-inf, 0.0] 上的积分。

1
x.distribution.logcdf(0.0).eval()

(3) Observed RV

Observed RV 的意思是该变量是被观测到的,不会被拟合算法改变。其定义与 Unobserved RV 一样,只是需要传数据给 observed 参数:

1
2
with pm.Model():
obs = pm.Normal("x", mu=0, sigma=1, observed=np.random.randn(100))

与 Unobserved RV 一样,可以使用 obs.logpobs.distribution.logcdf

1
2
obs.logp({"x": 0.0})
obs.distribution.logcdf(0.0).eval()

(4) 确定性的变换 pm.Deterministic

在 PyMC3 中我们可以对 RVs 做代数运算。

1
2
3
4
5
6
7
with pm.Model():
x = pm.Normal("x", mu=0, sigma=1)
y = pm.Gamma("y", alpha=1, beta=1)
plus_2 = x + 2
summed = x + y
squared = x ** 2
sined = pm.math.sin(x)

上面这种写法,结果不会自动保存,如果要追踪变换后的变量,我们必须用 pm.Deterministic

1
2
3
with pm.Model():
x = pm.Normal("x", mu=0, sigma=1)
plus_2 = pm.Deterministic("x plus 2", x + 2)

这样 PyMC3 就会跟踪 plus_2 了。

(5) Bounded RVs 的自动变换

为了对模型的采样更高效,PyMC3 自动将 bounded RVs 变换为 Unbounded,例如:

1
2
with pm.Model() as model:
x = pm.Uniform("x", lower=0, upper=1)

上面这个 x 是限定范围 [0, 1] 的,但是我们通过 model.free_RVs 查看,结果如下:

1
[x_interval__ ~ TransformedDistribution]

这里的 x_interval__ 是从 x 转换的范围在 (-inf, +inf) 的变量。默认的变换方式是对数几率,公式如下:

未经变换的 Bounded RVs 也是会追踪的。可以通过 model.deterministics 查看,结果如下:

1
[x ~ Uniform]

如果不想要这种内部的变换,可以在创建分布的时候传入 transform=None

1
2
3
4
with pm.Model() as model:
x = pm.Uniform("x", lower=0, upper=1, transform=None)

print(model.free_RVs)

(6) 通过 transform 参数传入自定义变换

之前提到传入 transform=None 后,可以取消对限定范围的随机变量自动变换,transform 这个参数还可以传入自定义的变换方法

pymc3.distributions.transforms 中有一些定义好的变换函数,可以传给 transform,例如下面的例子

1
2
3
4
5
6
7
8
9
10
import pymc3.distributions.transforms as tr

with pm.Model() as model:
# use the default log transformation
x1 = pm.Gamma("x1", alpha=1, beta=1)
# specify a different transformation
x2 = pm.Gamma("x2", alpha=1, beta=1, transform=tr.log_exp_m1)

print("The default transformation of x1 is: " + x1.transformation.name)
print("The user specified transformation of x2 is: " + x2.transformation.name)

打印结果如下

1
2
The default transformation of x1 is: log
The user specified transformation of x2 is: log_exp_m1

PyMC3 不提供显式的方法将一个分布转换为另一个分布。但是可以通过传入反变换给 transform 参数。

具体地定义 transform 时,可以继承 pymc3.distributions.transforms 中的 ElemwiseTransform 这个类,然后实现 forward, backwardjacobian_det 这三个方法。

下面是一个 LogNormal 的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import pymc3 as pm
import pymc3.distributions.transforms as tr
import matplotlib.pyplot as plt
import theano.tensor as tt

class Exp(tr.ElemwiseTransform):
name = "exp"

def backward(self, x):
return tt.log(x)

def forward(self, x):
return tt.exp(x)

def jacobian_det(self, x):
return -tt.log(x)

with pm.Model() as model:
x1 = pm.Normal("x1", 0.0, 1.0, transform=Exp())
x2 = pm.Lognormal("x2", 0.0, 1.0)

lognorm1 = model.named_vars["x1_exp__"]
lognorm2 = model.named_vars["x2"]

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
x = np.linspace(0.0, 10.0, 100)
ax.plot(x
,np.exp(lognorm1.distribution.logp(x).eval())
,"--"
,alpha=0.5
,label="log(y) ~ Lognormal(0, 1)"
)
ax.plot(
x,
np.exp(lognorm2.distribution.logp(x).eval()),
alpha=0.5,
label="y ~ Lognormal(0, 1)",
)
ax.legend()
plt.show()

上面的代码中,名称为 x1_exp__ 的变量即为 LogNormal 分布的。

transform 还可以传入又先后顺序的若干个变换,通过 tr.Chain 包装一下即可传入。

下面是一个例子,通过 ordered 和 logodds 两个变换组成 chain,实现 2D RVs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import pymc3.distributions.transforms as tr
import matplotlib.pyplot as plt
import pymc3 as pm

Order = tr.Ordered()
Logodd = tr.LogOdds()
chain_tran = tr.Chain([Logodd, Order])

with pm.Model() as m0:
x = pm.Uniform("x", 0.0, 1.0, shape=2, transform=chain_tran, testval=[0.1, 0.9])
trace = pm.sample(5000, tune=1000, progressbar=False, return_inferencedata=False)

_, ax = plt.subplots(1, 2, figsize=(10, 5))
for ivar, varname in enumerate(trace.varnames):
ax[ivar].scatter(trace[varname][:, 0], trace[varname][:, 1], alpha=0.01)
ax[ivar].set_xlabel(varname + "[0]")
ax[ivar].set_ylabel(varname + "[1]")
ax[ivar].set_title(varname)
plt.tight_layout()

(7) RVs 的列表/高维 RVs/随机向量 — shape

前面我们看了如何创建标量 RVs,其中涉及到下面的一些细节

  • Observed RVs 和 Unobserved RVs 的问题
  • 用 pm.Deterministic 跟踪对 RVS 做代数运算后的随机变量
  • Bounded RVs 自动变换为 (-inf, inf) 的随机变量
  • transform 传入自定义变换,可以是 PyMC3 自带的,也可以继承 ElemwiseTransform 后实现接口自定义,可以用 Chain 整合又先后顺序的多个变换。

很多模型中,是需要多个 RVs 的,此时我们可以通过 chape 参数传入,例如

1
2
3
with pm.Model() as model:
# good:
x = pm.Normal("x", mu=0, sigma=1, shape=10)

这样 x 是一个随机向量,可以对它做线性代数操作,例如

1
2
3
with model:
y = x[0] * x[1] # full indexing is supported
x.dot(x.T) # Linear algebra is supported

(8) 初始化 — testval

PyMC3 会自动初始化,我们可以通过 x.tag.test_value 查看。也可以自定义,在创建随机变量时候通过 testval 参数穿进去即可。

1
2
3
4
with pm.Model():
x = pm.Normal("x", mu=0, sigma=1, shape=5, testval=np.random.randn(5))

print(x.tag.test_value)

$2 推理

在模型建立后,我们需要通过执行推理来估计后验分布

PyMC3 支持两种推理的大类:采样和变分推断

$2-1 采样 — pm.sample()

可以通过 pm.sample() 执行 MCMC 采样算法,如果没有传入参数,会自动选择采样器以及自动初始化

pm.sample() 可以传入 return_inferencedata=True,这样 sample() 会返回 arviz.data.inference_data.InferenceData 对象,否则返回的是 pymc3.backends.base.MultiTrace 对象。

InferenceData 相比 MultiTrace 的优势是可以保存为文件,也可以从文件载入,并且可以持有元数据,例如日期、版本。

Arviz 是一个用于解释和可视化后验分布的框架,除了支持 PyMC3 以外,还支持 pystan(另一个贝叶斯推断的框架), 关于 ArziZ 更多内容可以看 ArviZ Quickstart

mp.sample 例子 — 从后验分布中采样

1
2
3
4
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=1, observed=np.random.randn(100))
idata = pm.sample(2000, tune=1500, return_inferencedata=True)

打印信息如下

1
2
3
4
5
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu]
Sampling 4 chains for 1_500 tune and 2_000 draw iterations (6_000 + 8_000 draws total) took 4 seconds. 4 chains, 0 divergences]

可以看出,对于连续型模型,PyMC3 会自动使用 NUTS 采样器,并且会给采样器寻找更好的初始参数。下面打印一下 idata 中的一些信息

1
2
idata.posterior.dims: Frozen({'chain': 4, 'draw': 2000})
idata.posterior["mu"].shape: (4, 2000)

结合 sample 的参数和打印结果,可以看到这里我们每个 chain 采样了 2000 个样本,以及 1500 轮迭代调整(tuning)参数。tuning 的数据默认不保留,如果要保留可以设 discard_tuned_samples=False

chains 个数默认会由 CPU 核的个数决定。我们还可以通过 cores 和 chains 参数设置,看下面的例子

1
2
3
4
5
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=1, observed=np.random.randn(100))

idata = pm.sample(cores=4, chains=6, return_inferencedata=True)

打印结果

1
2
3
4
5
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (6 chains in 4 jobs)
NUTS: [mu]
Sampling 6 chains for 1_000 tune and 1_000 draw iterations (6_000 + 6_000 draws total) took 3 seconds. 6 chains, 0 divergences]

这个例子中,我们设定 cores 为 4,chains 为 6,draw 和 tune 用默认值 1000。

1
2
print("idata.posterior.dims: {}".format(idata.posterior.dims))
print("idata.posterior[\"mu\"].shape: {}".format(idata.posterior["mu"].shape))

idata 信息的打印结果如下

1
2
idata.posterior.dims: Frozen({'chain': 6, 'draw': 1000})
idata.posterior["mu"].shape: (6, 1000)

如果我们要获取单个 chain 的数据,可以这样写

1
idata.posterior["mu"].sel(chain=1).shape

结果

1
idata.posterior["mu"].sel(chain=1).shape: (1000,)

PyMC3 中的各种 sampler

各种 sampler 在 pm.step_methods

1
list(filter(lambda x: x[0].isupper(), dir(pm.step_methods)))

结果如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
['BinaryGibbsMetropolis',
'BinaryMetropolis',
'CategoricalGibbsMetropolis',
'CauchyProposal',
'CompoundStep',
'DEMetropolis',
'DEMetropolisZ',
'DEMetropolisZMLDA',
'ElemwiseCategorical',
'EllipticalSlice',
'HamiltonianMC',
'LaplaceProposal',
'MLDA',
'Metropolis',
'MetropolisMLDA',
'MultivariateNormalProposal',
'NUTS',
'NormalProposal',
'PGBART',
'PoissonProposal',
'RecursiveDAProposal',
'Slice',
'UniformProposal']

除了 NUTS 外,最常用的 step-methods 是 Metropolis 和 Slice,对于大多数连续型模型 NUTS 是最好的。

pm.sample() 自定义 sampler 的代码如下

1
2
3
4
5
6
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=1, observed=np.random.randn(100))

step = pm.Metropolis()
trace = pm.sample(1000, step=step)

我们还可以对不同随机变量用不同的 step methods。看下面的例子

1
2
3
4
5
6
7
8
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sd = pm.HalfNormal("sd", sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=sd, observed=np.random.randn(100))

step1 = pm.Metropolis(vars=[mu])
step2 = pm.Slice(vars=[sd])
idata = pm.sample(10000, step=[step1, step2], cores=4, return_inferencedata=True)
1
2
3
4
5
6
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>Metropolis: [mu]
>Slice: [sd]
Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 10 seconds.chains, 0 divergences]
The number of effective samples is smaller than 25% for some parameters.

$2-2 分析采样结果 — arviz

下面的分析图均为前面的对不同随机变量用不同的 step methods例子,下面把这个模型的代码抄一遍

1
2
3
4
5
6
7
8
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sd = pm.HalfNormal("sd", sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=sd, observed=np.random.randn(100))

step1 = pm.Metropolis(vars=[mu])
step2 = pm.Slice(vars=[sd])
idata = pm.sample(10000, step=[step1, step2], cores=4, return_inferencedata=True)

trace plot

最常用的分析采样结果的图

1
az.plot_trace(idata)

R-hat

另一个 MCMC 里面常用的 metric 是 R-hat, 也就是 Gelman-Rubin statistic,我们可以在 az.summary 来看

1
2
df = az.summary(idata)
print(df)
1
2
3
     mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_mean   ess_sd  ess_bulk  ess_tail  r_hat
mu 0.017 0.100 -0.177 0.201 0.001 0.001 6240.0 4844.0 6223.0 5827.0 1.0
sd 1.001 0.072 0.873 1.141 0.000 0.000 37175.0 36446.0 38148.0 28721.0 1.0

forestplot

1
az.plot_forest(idata, r_hat=True)

后验分布

1
az.plot_posterior(idata)

energy plot

对于高维模型(shape参数),看所有参数的 traces 比较难受。如果用 NUTS 我们可以看 energy plot

1
2
3
4
5
6
with pm.Model() as model:
x = pm.Normal("x", mu=0, sigma=1, shape=100)
idata = pm.sample(cores=4, return_inferencedata=True)

az.plot_energy(idata);
plt.show()

$2-3 变分推断

PyMC3 支持多种变分推断技术,这种方法的特点是很快,但是会损失一些精度,使得预测有偏差。

可以通过 pymc3.fit() 来执行变分推断。

Approximation 对象

1
2
3
4
5
6
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sd = pm.HalfNormal("sd", sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=sd, observed=np.random.randn(100))

approx = pm.fit()

返回的 Approximation 对象可以做很多事情,比如从近似的的后验分布中抽样

1
2
approxdata = approx.sample(500)
print(approxdata)

返回的 approxdata 是 MultiTrace 对象。

1
<MultiTrace: 1 chains, 500 iterations, 3 variables>

面向对象接口

以 full-rank ADVI 为例

1
2
3
4
5
mu = pm.floatX([0.0, 0.0])
cov = pm.floatX([[1, 0.5], [0.5, 1.0]])
with pm.Model() as model:
pm.MvNormal("x", mu=mu, cov=cov, shape=2)
approx = pm.fit(method="fullrank_advi")

相应的面向对象接口如下

1
2
3
with pm.Model() as model:
pm.MvNormal("x", mu=mu, cov=cov, shape=2)
approx = pm.FullRankADVI().fit()

采样 10000 个数据画 kde 图

1
2
3
plt.figure()
trace = approx.sample(10000)
az.plot_kde(trace["x"][:, 0], trace["x"][:, 1])

Stein Variational Gradient Descent(SVGD)

1
2
3
4
5
6
7
8
9
10
w = pm.floatX([0.2, 0.8])
mu = pm.floatX([-0.3, 0.5])
sd = pm.floatX([0.1, 0.1])
with pm.Model() as model:
pm.NormalMixture("x", w=w, mu=mu, sigma=sd)
approx = pm.fit(method=pm.SVGD(n_particles=200, jitter=1.0))

plt.figure()
trace = approx.sample(10000)
az.plot_dist(trace["x"]);

关于变分推断,需要其它资料学习。

$3 后验预测采样

$3-1 sample_posterior_predictive

sample_posterior_predictive() 函数在测试数据上执行预测。

1
2
3
4
5
6
7
8
9
10
data = np.random.randn(100)
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sd = pm.HalfNormal("sd", sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=sd, observed=data)

idata = pm.sample(return_inferencedata=True)

with model:
post_pred = pm.sample_posterior_predictive(idata.posterior)

将后延预测加入到 InferenceData

1
az.concat(idata, az.from_pymc3(posterior_predictive=post_pred), inplace=True)

下面画 posterior/prior predictive checks 的图

1
2
3
4
fig, ax = plt.subplots()
az.plot_ppc(idata, ax=ax)
ax.axvline(data.mean(), ls="--", color="r", label="True mean")
ax.legend(fontsize=10);

$3-2 在未知数据上预测

很多时候我们想在没见过的数据上预测,这比较常见于概率机器学习,贝叶斯深度学习。

pm.Data 容器是对 theano.shared 的封装,可以传给 PyMC3。

由于 PyMC3 中的模型都是符号表达式。theano.shared 提供一种在符号表达式中定位数据的方法,使得数据可以修改。

1
2
3
4
5
6
7
8
9
10
11
12
x = np.random.randn(100)
y = x > 0

with pm.Model() as model:
# create shared variables that can be changed later on
x_shared = pm.Data("x_obs", x)
y_shared = pm.Data("y_obs", y)

coeff = pm.Normal("x", mu=0, sigma=1)
logistic = pm.math.sigmoid(coeff * x_shared)
pm.Bernoulli("obs", p=logistic, observed=y_shared)
idata = pm.sample(return_inferencedata=True)

现在我们要在未知数据上预测,需要修改 x_shared 和 y_shared。

1
2
3
4
5
6
7
8
9
10
11
12
13
with model:
# change the value and shape of the data
pm.set_data(
{
"x_obs": [-1, 0, 1.0],
# use dummy values with the same shape:
"y_obs": [0, 0, 0],
}
)

post_pred = pm.sample_posterior_predictive(idata.posterior)

print(post_pred["obs"].mean(axis=0))

打印结果

1
[0.02175 0.49925 0.97675]

Share