热力图与原始图像融合
时间:2022-07-22
本文章向大家介绍热力图与原始图像融合,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
使用神经网络进行预测时,一个明显的缺陷就是缺少可解释性,我们不能通过一些简单的方法来知道网络做出决策或者预测的理由,这在很多方面就使得它的应用受限。 虽然不能通过一些数学方法来证明模型的有效性,但我们仍能够通过一些可视化热力图的方法来观测一下原始数据中的哪些部分对我们网络影响较大。 实现热力图绘制的方法有很多,如:CAM, Grad-CAM, Contrastive EBP等。在热力图生成之后,因为没有原始数据信息,所以我们并不能很直观地观测到模型到底重点关注了图像的哪些区域。这时将热力图叠加到原始图像上的想法就会很自然的产生。这里存在的一个问题是原始图像的色域空间可能和产生的热力图的色域空间是不一致的,当二者叠加的时候,会产生颜色的遮挡。并且因为产生的热力图的尺寸应该与原始图像尺寸一致或者调整到与原始尺寸一致,这样当二者直接简单地叠加的话,产生的图像可能并不是我们想要的,因此,我们需要先对热力图数据进行一些简单的像素处理,然后在考虑与原始图像的融合。以下部分的安排为:1. 热力图的产生 2. 热力图与原始图的叠加 3. 热力图与原始图融合优化
1. 热力图产生
在这里使用3D-Grad-CAM的方法来实现热力图绘制的方法,使用的图像尺寸为144, 168, 152 代码如下:
def cam(img_path, model_path, relu=True, sigmoid=False):
# grad-cam
img_data = np.load(img_path)
img_data = img_data[np.newaxis, :, :, :, np.newaxis]
max_ = np.max(img_data)
min_ = np.min(img_data)
img_data = (img_data - min_) / (max_ - min_)
model = load_model(model_path)
model.summary()
index = 0
pred = model.predict(img_data)
if sigmoid:
if pred >= 0.5:
index = 1
else:
max_ = np.max(pred)
for i in range(4):
if pred[0][i] == max_:
index = i
break
print(pred)
print("index: ", index)
pre_output = model.output[:, index]
last_conv_layer = model.get_layer('conv3d_7')
grads = K.gradients(pre_output, last_conv_layer.output)[0]
pooled_grads = K.mean(grads, axis=(0, 1, 2, 3))
iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])
pooled_grads_value, conv_layer_output_value = iterate([img_data])
if relu:
conv_layer_output_value[np.where(conv_layer_output_value < 0)] = 0
conv_max = np.max(conv_layer_output_value)
conv_min = np.min(conv_layer_output_value)
conv_layer_output_value = (conv_layer_output_value - conv_min) / (conv_max - conv_min)
pool_max = np.max(pooled_grads_value)
pool_min = np.min(pooled_grads_value)
pooled_grads_value = (pooled_grads_value - pool_min) / (pool_max - pool_min)
layer_number = len(pooled_grads_value)
for i in range(layer_number):
conv_layer_output_value[:, :, :, i] *= pooled_grads_value[i]
# along the last dim calculate the mean value
heatmap = np.mean(conv_layer_output_value, axis=-1)
# remove the value which less than 0
heatmap = np.maximum(heatmap, 0)
# uniformization
min_ = np.min(heatmap)
max_ = np.max(heatmap)
heatmap = (heatmap - min_) / (max_ - min_)
return heatmap
2. 热力图与原始图的叠加
通过以下代码获取热力图,并将其尺寸放缩到与原图一致:
heatmap = cam(img_path, model_path)
heatmap = resize(heatmap, (144, 168, 152))
加载数据:
img_data = np.load(img_path)
热力图与原图简单叠加:
def easy_show(data, heatmap):
plt.figure()
plt.subplot(221)
plt.axis('off')
plt.imshow(data, cmap='bone')
plt.subplot(222)
plt.axis('off')
plt.imshow(heatmap, cmap='rainbow')
plt.subplot(223)
plt.axis('off')
plt.imshow(data, cmap='bone')
plt.imshow(heatmap, cmap='rainbow', alpha=0.7)
plt.subplot(224)
plt.axis('off')
plt.imshow(data, cmap='bone')
plt.imshow(heatmap, cmap='rainbow', alpha=0.3)
plt.savefig(r'E:study研究生笔记studyNoteothersimgstmp.png')
# 使用
heatmap = np.load("CNcam.npy")
img_data = np.load(img_path)
easy_show(img_data[:, 84, :], heatmap[:, 84, :])
图像融合结果:
3. 热力图与原始图融合优化
上面图像融合之后存在的问题是,前景热力图完全遮挡了原图,使得最终的展示图中,原图结构存在模糊。首先对热力图进行优化,使背景颜色变为白色且去掉一些权重过小热力。然后将热力图剩余的部分叠加到原图上。
def img_fusion(img1, img2, save_path):
dpi = 100
save_fig(img1, dpi, "cam.png")
img = Image.open("cam.png")
img = np.array(img)
for i in range(len(img)):
for j in range(len(img[0])):
if img[i][j][0] == 127 and img[i][j][1] == 0 and img[i][j][2] == 255
and img[i][j][3] == 255:
img[i][j][:] = 255
save_fig(img2, dpi, "data.png", "bone")
cam_img = cv2.imread("cam.png")
data_img = cv2.imread("data.png")
cam_gray = cv2.cvtColor(cam_img, cv2.COLOR_BGR2GRAY)
rest, mask = cv2.threshold(cam_gray, 80, 255, cv2.THRESH_BINARY)
cam_fg = cv2.bitwise_and(cam_img, cam_img, mask=mask)
dst = cv2.addWeighted(cam_fg, 0.4, data_img, 1, 0)
add_cubic = cv2.resize(dst, (dst.shape[1] * 4, dst.shape[0] * 4), cv2.INTER_CUBIC)
cv2.imwrite(save_path, add_cubic)
使用上面的函数(上面的图像不正,首先向左旋转90°,之后再进行融合):
heatmap = np.load("CNcam.npy")
img_data = np.load(img_path)
heatmap = np.where(heatmap < 0.3, 0, heatmap) * 255
img_data = np.rot90(img_data[:, 84, :], 1) # 向左旋转90度
heatmap = np.rot90(heatmap[:, 84, :], 1)
img_fusion(heatmap, img_data, r'tmp.png')
绘制结果:
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 腾讯云-轻量应用服务器SaaS交付Discuz! Q
- LeetCode 刷题技巧与学习方法
- IntelliJ IDEA 2020.2正式发布,诸多亮点总有几款能助你提效
- SQL之单表查询
- Ubuntu19.10 编译运行C语言程序
- Linux 中杀死指定端口的进程
- Python:将给定字符串中的大写英文字母按以下对应规则替换
- 数据库原理02——关系数据库
- 计算机网络02——物理层
- 面试官你好,我已经掌握了MySQL主从配置和读写分离,你看我还有机会吗?
- 虚拟机中安装双系统
- 一张图记住 Vim 常用命令
- 数据库原理01——概述
- Spring 注解开发之 @Bean 及其相关注解
- JDK配置详细教程