pytorch+Unet图像分割:将图片中的盐体找出来
什么是图像分割问题呢?简单的来讲就是给一张图像,检测是用框出框出物体,而图像分割分出一个物体的准确轮廓。也这样考虑,给出一张图像 I,这个问题就是求一个函数,从I映射到Mask。至于怎么求这个函数有多种方法。我们可以看到这个图,左边是给出图像,可以看到人和摩托车,右边是分割结果。
求这个函数有很多方法,但是第一次将深度学习结合起来的是这篇文章全卷积网络(FCN),利用深度学习求这个函数。在此之前深度学习一般用在分类和检测问题上。由于用到CNN,所以最后提取的特征的尺度是变小的。和我们要求的函数不一样,我们要求的函数是输入多大,输出有多大。为了让CNN提取出来的尺度能到原图大小,FCN网络利用上采样和反卷积到原图像大小。然后做像素级的分类。
可以看图二,输入原图,经过VGG16网络,得到特征map,然后将特征map上采样回去。再将预测结果和ground truth每个像素一一对应分类,做像素级别分类。也就是说将分割问题变成分类问题,而分类问题正好是深度学习的强项。如果只将特征map直接上采样或者反卷积,明显会丢失很多信息。
FCN采取解决方法是将pool4、pool3、和特征map融合起来,由于pool3、pool4、特征map大小尺寸是不一样的,所以融合应该前上采样到同一尺寸。这里的融合是拼接在一起,不是对应元素相加。
FCN是深度学习在图像分割的开山之作,FCN优点是实现端到端分割等,缺点是分割结果细节不够好,可以看到图四,FCN8s是上面讲的pool4、pool3和特征map融合,FCN16s是pool4和特征map融合,FCN32s是只有特征map,得出结果都是细节不够好,具体可以看自行车。由于网络中只有卷积没有全连接,所以这个网络又叫全卷积网络。
本文将先简单介绍Unet的理论基础,然后使用pytorch一步一步地实现Unet图像分割。因为主要目的是提供一个baseline模型给大家,所以代码主要关注在如何构造Unet的网络结构。
Unet
图1: Unet的网络结构
Unet主要用于图像分割问题。图1是Unet论文中的网络结构图。
https://arxiv.org/abs/1505.04597
很多分割网络都是基于FCNs做改进,包括Unet。Unet包括两部分,可以看右图,第一部分,特征提取,VGG类似。第二部分上采样部分。由于网络结构像U型,所以叫Unet网络。
- 特征提取部分,每经过一个池化层就一个尺度,包括原图尺度一共有5个尺度。
- 上采样部分,每上采样一次,就和特征提取部分对应的通道数相同尺度融合,但是融合之前要将其crop。这里的融合也是拼接。
个人认为改进FCN之处有:
- 多尺度
- 适合超大图像分割,适合医学图像分割
可以看出Unet是一个对称的结构,左半边是Encoder,右半边是Decoder。图像会先经过Encoder处理,再经过Decoder处理,最终实现图像分割。它们分别的作用如下:
- Encoder:使得模型理解了图像的内容,但是丢弃了图像的位置信息。
- Decoder:使模型结合Encoder对图像内容的理解,恢复图像的位置信息。
Encoder的部分和传统的网络结构类似,可以选择图中的结构,也可以选择VGG,ResNet等。随着卷积层的加深,特征图的长宽减小,通道增加。虽然Encoder提取了图像的高级特征,但是丢弃了图像的位置信息。所以在图像识别问题中,模型只需要Encoder的部分。因为图像识别不需要位置信息,只需要提取图像的内容信息。
Decoder的部分是Unet的重点。Decoder中涉及upconvolution这个概念。关于upconvolution,这里不做详细介绍,简单来说就是convolution的反向运算。Decoder的每一层都通过upconvolution(图中绿色箭头),并且和Encoder相对应的初级特征结合(图中的灰色箭头),逐渐恢复图像的位置信息。在Decoder中,随着卷积层的加深,特征图的长宽增大,通道减少。
Unet——输入输出
医学图像是一般相当大,但是分割时候不可能将原图太小输入网络,所以必须切成一张一张的小patch,在切成小patch的时候,Unet由于网络结构原因适合有overlap的切图,可以看图,红框是要分割区域,但是在切图时要包含周围区域,overlap另一个重要原因是周围overlap部分可以为分割区域边缘部分提供文理等信息。可以看黄框的边缘,分割结果并没有受到切成小patch而造成分割情况不好。
本文用到的数据来源于Kaggle盐体分割比赛。这次比赛的问题是一个非常典型的图像分割问题。比赛中的大佬们基本上都用的Unet。
我们的目标就是将图片中的盐体找出来。盐体有一些我不太懂的经济价值,反正是很有意义的。
以下是一些图片样例:
PyTorch实现
代码 以及运行教程 获取:
关注微信公众号 datayx 然后回复 分割 即可获取。
AI项目体验地址 https://loveai.tech
Unet
本文定义的Unet网络结构和论文中的略有不同,但本质都采用的是Encoder和Decoder的结构。主要的不同点是:
- Encoder的backbone基于ResNet18
- 输入和输出图像大小一致
以下是Unet网络结构的pytorch代码,代码后附了详细的解释。
- 这里定义了两个class:
Decoder
和Unet
。Unet
是整个模型的结构,Decoder
则是模型Decoder中的单层。 - 使用pytorch构造模型时,需要基于
nn.Module
定义类。forward
函数定义前向传播的逻辑。 -
Decoder
中的up
运算定义为nn.ConvTranspose2d
,也就是upconvolution;conv_relu
则定义为nn.Conv2d
和nn.ReLU
的组合。pytorch中需要用到nn.Sequential
将多个运算组合在一起。 -
Decoder
中forward
函数定义了其前向传播的逻辑:1. 对特征图x1做upconvolution。2. 将x1和x2(encoder中对应的特征图)组合(concatenate)。3. 对组合后的特征图做卷积和relu。 - 因为
Unet
基于resnet18,所以定义运算时从torchvision.models.resnet18
取出来就可以。因为resnet18默认的是适用于RGB图片,而比赛中的图片是灰的,只有一个通道,所以layer1
中的卷基层需要自己定义。 -
layer1
到layer5
属于encoder,encode4
到encode0
属于decoder,呈对称结构。 - 下表是经过各层的处理后,特征图的长/宽和通道数:
Dataset
如果你了解keras,那么就会发现pytorch中的Dataset
和keras中的DataGenerator
类似。不同的是pytorch定义的Dataset
只返回1个样本,再通过DataLoader
定义batch_size
。
Dataset的逻辑很简单,分为以下几步:
- 读取图片
- 预处理(resize, pad, 数据增强等)
- 返回图片和Mask
Pytorch代码如下:
Optimizer
optimizer采用的是SGD,同时用到了余弦退火学习率和快照集成来提升模型效果。
结论
在没有数据增强和TTA等其他手段的情况下,本文的代码能够拿到0.76的成绩,是一个不错的baseline模型。
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 二叉树:看看这些树的最大深度
- C++核心准则SF.5: .cpp文件必须包含定义它接口的.h文件
- C++核心准则SF.6:(只)为转换,基础库或在局部作用域内部使用using namspace指令
- C++核心准则SF.7:不要在头文件中的全局作用域中使用using namespace指令
- 二叉树:看看这些树的最小深度
- Hive初体验
- Hive数据的存储以及在centos7下进行Mysql的安装
- 一个改进的数学学习工具
- 配置hive的元数据到Mysql中
- 二叉树:我有多少个节点?
- POST请求和GET请求如何传递和接收解析参数
- 二叉树:我平衡么?
- 机器学习中的常用编码方式(二)
- 个人Next主题配置文件
- 数组中出现次数超过一半的数字