基于飞桨复现CVPR 2016 MCNN的过程解析:教你更精确估算人流密度
MCNN是CVPR2016年的一篇论文中提出的一种神经网络。在论文中,作者提出了一种简单而有效的多列卷积神经网络架构(Multi-column Convolutional Neural Network, MCNN),通过使用大小不同的卷积核去适应人的头部大小的变化,将图片映射为人群密度图。
文章中作者收集并标记了一个大型的新数据集(ShanghaiTech数据集),其中包括1198幅图像,并使用了几何自适应高斯核基于数据集中的标记图像推理出人群密度图作为Ground Truth,然后使用图像和Ground Truth图对MCNN模型进行训练。下面本文将重点介绍MCNN的实现原理,并基于飞桨完成模型复现。其中ShanghaiTech数据集可以在飞桨的AI学习与实训社区AI Studio上下载到。
论文地址如下:
https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Zhang_Single-Image_Crowd_Counting_CVPR_2016_paper.pdf
ShanghaiTech数据集的下载地址如下所示:
https://aistudio.baidu.com/aistudio/datasetdetail/10675
什么是MCNN
MCNN受MDNNs[1]的启发,由三列并行的CNN组成,每列CNN卷积核大小不同。为了简化,所有列使用相同的网络结构(即conv-pooling-conv-pooling)。每次池化都会使用2*2的Max Pooling,而激活函数全部选择Relu。堆叠三列CNN的输出特征图,并使用1*1的卷积核将其映射为密度图。MCNN的整体架构图如图1所示:
图 1:用于人群密度图估计的多列卷积神经网络(MCNN)的结构
.
MCNN在训练时,存在数据样本少和梯度消失的问题,受预训练模型RBM[2]的启发,作者将三列CNN单独做预训练,将这些预训练的CNN参数初始化为对应的MCNN参数并微调。需要补充的是,MCNN使用了最简单的均方误差作为损失函数。
论文中使用几何自适应高斯核去计算数据图片的Ground Truth:
对于每个在给定图片中的人头的位置xi,其在图片中可以表示为冲激函数
,计算出其k个最近邻居的距离为
,所以平均距离为
。
为了估计像素xi周围的人群密度,作者将
与方差为
的高斯核进行卷积操作,
与
成正比,即
,经过作者多次试验,发现
时效果最好。
所以最终的Ground Truth人群密度图F(x)应该为:
在图 2中,显示了两张图片的人群密度图。值得说明的是,由于经过了两次下采样,所以预测出人群密度图的分辨率变为原来的1/4。
图 2:原始图像和通过几何自适应高斯核进行卷积获得的相应的人群密度图。
MCNN几乎可以从任何观察角度准确估计单个图像中的人群数,在2016年,取得了人群计数领域state-of-art的成绩。同时作者还指出,仅需要对模型最后几层进行微调,便可以将模型轻松迁移到目标问题,验证了模型的鲁棒性。
在论文中,还有很多细节,本篇不再赘述,具体可以查看原论文:
https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Zhang_Single-Image_Crowd_Counting_CVPR_2016_paper.pdf
基于飞桨复现MCNN
近日,笔者基于飞桨开源深度学习框架复现了MCNN,下面我们将复现的技术细节与开发者分享。
MCNN模型结构定义:
https://github.com/DrRyanHuang/MCNN_Paddlepaddle/blob/master/model.py
01
搭建MCNN网络
使用飞桨搭建MCNN网络十分简单,单独建立三层CNN,将其输出纵向堆叠并使用1*1的卷积核将通道数降为1即可。
为了帮助我们更方便地定义网络,先写一个CONVLMS模型单独定义每一层CNN,以方便后续的操作。
class CONVLMS(fluid.dygraph.Layer):
def __init__(self, channel_list, filter_list, use_bias=False):
# 通过传入参数 `channel_list`和 `filter_list` 定义每一列CNN
# `channel_list` 定义CNN中所有卷积的卷积核个数
# `filter_list` 定义CNN中所有卷积的卷积核size
super(CONVLMS, self).__init__()
# 定义卷积操作
self.conv1 = Conv2D(
num_channels=3,
num_filters=channel_list[0],
filter_size=filter_list[0],
stride=1,
padding=filter_list[0]//2,
act='relu',
bias_attr=use_bias)
# 定义批归一化操作
self.batch_norm1 = fluid.BatchNorm( channel_list[0] )
self.conv2 = Conv2D(
num_channels=channel_list[0],
num_filters=channel_list[1],
filter_size=filter_list[1],
stride=1,
padding=filter_list[1]//2,
act='relu',
bias_attr=use_bias)
self.batch_norm2 = fluid.BatchNorm( channel_list[1] )
# 定义池化操作
self.pool1 = Pool2D(
pool_size = 2,
pool_type = 'max',
pool_stride = 2,
global_pooling = False,)
self.conv3 = Conv2D(
num_channels=channel_list[1],
num_filters=channel_list[2],
filter_size=filter_list[2],
stride=1,
padding=filter_list[2]//2,
act='relu',
bias_attr=use_bias)
self.batch_norm3 = fluid.BatchNorm( channel_list[2] )
self.pool2 = Pool2D(
pool_size = 2,
pool_type = 'max',
pool_stride = 2,
global_pooling = False,)
self.conv4 = Conv2D(
num_channels=channel_list[2],
num_filters=channel_list[3],
filter_size=filter_list[3],
stride=1,
padding=filter_list[3]//2,
act='relu',
bias_attr=use_bias)
self.batch_norm4 = fluid.BatchNorm( channel_list[3] )
def forward(self, input):
conv1 = self.conv1(input)
conv1 = self.batch_norm1(conv1)
conv2 = self.conv2(conv1)
conv2 = self.batch_norm2(conv2)
pool1 = self.pool1(conv2)
conv3 = self.conv3(pool1)
conv3 = self.batch_norm3(conv3)
pool2 = self.pool2(conv3)
conv4 = self.conv4(pool2)
conv4 = self.batch_norm4(conv4)
return conv4
飞桨支持静态图和动态图两种网络定义模式,为方便调试本文选用动态图。如上代码就是定义了MCNN中的每一层CNN。
具体参数含义如下:
- channel_list : 对应MCNN的每一次卷积的通道数列表;
- filter_list : 对应MCNN每一次卷积的卷积核个数;
- use_bias : 是否使用偏置,默认为False。
接着我们定义MCNN:
class MCNN(fluid.dygraph.Layer):
def __init__(self):
super(MCNN,self).__init__()
# 该列表是指每一列CNN的每次卷积的卷积核数量
channel_list_L = [16, 32, 16, 8]
# 该列表是指每一列CNN的每次卷积的卷积核size
filter_list_L = [9, 7, 7, 7]
# 通过之前定义的类 `CONVLMS` 传入参数 `channel_list_L` 和 `filter_list_L`
# 来定义每一列 `CNN`
self.CNN_L = CONVLMS(channel_list_L, filter_list_L)
channel_list_M = [20, 40, 20, 10]
filter_list_M = [7, 5, 5, 5]
self.CNN_M = CONVLMS(channel_list_M, filter_list_M)
channel_list_S = [24, 48, 24, 12]
filter_list_S = [5, 3, 3, 3]
self.CNN_S = CONVLMS(channel_list_S, filter_list_S)
# 定义最后一层CNN, 将多通道转化为单通道
self.convall = Conv2D(
num_channels=30,
num_filters=1,
filter_size=1,
stride=1,
padding=0,
act='relu',
bias_attr=use_bias)
def forward(self, inputs):
cnn_L = self.CNN_L(inputs)
cnn_M = self.CNN_M(inputs)
cnn_S = self.CNN_S(inputs)
# 将三列CNN的结果结合在一起
convall_pre = concat([cnn_L, cnn_M, cnn_S], axis=1)
convall = self.convall(convall_pre)
return convall
MCNN将三列CNN的输出堆叠在一起,并使用1*1的卷积核将其通道数降为1.
02
训练策略
MCNN受到预训练模型的启发,先将三列CNN单独训练,之后将CNN的参数初始化为MCNN对应的参数之后,整体再进行训练,在原论文中,训练策略为批随机梯度下降法。
03
模型复现效果
我们可以对比使用飞桨的训练效果和原论文的训练效果,可以看出在AI Studio平台的算力加持下,基于飞桨的训练效果更加精确,多了更多细节。
图 3:原论文中两张测试集图片的结果对比,从左到右分别是原图,Ground Truth,原论文复现图和基于飞桨的复现结果图
[1] D. Ciresan, U. Meier, and J. Schmidhuber. Multi-column deep neural networks for image classification. In CVPR, pages 3642–3649. IEEE, 2012.
[2] G. Hinton, S. Osindero, and Y. Teh. A fast learning algorithm for deep belief nets. NEURAL COMPUT, 18(7):1527–1554, 2006.
[Thanks for reading.]
·飞桨官网地址·
https://www.paddlepaddle.org.cn/
·飞桨开源框架项目地址·
GitHub:
https://github.com/PaddlePaddle/Paddle
Gitee:
https://gitee.com/paddlepaddle/Paddle
扫描二维码 | 关注我们
微信号 : PaddleOpenSource
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Maven私服搭建
- Java线程状态详解
- 设计模式~命令模式
- 基于DelayQueue实现的带失效时间的缓存
- 基于AQS实现的简单的Semaphore
- 图解:基于B+树索引结构,MySQL可以这么优化
- Android开发笔记:Retrofit + OkHttp3 + coroutines + LiveData打造一款网络请求框架
- Nginx安装与使用
- 基于Redis实现分布式锁
- 通过简单代码示例了解七大软件设计原则
- Flink在新浪微博的在线机器学习和实时数据分析
- Nginx + Keepalived使用文档
- 22+ 高频实用的 JavaScript 片段 (2020年)
- 文件上传C:fakepath解决方案
- Asp.net web api部署在某些服务器上老是404