贝叶斯深度学习:桥接PyMC3和Lasagne构建层次神经网络
编辑部翻译组
编译:西西、wally
作者:Thomas Wiecki
今天,我们将使用Lasagne构建一个更有趣的模型,这是一个灵活的Theano图书馆,用于构建各种类型的神经网络。你可能知道,PyMC3还使用了Theano,因此在Lasagne中建立了人工神经网络(ANN),将贝叶斯先验放在参数上,然后在PyMC3中使用变分推理(ADVI)来估计模型。
由于Lasagne的优秀表现,我们可以轻松地建立一个具有最大汇集层的分层贝叶斯卷积ANN,在MNIST上实现98%的准确性。
数据集:MNIST
我们将使用手写数字的经典MNIST数据集。 与之前的博客文章相反,MNIST是具有合理数量的维度和数据点的有实际挑战性的ML任务(当然不如像ImageNet那样有挑战性)。
Loading data...
模型说明
我想像应该可以把Lasagne和PyMC3搭在一起,因为他们都依赖于Theano。 然而,目前还不清楚它将会是多么困难。 幸运的是,第一个实验做得很好,但有一些潜在的方法可以使这更容易。 我开设了一个GitHub issue在Lasagne's的报告里,在这几天后,PR695被合并,允许他们更好的整合。
首先,Lasagne创建一个具有2个完全连接的隐藏层(每个具有800个神经元)的ANN,这几乎是从教程中直接采用的Lasagne代码。 当使用lasagne.layers.DenseLayer创建图层时,我们可以传递一个函数init,该函数必须返回一个用作权重和偏差矩阵的Theano表达式。
接下来,为ANN创建权重函数。 因为PyMC3要求每个随机变量具有不同的名称,我们创建一个类并且是唯一命名的先验。
在这里,priors充当了调节者的角色,试图保持ANN small的权重。它在数学上等价于一个L2的损失项,作为通常的做法是将大的权重惩罚到目标函数中。
下面是一些设置小批量ADVI的函数。
放在一起
让我们用小批量的ADVI来运行ANN:
确保一切聚合:
Accuracy on test data = 89.81%
分层神经网络:学习数据的正则化
上面我们只是固定了所有层的sd = 0.1,但是可能第一层应该有不同于第二层的值。也许开始时是0.1,要么太小或太大。在贝叶斯建模中,很常见的是在这种情况下放置hyperprior,并学习最佳正则化应用到数据中去。这节省了我们在超参数优化中对参数进行调优的时间。
Accuracy on test data = 92.25999999999999%
我们得到一个很小但很好的boost在准确性上。 我们来看看超参数后面的部分:
有趣的是,它们都是不同的,这表明改变正规化数量在网络的每一层是有意义的。
卷积神经网络
但到目前为止,在PyMC3中实现也很简单。有趣的是,我们现在可以构建更复杂的ANNs,像卷积神经网络:
Accuracy on test data = 98.03%
更高的精度。我也尝试了这个层次模型,但它实现了较低的精度(95%),我认为是由于过度拟合。
让我们更多地利用我们在贝叶斯框架中的产出,并在我们的预测中探索不确定性。正如我们的预测是分类的,我们不能简单地计算预测标准差。相反,我们计算的是卡方统计量,它告诉我们样本的均匀程度。越均匀,我们的不确定性越高。我不确定这是否是最好的方法。
正如我们所看到的,当模型出错时,答案会更加不确定(即提供的答案更加均匀)。你可能会说,你从一个普通的ANN那里得到了同样的效果,但事实并非如此。
这篇文章在后续会翻译
结论
通过桥接Lasagne和PyMC3,并通过使用小批量的ADVI来训练贝叶斯神经网络,在一个合适的和复杂的数据集上(MNIST),我们在实际的贝叶斯深度学习问题上迈出了一大步。
我还认为这说明了PyMC3的好处。通过使用一种常用的语言(Python)和抽象计算后端(Theano),我们能够很容易地利用该生态系统的强大功能,并以一种从未考虑过的方式使用PyMC3。我期待着将它扩展到新的领域。
- 终端图像处理系列 - OpenGL ES 2.0 - 3D基础(矩阵投影)
- XssHtml – 基于白名单的富文本XSS过滤类
- fireeyee解剖新型Android恶意软件
- WordPress系统暴力破解测试工具 – wpbf
- RecyclerView notifyItem闪烁的问题
- 独家: iOS是如何收集用户的地理信息的
- Hygieia 为何物?DevOps 利器也
- 汽车攻击离你很近:一分钟变成汽车黑客
- LIFX智能灯泡漏洞泄露WIFI密码
- android ViewPager+Fragment之懒加载
- 逆向APK进行smali注入实现“秒破WIFI”
- 我所理解的Android 启动模式
- 搭建开源入侵检测系统Snort,并实现与防火墙联动
- 如何在Python中为长短期记忆网络扩展数据
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法