利用python中的matplotlib打印混淆矩阵实例
前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看。。
代码:
import itertools
import matplotlib.pyplot as plt
import numpy as np
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
# plt.savefig('confusion_matrix',dpi=200)
cnf_matrix = np.array([
[4101, 2, 5, 24, 0],
[50, 3930, 6, 14, 5],
[29, 3, 3973, 4, 0],
[45, 7, 1, 3878, 119],
[31, 1, 8, 28, 3936],
])
class_names = ['Buildings', 'Farmland', 'Greenbelt', 'Wasteland', 'Water']
# plt.figure()
# plot_confusion_matrix(cnf_matrix, classes=class_names,
# title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title='Normalized confusion matrix')
在放矩阵位置,放一下你的混淆矩阵就可以,当然可视化混淆矩阵这一步也可以直接在模型运行中完成。
补充知识:混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)
原理
在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能. 混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量. 这个矩阵的每一行表示真实类中的实例, 而每一列表示预测类中的实例 (Tensorflow 和 scikit-learn 采用的实现方式). 也可以是, 每一行表示预测类中的实例, 而每一列表示真实类中的实例 (Confusion matrix From Wikipedia 中的定义). 通过混淆矩阵, 可以很容易看出系统是否会弄混两个类, 这也是混淆矩阵名字的由来.
混淆矩阵是一种特殊类型的列联表(contingency table)或交叉制表(cross tabulation or crosstab). 其有两维 (真实值 “actual” 和 预测值 “predicted” ), 这两维都具有相同的类(“classes”)的集合. 在列联表中, 每个维度和类的组合是一个变量. 列联表以表的形式, 可视化地表示多个变量的频率分布.
使用混淆矩阵( scikit-learn 和 Tensorflow)
下面先介绍在 scikit-learn 和 tensorflow 中计算混淆矩阵的 API (Application Programming Interface) 接口函数, 然后在一个示例中, 使用这两个 API 函数.
scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口
skearn.metrics.confusion_matrix(
y_true, # array, Gound true (correct) target values
y_pred, # array, Estimated targets as returned by a classifier
labels=None, # array, List of labels to index the matrix.
sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)
在 scikit-learn 中, 计算混淆矩阵用来评估分类的准确度.
按照定义, 混淆矩阵 C 中的元素 Ci,j 等于真实值为组 i , 而预测为组 j 的观测数(the number of observations). 所以对于二分类任务, 预测结果中, 正确的负例数(true negatives, TN)为 C0,0; 错误的负例数(false negatives, FN)为 C1,0; 真实的正例数为 C1,1; 错误的正例数为 C0,1.
如果 labels 为 None, scikit-learn 会把在出现在 y_true 或 y_pred 中的所有值添加到标记列表 labels 中, 并排好序.
Tensorflow 混淆矩阵函数 tf.confusion_matrix API 接口
tf.confusion_matrix(
labels, # 1-D Tensor of real labels for the classification task
predictions, # 1-D Tensor of predictions for a givenclassification
num_classes=None, # The possible number of labels the classification task can have
dtype=tf.int32, # Data type of the confusion matrix
name=None, # Scope name
weights=None, # An optional Tensor whose shape matches predictions
)
Tensorflow tf.confusion_matrix 中的 num_classes 参数的含义, 与 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 参数相近, 是与标记有关的参数, 表示类的总个数, 但没有列出具体的标记值. 在 Tensorflow 中一般是以整数作为标记, 如果标记为字符串等非整数类型, 则需先转为整数表示. 如果 num_classes 参数为 None, 则把 labels 和 predictions 中的最大值 + 1, 作为num_classes 参数值.
tf.confusion_matrix 的 weights 参数和 sklearn.metrics.confusion_matrix 的 sample_weight 参数的含义相同, 都是对预测值进行加权, 在此基础上, 计算混淆矩阵单元的值.
使用示例
#!/usr/bin/env python
# -*- coding: utf8 -*-
"""
Author: klchang
Description:
A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.
Date: 2018.9.8
"""
from __future__ import print_function
import tensorflow as tf
import sklearn.metrics
y_true = [1, 2, 4]
y_pred = [2, 2, 4]
# Build graph with tf.confusion_matrix operation
sess = tf.InteractiveSession()
op = tf.confusion_matrix(y_true, y_pred)
op2 = tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3]))
# Execute the graph
print ("confusion matrix in tensorflow: ")
print ("1. default: n", op.eval())
print ("2. customed: n", sess.run(op2))
sess.close()
# Use sklearn.metrics.confusion_matrix function
print ("nconfusion matrix in scikit-learn: ")
print ("1. default: n", sklearn.metrics.confusion_matrix(y_true, y_pred))
print ("2. customed: n", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))
以上这篇利用python中的matplotlib打印混淆矩阵实例就是小编分享给大家的全部内容了,希望能给大家一个参考。
- 绘图: matplotlib Basemap简介
- GridView实战二:使用ObjectDataSource数据源控件(自定义缓存机制实现Sort)
- 把孩子打造成为码农
- 分享基于Qt5开发的一款故障波形模拟软件
- 剑指OFFER之打印1到最大的N位数(九度OJ1515)
- GridView实战二:使用ObjectDataSource数据源控件
- javascript实例:逐条记录停顿的走马灯
- Python标准库05 存储对象 (pickle包,cPickle包)
- macOS平台下虚拟摄像头的研发总结
- 网页优化系列三:使用压缩后置viewstate
- 网页优化系列三:使用压缩后置viewstate
- macOS下利用dSYM文件将crash文件中的内存地址转换为可读符号
- 微信小程序的大动作
- Python标准库04 文件管理 (部分os包,shutil包)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 面试题系列第5篇:JDK的运行时常量池、字符串常量池、静态常量池,还傻傻分不清?
- Go by Example 中文版: 互斥锁
- Idea初始化配置大全,以后重装再也不用各种百度了
- 使用这种技巧,可以大大地提高前端布局效率
- Element-UI表格组件实现行拖拽排序
- Vue自定义指令实现拖拽功能
- 小程序 Canvas 层级问题
- JDK 8 新特性之函数式编程 → Stream API
- golang 单元测试框架实践
- 想要成为前端Star 吗?一首歌时间将React/Vue 应用Docker 化
- 60亿次for循环,原来这么多东西
- 不要再问我 in,exists 走不走索引了...
- 知乎太可恶了,一言不合就封号?
- 5年Java开发经验,面试挂在MySQL InnoDB上!大厂究竟多看重MySQL?
- 是你们的力量,让知乎看见了!