【前沿】TensorFlow Pytorch Keras代码实现深度学习大神Hinton NIPS2017 Capsule论文
【导读】10月26日,深度学习元老Hinton的NIPS2017 Capsule论文《Dynamic Routing Between Capsules》终于在arxiv上发表。今天相关关于这篇论文的TensorFlowPytorchKeras实现相继开源出来,让我们来看下。
论文地址:https://arxiv.org/pdf/1710.09829.pdf
摘要:Capsule 是一组神经元,其活动向量(activity vector)表示特定实体类型的实例化参数,如对象或对象部分。我们使用活动向量的长度表征实体存在的概率,向量方向表示实例化参数。同一水平的活跃 capsule 通过变换矩阵对更高级别的 capsule 的实例化参数进行预测。当多个预测相同时,更高级别的 capsule 变得活跃。我们展示了判别式训练的多层 capsule 系统在 MNIST 数据集上达到了最好的性能效果,比识别高度重叠数字的卷积网络的性能优越很多。为了达到这些结果,我们使用迭代的路由协议机制:较低级别的 capsule 偏向于将输出发送至高级别的 capsule,有了来自低级别 capsule 的预测,高级别 capsule 的活动向量具备较大的标量积。
CapsNet-PyTorch
python依赖包
- Python 3
- PyTorch
- TorchVision
- TorchNet
- TQDM
- Visdom
使用说明
第一步 在capsule_network.py
文件中设置训练epochs,batch size等
BATCH_SIZE = 100NUM_CLASSES = 10NUM_EPOCHS = 30NUM_ROUTING_ITERATIONS = 3
Step 2 开始训练. 如果本地文件夹中没有MNIST数据集,将运行脚本自动下载到本地. 确保 PyTorch可视化工具Visdom正在运行。
$ sudo python3 -m visdom.server & python3 capsule_network.py
基准数据
经过30个epoche的训练手写体数字的识别率达到99.48%. 从下图的训练进度和损失图的趋势来看,这一识别率可以被进一步的提高 。
采用了PyTorch中默认的Adam梯度优化参数并没有用到动态学习率的调整。 batch size 使用100个样本的时候,在雷蛇GTX 1050 GPU上每个Epochs 用时3分钟。
待完成
- 扩展到除MNIST以外的其他数据集。
Credits
主要借鉴了以下两个 TensorFlow 和 Keras 的实现:
- Keras implementation by @XifengGuo
- TensorFlow implementation by @naturomics
Many thanks to @InnerPeace-Wu for a discussion on the dynamic routing procedure outlined in the paper.
CapsNet-Tensorflow
Python依赖包
- Python
- NumPy
- Tensorflow (I'm using 1.3.0, not yet tested for older version)
- tqdm (for displaying training progress info)
- scipy (for saving image)
使用说明
训练
*第一步 * 用git命令下载代码到本地.
$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow
第二部 下载MNIST数据集(http://yann.lecun.com/exdb/mnist/), 移动并解压到data/mnist
文件夹.(当你用复制wget
命令到你的终端是注意渠道花括号里的反斜杠)
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/{train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz}
$ gunzip data/mnist/*.gz
第三部 开始训练:
$ pip install tqdm # install it if you haven't installed yet
$ python train.py
tqdm包并不是必须的,只是为了可视化训练过程。如果你不想要在train.py
中将循环for in step ...
改成 ``for step in range(num_batch)就行了。
评估
$ python eval.py --is_training False
结果
错误的运行结果(Details in Issues #8):
- training loss
- test acc
Epoch |
49 |
51 |
---|---|---|
test acc |
94.69 |
94.71 |
Results after fixing Issues #8:
关于capsule的一点见解
- 一种新的神经单元(输入向量输出向量,而不是标量)
- 常规算法类似于Attention机制
- 总之是一项很有潜力的工作,有很多工作可以在之上开展
待办:
- 完成MNIST的实现Finish the MNIST version of capsNet (progress:90%)
- 在其他数据集上验证capsNet
- 调整模型结构
- 一篇新的投稿在ICLR2018上的后续论文(https://openreview.net/pdf?id=HJWLfGWRb) about capsules(submitted to ICLR 2018),
CapsNet-Keras
依赖包
- Keras
- matplotlib
使用方法
训练
第一步 安装 Keras:
$ pip install keras
第二步 用 git
命令下载代码到本地.
$ git clone https://github.com/xifengguo/CapsNet-Keras.git
$ cd CapsNet-Keras
第三步 训练:
$ python capsulenet.py
一次迭代训练(default 3).
$ python capsulenet.py --num_routing 1
其他参数包括想 batch_size, epochs, lam_recon, shift_fraction, save_dir
可以以同样的方式使用。 具体可以参考 capsulenet.py
测试
假设你已经有了用上面命令训练好的模型,训练模型将被保存在 result/trained_model.h5
. 现在只需要使用下面的命令来得到测试结果。
$ python capsulenet.py --is_training 0 --weights result/trained_model.h5
将会输出测试结果并显示出重构后的图片。测试数据使用的和验证集一样 ,同样也可以很方便的在新数据上验证,至于要按照你的需要修改下代码就行了。
如果你的电脑没有GPU来训练模型,你可以从https://pan.baidu.com/s/1hsF2bvY下载预先训练好的训练模型
结果
主要结果
运行 python capsulenet.py
: epoch=1 代表训练一个epoch 后的结果 在保存的日志文件中,epoch从0开始。
Epoch |
1 |
5 |
10 |
15 |
20 |
---|---|---|---|---|---|
train_acc |
90.65 |
98.95 |
99.36 |
99.63 |
99.75 |
vali_acc |
98.51 |
99.30 |
99.34 |
99.49 |
99.59 |
损失和准确度:
一次常规迭代后的结果
运行 python CapsNet.py --num_routing 1
Epoch |
1 |
5 |
10 |
15 |
20 |
---|---|---|---|---|---|
train_acc |
89.64 |
99.02 |
99.42 |
99.66 |
99.73 |
vali_acc |
98.55 |
99.33 |
99.43 |
99.57 |
99.58 |
每个 epoch 在单卡GTX 1070 GPU上大概需要110s
注释: 训练任然是欠拟合的,欢迎在你自己的机器上验证。学习率decay还没有经过调试, 我只是试了一次,你可以接续微调。
测试结果
运行 python capsulenet.py --is_training 0 --weights result/trained_model.h5
模型结构:
其他实现代码
- Kaggle (this version as self-contained notebook):
- MNIST Dataset running on the standard MNIST and predicting for test data
- MNIST Fashion running on the more challenging Fashion images.
- TensorFlow:
- naturomics/CapsNet-Tensorflow Very good implementation. I referred to this repository in my code.
- InnerPeace-Wu/CapsNet-tensorflow I referred to the use of tf.scan when optimizing my CapsuleLayer.
- LaoDar/tf_CapsNet_simple
- PyTorch:
- nishnik/CapsNet-PyTorch
- timomernick/pytorch-capsule
- gram-ai/capsule-networks
- andreaazzini/capsnet.pytorch
- leftthomas/CapsNet
- MXNet:
- AaronLeong/CapsNet_Mxnet
- Lasagne (Theano):
- DeniskaMazur/CapsNet-Lasagne
- Chainer:
- soskek/dynamic_routing_between_capsules
参考网址链接:
https://github.com/gram-ai/capsule-networks
https://github.com/naturomics/CapsNet-Tensorflow
https://github.com/XifengGuo/CapsNet-Keras
特别提示:
请关注专知公众号,后台回复“MLDL” 就可以获取机器学习&深度学习知识资料大全集的pdf下载链接
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法