faster-rcnn中ROI_POOIING层的解读
时间:2022-05-06
本文章向大家介绍faster-rcnn中ROI_POOIING层的解读,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
在没有出现sppnet之前,RCNN使用corp和warp来对图片进行大小调整,这种操作会造成图片信息失真和信息丢失。sppnet这个模型推出来之后(关于这个网络的描述,可以看看之前写的一篇理解:http://www.cnblogs.com/gongxijun/p/7172134.html),rg大神沿用了sppnet的思路到他的下一个模型中fast-rcnn中,但是roi_pooling和sppnet的思路虽然相同,但是实现方式还是不同的.我们看一下网络参数:
layer {
name: "roi_pool5"
type: "ROIPooling"
bottom: "conv5_3"
bottom: "rois"
top: "pool5"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625 # 1/16
}
结合源代码,作者借助了sppnet的空域金字塔pool方式,但是和sppnet并不同的是,作者在这里只使用了(pooled_w,pooled_h)这个尺度,来将得到的每一个特征图分成(pooled_w,pooled_h),然后对每一块进行max_pooling取值,最后得到一个n*7*7固定大小的特征图。
1 // ------------------------------------------------------------------
2 // Fast R-CNN
3 // Copyright (c) 2015 Microsoft
4 // Licensed under The MIT License [see fast-rcnn/LICENSE for details]
5 // Written by Ross Girshick
6 // ------------------------------------------------------------------
7
8 #include <cfloat>
9
10 #include "caffe/fast_rcnn_layers.hpp"
11
12 using std::max;
13 using std::min;
14 using std::floor;
15 using std::ceil;
16
17 namespace caffe {
18
19 template <typename Dtype>
20 void ROIPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
21 const vector<Blob<Dtype>*>& top) {
22 ROIPoolingParameter roi_pool_param = this->layer_param_.roi_pooling_param();
23 CHECK_GT(roi_pool_param.pooled_h(), 0)
24 << "pooled_h must be > 0";
25 CHECK_GT(roi_pool_param.pooled_w(), 0)
26 << "pooled_w must be > 0";
27 pooled_height_ = roi_pool_param.pooled_h(); //定义网络的大小
28 pooled_width_ = roi_pool_param.pooled_w();
29 spatial_scale_ = roi_pool_param.spatial_scale();
30 LOG(INFO) << "Spatial scale: " << spatial_scale_;
31 }
32
33 template <typename Dtype>
34 void ROIPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
35 const vector<Blob<Dtype>*>& top) {
36 channels_ = bottom[0]->channels();
37 height_ = bottom[0]->height();
38 width_ = bottom[0]->width();
39 top[0]->Reshape(bottom[1]->num(), channels_, pooled_height_,
40 pooled_width_);
41 max_idx_.Reshape(bottom[1]->num(), channels_, pooled_height_,
42 pooled_width_);
43 }
44
45 template <typename Dtype>
46 void ROIPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
47 const vector<Blob<Dtype>*>& top) {
48 const Dtype* bottom_data = bottom[0]->cpu_data();
49 const Dtype* bottom_rois = bottom[1]->cpu_data();//获取roidb信息(n,x1,y1,x2,y2)
50 // Number of ROIs
51 int num_rois = bottom[1]->num();//候选目标的个数
52 int batch_size = bottom[0]->num();//特征图的维度,vgg16的conv5之后为512
53 int top_count = top[0]->count();//需要输出的值个数
54 Dtype* top_data = top[0]->mutable_cpu_data();
55 caffe_set(top_count, Dtype(-FLT_MAX), top_data);
56 int* argmax_data = max_idx_.mutable_cpu_data();
57 caffe_set(top_count, -1, argmax_data);
58
59 // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
60 for (int n = 0; n < num_rois; ++n) {
61 int roi_batch_ind = bottom_rois[0];
62 int roi_start_w = round(bottom_rois[1] * spatial_scale_);//缩小16倍,将候选区域在原始坐标中的位置,映射到conv_5特征图上
63 int roi_start_h = round(bottom_rois[2] * spatial_scale_);
64 int roi_end_w = round(bottom_rois[3] * spatial_scale_);
65 int roi_end_h = round(bottom_rois[4] * spatial_scale_);
66 CHECK_GE(roi_batch_ind, 0);
67 CHECK_LT(roi_batch_ind, batch_size);
68
69 int roi_height = max(roi_end_h - roi_start_h + 1, 1);//得到候选区域在特征图上的大小
70 int roi_width = max(roi_end_w - roi_start_w + 1, 1);
71 const Dtype bin_size_h = static_cast<Dtype>(roi_height)
72 / static_cast<Dtype>(pooled_height_);//计算如果需要划分成(pooled_height_,pooled_weight_)这么多块,那么每一个块的大小(bin_size_w,bin_size_h);
73 const Dtype bin_size_w = static_cast<Dtype>(roi_width)
74 / static_cast<Dtype>(pooled_width_);
75
76 const Dtype* batch_data = bottom_data + bottom[0]->offset(roi_batch_ind);//获取当前维度的特征图数据,比如一共有(n,x1,x2,x3,x4)的数据,拿到第一块特征图的数据
77
78 for (int c = 0; c < channels_; ++c) {
79 for (int ph = 0; ph < pooled_height_; ++ph) {
80 for (int pw = 0; pw < pooled_width_; ++pw) {
81 // Compute pooling region for this output unit:
82 // start (included) = floor(ph * roi_height / pooled_height_)
83 // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
84 int hstart = static_cast<int>(floor(static_cast<Dtype>(ph)
85 * bin_size_h)); //计算每一块的位置
86 int wstart = static_cast<int>(floor(static_cast<Dtype>(pw)
87 * bin_size_w));
88 int hend = static_cast<int>(ceil(static_cast<Dtype>(ph + 1)
89 * bin_size_h));
90 int wend = static_cast<int>(ceil(static_cast<Dtype>(pw + 1)
91 * bin_size_w));
92
93 hstart = min(max(hstart + roi_start_h, 0), height_);
94 hend = min(max(hend + roi_start_h, 0), height_);
95 wstart = min(max(wstart + roi_start_w, 0), width_);
96 wend = min(max(wend + roi_start_w, 0), width_);
97
98 bool is_empty = (hend <= hstart) || (wend <= wstart);
99
100 const int pool_index = ph * pooled_width_ + pw;
101 if (is_empty) {
102 top_data[pool_index] = 0;
103 argmax_data[pool_index] = -1;
104 }
105
106 for (int h = hstart; h < hend; ++h) {
107 for (int w = wstart; w < wend; ++w) {
108 const int index = h * width_ + w;
109 if (batch_data[index] > top_data[pool_index]) {
110 top_data[pool_index] = batch_data[index]; //在取每一块中的最大值,就是max_pooling操作.
111 argmax_data[pool_index] = index;
112 }
113 }
114 }
115 }
116 }
117 // Increment all data pointers by one channel
118 batch_data += bottom[0]->offset(0, 1);
119 top_data += top[0]->offset(0, 1);
120 argmax_data += max_idx_.offset(0, 1);
121 }
122 // Increment ROI data pointer
123 bottom_rois += bottom[1]->offset(1);
124 }
125 }
126
127 template <typename Dtype>
128 void ROIPoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
129 const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
130 NOT_IMPLEMENTED;
131 }
132
133
134 #ifdef CPU_ONLY
135 STUB_GPU(ROIPoolingLayer);
136 #endif
137
138 INSTANTIATE_CLASS(ROIPoolingLayer);
139 REGISTER_LAYER_CLASS(ROIPooling);
140
141 } // namespace caffe
进过以上的操作过后,就得到了固定大小的特征图啦,然后就可以进行全连接操作了. 但愿我说明白了.
---完.
- 研究人员发现一种利用Siri窃取苹果iPhone/iPad数据的方法
- 关于Python中的__main__和编程模板
- 世界大战尽在掌控:盘点全球网络攻击实时追踪系统
- Activity数据传递
- apache反向代理一、泛解析域名二、APACHE配置
- Python学习 - 可视化变量赋值、循环、程序运行过程
- jdk源码分析红黑树——插入篇1.插入root2.父黑3.父红4.父红,叔红5.1父红,叔黑,外侧子孙5.2父红,叔黑,内侧子孙
- WIFI环境下Android手机和电脑通信
- 破解之美:利用ECB加密缺陷突破cookie加密
- 让Python猜猜你是否能约会成功
- python学习笔记之初识Python
- 从APK解密到批量获取他人信息
- React native和原生之间的通信
- 移动云存储服务平台Parse下的iOS安全分析
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Swift日常开发随笔
- vue入门003~vue项目引入element并创建一个登录页面
- vue入门002~vue项目的两种创建方式
- IntelliJ IDEA,WebStorm,PhpStorm破解到2089年
- 小程序订阅消息推送(含源码)java实现小程序推送,springboot实现微信消息推送
- 借助云开发10行代码实现短信验证码的发送
- 借助云开发实现小程序订阅消息(模板消息)推送功能
- 1小时实战入门小程序开发,历史上的今天案例讲解
- 小程序实现全屏幕高斯模糊背景图
- 小程序顶部导航栏,可滑动,可动态选中放大
- 小程序不同页面的异步回调,callback和promise的使用讲解
- java入门019~springboot批量导入excel数据到mysql
- Java点餐系统和点餐小程序新加微信消息推送功能
- Java点餐系统和点餐小程序新加排号等位功能
- IDEA上给文件添加姓名,日期,版本号