空间金字塔池化(Spatial Pyramid Pooling, SPP)原理和代码实现(Pytorch)
想直接看公式的可跳至第三节 3.公式修正
一、为什么需要SPP
首先需要知道为什么会需要SPP。
我们都知道卷积神经网络(CNN)由卷积层和全连接层组成,其中卷积层对于输入数据的大小并没有要求,唯一对数据大小有要求的则是第一个全连接层,因此基本上所有的CNN都要求输入数据固定大小,例如著名的VGG模型则要求输入数据大小是 (224*224) 。
固定输入数据大小有两个问题:
1.很多场景所得到数据并不是固定大小的,例如街景文字基本上其高宽比是不固定的,如下图示红色框出的文字。
2.可能你会说可以对图片进行切割,但是切割的话很可能会丢失到重要信息。
综上,SPP的提出就是为了解决CNN输入图像大小必须固定的问题,从而可以使得输入图像高宽比和大小任意。
二、SPP原理
更加具体的原理可查阅原论文:Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition
上图是原文中给出的示意图,需要从下往上看:
- 首先是输入层(input image),其大小可以是任意的
- 进行卷积运算,到最后一个卷积层(图中是(conv_5))输出得到该层的特征映射(feature maps),其大小也是任意的
- 下面进入SPP层
- 我们先看最左边有16个蓝色小格子的图,它的意思是将从(conv_5)得到的特征映射分成16份,另外16X256中的256表示的是channel,即SPP对每一层都分成16份(不一定是等比分,原因看后面的内容就能理解了)。
- 中间的4个绿色小格子和右边1个紫色大格子也同理,即将特征映射分别分成4X256和1X256份
那么将特征映射分成若干等分是做什么用的呢? 我们看SPP的名字就是到了,是做池化操作,一般选择MAX Pooling,即对每一份进行最大池化。
我们看上图,通过SPP层,特征映射被转化成了16X256+4X256+1X256 = 21X256的矩阵,在送入全连接时可以扩展成一维矩阵,即1X10752,所以第一个全连接层的参数就可以设置成10752了,这样也就解决了输入数据大小任意的问题了。
注意上面划分成多少份是可以自己是情况设置的,例如我们也可以设置成3X3等,但一般建议还是按照论文中说的的进行划分。
三、SPP公式
理论应该理解了,那么如何实现呢?下面将介绍论文中给出的计算公式,但是在这之前先要介绍两种计算符号以及池化后矩阵大小的计算公式:
1.预先知识
取整符号:
⌊⌋:向下取整符号 ⌊59/60⌋=0,有时也用 floor() 表示
⌈⌉:向上取整符号 ⌈59/60⌉=1, 有时也用ceil() 表示
池化后矩阵大小计算公式:
没有步长(Stride):(h+2p−f+1)∗(w+2p−f+1)
有步长(Stride):⌊h+2p−fs+1⌋*⌊w+2p−fs+1⌋
2.公式
3.公式修正
Kh=⌈hinn⌉=ceil(hinn)
Sh=⌈hinn⌉=ceil(hinn)
ph=⌊kh∗n−hin+12⌋=floor(kh∗n−hin+12)
hnew=2∗ph+hin
Kw=⌈winn⌉=ceil(winn)
Sw=⌈winn⌉=ceil(winn)
pw=⌊kw∗n−win+12⌋=floor(kw∗n−win+12)
wnew=2∗pw+win
kh: 表示核的高度
Sh: 表示高度方向的步长
ph: 表示高度方向的填充数量,需要乘以2
注意核和步长的计算公式都使用的是ceil(),即向上取整,而padding使用的是floor(),即向下取整。
现在再来检验一下: 假设输入数据大小和上面一样是((10, 7, 11)), 池化数量为((4,4)):
Kernel大小为((2,3)),Stride大小为((2,3)),所以Padding为((1,1))。
利用矩阵大小计算公式:⌊(frac{h+2p-f}{s})+1⌋*⌊(frac{w+2p-f}{s})+1⌋得到池化后的矩阵大小为:(4*4)。
四、代码实现(Python)
这里我使用的是PyTorch深度学习框架,构建了一个SPP层,代码如下:
#coding=utf-8
import math
import torch
import torch.nn.functional as F
# 构建SPP层(空间金字塔池化层)
class SPPLayer(torch.nn.Module):
def __init__(self, num_levels, pool_type='max_pool'):
super(SPPLayer, self).__init__()
self.num_levels = num_levels
self.pool_type = pool_type
def forward(self, x):
num, c, h, w = x.size() # num:样本数量 c:通道数 h:高 w:宽
for i in range(self.num_levels):
level = i+1
kernel_size = (math.ceil(h / level), math.ceil(w / level))
stride = (math.ceil(h / level), math.ceil(w / level))
pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2))
# 选择池化方式
if self.pool_type == 'max_pool':
tensor = F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
else:
tensor = F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
# 展开、拼接
if (i == 0):
x_flatten = tensor.view(num, -1)
else:
x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1)
return x_flatten
上述代码参考: sppnet-pytorch
为防止原作者将代码删除,我已经Fork了,也可以通过如下地址访问代码: marsggbo/sppnet-pytorch
- Spring Cloud Feign 启动UnsatisfiedDependencyException
- Spring Cloud Zuul结合Smconf配置中心动态进行IP黑名单限制
- 高性能NIO框架Netty入门篇
- Spring Boot Web 静态文件缓存处理
- hbuilder 开发APP填坑经验
- hbuilder APP 定位提示苹果审核不通过
- hbuilder 开发5+ APP采坑记录
- Spring Cloud如何提供API给客户端
- 5分钟学会Spring Boot自定义属性和自动配置
- 创建一个Spring Security OAuth认证服务
- Zipkin和微服务链路跟踪
- Java8真不用再搞循环了?
- 针对事件驱动架构的Spring Cloud Stream
- Spring的三种Circuit Breaker
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- ThinkPHP5.1框架页面跳转及修改跳转页面模版示例
- Python Switch Case三种实现方法代码实例
- PHP正则表达式笔记与实例详解
- php实现微信分享朋友链接功能
- python交互模式基础知识点学习
- 浅析Python 抽象工厂模式的优缺点
- 基于PyTorch的permute和reshape/view的区别介绍
- Laravel配置全局公共函数的方法步骤
- PHP5.5基于mysqli连接MySQL数据库和读取数据操作实例详解
- python–shutil移动文件到另一个路径的操作
- PHP正则表达式处理函数(PCRE 函数)实例小结
- yii2的restful api路由实例详解
- PHP实现的权重算法示例【可用于游戏根据权限来随机物品】
- 用Python爬取LOL所有的英雄信息以及英雄皮肤的示例代码
- Python操作MySQL数据库的示例代码