卷积神经网络之 - GoogLeNet / Inception-v1
本文建议阅读时间 10 min
欢迎加入计算机视觉交流群
大纲
简介
论文地址:https://arxiv.org/abs/1409.4842
Inception 是一个代号,是 Google 提出的一种深度卷积网络架构(PS:有一部电影的英文名就是它,中文名叫做盗梦空间)。
Inception 的第一个版本也叫 GoogLeNet,在 2014 年 ILSVRC(ImageNet 大规模视觉识别竞赛)的图像分类竞赛提出的,它对比 ZFNet(2013 年的获奖者)和 AlexNet (2012 年获胜者)有了显着改进,并且与 VGGNet(2014 年亚军)相比错误率相对较低。
GoogLeNet 一词包含 LeNet ,这是在向第一代卷积神经网络为代表的 LeNet 致敬。
1×1 卷积
1×1 卷积最初是在 NiN 网络中首次提出的,使用它的目的是为了增加网络非线性表达的能力。
GoogLeNet 也使用了 1×1 卷积,但是使用它的目的是为了设计瓶颈结构,通过降维来减少计算量,不仅可以增加网络的深度和宽度,也有一定的防止过拟合的作用。
以下对比使用 1×1 卷积前后计算量的对比,看看 1×1 卷积是如何有效减少参数的。
未使用 1×1 卷积
使用 1×1 卷积
从上表可以看出,使用 1×1 卷积,参数量只有未使用 1×1 卷积的 5.3/112.9=4.7%,大幅度减少了参数量。
Inception 模块
Inception 模块
上图中,图 (a) 是不带 1×1 卷积的版本,不具备降维的作用,图 (b) 是带 1×1 卷积的版本,具有降维的作用,可以降低参数量。
Inception 模块有四条通路,包括三条卷积通路和一条池化通路,具有不同的卷积核大小,不同卷积核大小可以提取不同尺度的特征。最后将不同通路提取的特征 concate 起来,不同通路得到的特征图大小是一致的。
总体架构
GoogLenet 网络的结构如下,总共有 22 层,主干网络都是全部使用卷积神经网络,仅仅在最终的分类上使用全连接层。
GoogLeNet
可以在 GoogLeNet 看到多个 softmax 分支,网络越深,越影响梯度的回传,作者希望通过不同深度的分支增加梯度的回传,用于解决梯度消失问题,并提供一定的正则化,所以在训练阶段使用多个分支结构来进行训练,它们产生的损失加到总损失中,并设置占比权重为 0.3,但是这些分支结构在推理阶段不使用。它们的详细参数可以看下图的注释。
GoogLeNet 分支结构
各层网络的具体参数见下表:
GoogLeNet 网络中各层参数的详细信息
代码实现
Inception 模块
class Inception(nn.Module):
def __init__(self,in_ch,out_ch1,mid_ch13,out_ch13,mid_ch15,out_ch15,out_ch_pool_conv,auxiliary=False):
# auxiliary 用来标记是否要有一条 softmax 分支
super(Inception,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_ch,out_ch1,kernel_size=1,stride=1),
nn.ReLU())
self.conv13 = nn.Sequential(
nn.Conv2d(in_ch,mid_ch13,kernel_size=1,stride=1),
nn.ReLU(),
nn.Conv2d(mid_ch13,out_ch13,kernel_size=3,stride=1,padding=1),
nn.ReLU())
self.conv15 = nn.Sequential(
nn.Conv2d(in_ch,mid_ch15,kernel_size=1,stride=1),
nn.ReLU(),
nn.Conv2d(mid_ch15,out_ch15,kernel_size=5,stride=1,padding=2),
nn.ReLU())
self.pool_conv1 = nn.Sequential(
nn.MaxPool2d(3,stride=1,padding=1),
nn.Conv2d(in_ch,out_ch_pool_conv,kernel_size=1,stride=1),
nn.ReLU())
self.auxiliary = auxiliary
if auxiliary:
self.auxiliary_layer = nn.Sequential(
nn.AvgPool2d(5,3),
nn.Conv2d(in_ch,128,1),
nn.ReLU())
def forward(self,inputs,train=False):
conv1_out = self.conv1(inputs)
conv13_out = self.conv13(inputs)
conv15_out = self.conv15(inputs)
pool_conv_out = self.pool_conv1(inputs)
outputs = torch.cat([conv1_out,conv13_out,conv15_out,pool_conv_out],1) # depth-wise concat
if self.auxiliary:
if train:
outputs2 = self.auxiliary_layer(inputs)
else:
outputs2 = None
return outputs, outputs2
else:
return outputs
实验结果
GoogLeNet 使用了多种方法来进行测试,从而提升精度,如模型集成,最多同时使用 7 个模型进行融合;多尺度测试,使用 256、288、320、352 等尺度对测试集进行测试;crop 裁剪操作,最多达到 144 个不同方式比例的裁剪。
在集成 7 个模型,使用 144 种不同比例裁剪方式下,在 ILSVRC 竞赛中 Top-5 降到 6.67,GoogLeNet 的表现优于之前的其他深度学习网络,并在 ILSVRC 2014 上获奖。
GoogLeNet 分类性能
参考:
- https://medium.com/coinmonks/paper-review-of-googlenet-inception-v1-winner-of-ilsvlc-2014-image-classification-c2b3565a64e7
- https://github.com/pytorch/vision/blob/master/torchvision/models/googlenet.py
- 完整部署CentOS7.2+OpenStack+kvm 云平台环境(4)--用OZ工具制作openstack镜像
- centos下部署NTP时间服务器同步环境记录
- ASP.NET MVC扩展库
- centos7.2部署vnc服务记录
- nginx访问报错:Too many open files accept:
- iptables之NAT端口转发设置
- 使用Combres 库 ASP.NET 网站优化
- jQuery和asp.net mvc相关资源链接
- JavaScriptSerializer 序列化json 时间格式
- Nginx反向代理+负载均衡简单实现(https方式)
- 在网页中给Flash加上超级链接
- ASP.NET MVC HandleErrorAttribute 和 远程链接
- javascript实现数字转大写金额的函数
- 如何在GridView的Footer内显示总计?
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 【TBase开源版测评】Hello, TBase
- linux定位问题常用命令
- 聊聊claudb的zset command
- 腾讯云语音识别v1签名算法详解
- MySQL案例:关于JSON的一个bug
- Confluence 如何查看页面树
- 聊聊claudb的pubsub command
- Nginx证书和Tomcat证书能相互转化吗,请看这里
- 你真的理解 Webpack?请回答下列问题
- docker浅入深出3
- 绘图代码|多组学数据可视化的高端玩法
- Java单元测试——Mock技术配置
- 简单的场景分析LinearLayout 源码
- 避免栽坑之掌握Jenkins工作原理
- 如何检测JavaScript中的死循环?