
图源:~kruschke/BMLR/
模型要拟合到什么程度才有用?通用结构被称为变分推理(variational inference)。无需细想,我们可以假设,我们希望找到一个可以得到最大对数似然函数 p_w(z | x)的模型,其中 w 是模型的参数(分布参数),z 是我们的隐变量(隐藏层的神经元输出,从参数 w 的分布采样得到),x 是输入数据样本。这就是我们的模型了。我们在 Pyro 中引入了一个实例来介绍这个模型,该简单实例包含所有隐变量 q_(z)的一些分布,其中 ф 被称为变分参数。这种分布必须近似于训练最好的模型参数的「实际」分布。
训练目标是使得 [log(p_w(z|x))—log(q_ф(z))] 的期望值相对于有指导的输入数据和样本最小化。在这里我们不探讨训练的细节,因为这里面的知识量太大了,此处就先当它是一个可以优化的黑箱吧。
对了,为什么需要编程呢?因为我们通常将这种概率模型(如神经网络)定义为变量相互关联的有向图,这样我们就可以直接显示变量间的依赖关系:


图源:
而且,概率编程语言起初就被用于定义此类模型并在模型上做推理。
不同于在模型中使用 dropout 或 L1 正则化,你可以把它当作你数据中的隐变量。考虑到所有的权重其实是分布,你可以从中抽样 N 次得到输出的分布,通过计算该分布的标准差,你就知道能模型有多靠谱。作为成果,我们可以只用少量的数据来训练这些模型,而且我们可以灵活地在变量之间添加不同的依赖关系。
我还没有太多关于贝叶斯建模的经验,但是我从 Pyro 和 PyMC3 中了解到,这类模型的训练过程十分漫长且很难定义正确的先验分布。而且,处理从分布中抽取的样本会导致误解和歧义。
我已经从 抓取了每日 Ethereum(以太坊)的价格数据。其中包括典型的 OHLCV(高开低走),另外还有关于 Ethereum 的每日推特量。我们将使用七日的价格、开盘及推特量数据来预测次日的价格变动情况。

价格、推特数、大盘变化
上图是一些数据样本——蓝线对应价格变化,黄线对应推特数变化,绿色对应大盘变化。它们之间存在某种正相关(0.1—0.2)。因此我们希望能利用好这些数据中的模式对模型进行训练。模式识别分类
首先,我想验证简单线性分类器在任务中的表现结果(并且我想直接使用 Pyro tutorial————的结果)。我们按照以下操作在 PyTorch 上定义我们的模型(详情参阅官方指南:)
class RegressionModel(nn.Module): def __init__(self, p): super(RegressionModel, self).__init__() self.linear = nn.Linear(p, 1) def forward(self, x): # x * w + b return self.linear(x)
以上是我们以前用过的简单确定性模型,下面是用 Pyro 定义的概率模型:
def model(data):
def model(data):
# Create unit normal priors over the parameters
mu = Variable(torch.zeros(1, p)).type_as(data)
sigma = Variable(torch.ones(1, p)).type_as(data)
bias_mu = Variable(torch.zeros(1)).type_as(data)
bias_sigma = Variable(torch.ones(1)).type_as(data)
w_prior, b_prior = Normal(mu, sigma), Normal(bias_mu, bias_sigma)
priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
lifted_module = pyro.random_module("module", regression_model, priors)
lifted_reg_model = lifted_module()
with pyro.iarange("map", N, subsample=data):
x_data = data[:, :-1]
y_data = data[:, -1]
# run the regressor forward conditioned on inputs
prediction_mean = lifted_reg_model(x_data).squeeze()
pyro.sample("obs",
Normal(prediction_mean, Variable(torch.ones(data.size(0))).type_as(data)),
obs=y_data.squeeze())
本文来自电脑杂谈,转载请注明本文网址:
http://www.pc-fly.com/a/jisuanjixue/article-70150-2.html