TN-SCUI2020挑战赛——基于注意力机制改进的Vnet2d分割网络
今天将分享甲状腺超声结节二值分割的改进模型,为了方便大家学习理解整个流程,将整个流程步骤进行了整理,并给出详细的步骤结果。感兴趣的朋友赶紧动手试一试吧。
一、超声图像分析与预处理
(1)、3644张超声结节原始数据和标注数据及训练标签文件可以在官网上下载获取到,下载下来后如下所示。测试数据一共有910张数据,目前也可以在官网上下载了。
(2)、由于原始数据大小不一样,这里对图像做统一大小设置,都设置成512x512的大小。
(3)、使用全部的3644例数据来训练,为了增加模型鲁棒性,可以对原始数据做数据增强处理,但是在这里就不做数据增强操作了,直接在原始数据量上进行分割和分类。
(4)、原始图像和金标准Mask图像的预处理还需要做归一化操作,统一都归一化到(0,1)。
二、AGVNet2d分割网络
(1)、搭建AG模块,如下图所示。整个过程很简单,首先分别对上采样信号和编码器特征进行卷积计算,然后将两者结果进行逐元素相加,再经过relu激活函数,再经过1x1的卷积计算和sigmoid激活函数产生AG系数,最后在与编码器特征进行相乘,得到最终AG结果。
代码实现如下:
def AGModel(x, signal, kernalshape, phase, height=None, width=None, scope=None):
with tf.name_scope(scope):
# attention input
Wg = weight_xavier_init(shape=kernalshape, n_inputs=kernalshape[0] * kernalshape[1] * kernalshape[2],
n_outputs=kernalshape[-1], activefunction='relu', variable_name=str(scope) + 'Wg')
Bg = bias_variable([kernalshape[-1]], variable_name=str(scope) + 'Bg')
convg = conv2d(signal, Wg) + Bg
convg = normalizationlayer(convg, phase, height=height, width=width, norm_type='group',
scope=str(scope) + 'normg')
# input
Wf = weight_xavier_init(shape=kernalshape, n_inputs=kernalshape[0] * kernalshape[1] * kernalshape[2],
n_outputs=kernalshape[-1], activefunction='relu', variable_name=str(scope) + 'Wf')
Bf = bias_variable([kernalshape[-1]], variable_name=str(scope) + 'Bf')
convf = conv2d(x, Wf) + Bf
convf = normalizationlayer(convf, phase, height=height, width=width, norm_type='group',
scope=str(scope) + 'normf')
# add input and attention input
convadd = resnet_Add(x1=convg, x2=convf)
convadd = tf.nn.relu(convadd)
# generate attention gat coe
attencoekernalshape = (1, 1, kernalshape[-1], 1)
Wpsi = weight_xavier_init(shape=attencoekernalshape,
n_inputs=attencoekernalshape[0] * attencoekernalshape[1] * attencoekernalshape[2],
n_outputs=attencoekernalshape[-1], activefunction='sigomd',
variable_name=str(scope) + 'Wpsi')
Bpsi = bias_variable([attencoekernalshape[-1]], variable_name=str(scope) + 'Bpsi')
convpsi = conv2d(convadd, Wpsi) + Bpsi
convpsi = normalizationlayer(convpsi, phase, height=height, width=width, norm_type='group',
scope=str(scope) + 'normpsi')
convpsi = tf.nn.sigmoid(convpsi)
# generate attention gat coe
attengatx = tf.multiply(x, convpsi)
return attengatx
(2)、搭建AGVNet2d模型,网络输入大小是(512,512),主要是对原始的跳跃连接进行改进,原先是将编码器的输出与上采样的输出直接进行拼接,作为解码器的输入,改进的地方首先是将编码器的输出与上采样的输出输入到AG模块中,然后将产生的AG结果与上采样的输出进行拼接,作为解码器的输入。
(2)、loss采用的是二分类的dice函数。
代码实现如下:
def __get_cost(self, cost_name, Y_gt, Y_pred):
H, W, C = Y_gt.get_shape().as_list()[1:]
if cost_name == <span data-raw-text="" "="" data-textnode-index="27" data-index="547" class="character">"dice coefficient<span data-raw-text="" "="" data-textnode-index="27" data-index="564" class="character">":
smooth = 1e-5
pred_flat = tf.reshape(Y_pred, [-1, H * W * C])
true_flat = tf.reshape(Y_gt, [-1, H * W * C])
intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + smooth
denominator = tf.reduce_sum(pred_flat, axis=1) + tf.reduce_sum(true_flat, axis=1) + smooth
loss = -tf.reduce_mean(intersection / denominator)
return loss
具体实现可以参考Tensorflow入门教程(三十四)——常用两类图像分割损失函数。
(3)、分割损失结果和精度经过如下图所示。
为了方便大家更高效地学习,我将代码进行了整理并更新到github上,点击原文链接即可访问。此外训练好的模型文件已经上传至百度云盘:链接:https://pan.baidu.com/s/15E-RvaqSLdDMBCUZoGTrDA
提取码:hszn。
如果大家觉得这个项目还不错,希望大家给个Star并Fork,可以让更多的人学习。
- spring boot 登录注册 demo (四) -- 体验小结
- jenkins 时区设置
- 什么样的密码才是安全的?
- MAC本遭遇ARP攻击的处理办法
- nodejs 语法学习(持续更新)
- Django——模板层(template)(模板语法、自定义模板过滤器及标签、模板继承)
- - Templates should only be responsible for mapping the state to the UI. Avoid placing tags with side
- Django - - - -视图层之视图函数(views)
- fiddler mock ==> AutoResponder
- 基于Node.js开发跨平台窗口程序
- Django视图层之路由配置系统(urls)
- java String时间转为时间戳
- linux 简易启动脚本
- 2017年我国大数据产业发展五大新突破
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法