【tensorflow2.0】评价指标metrics
损失函数除了作为模型训练时候的优化目标,也能够作为模型好坏的一种评价指标。但通常人们还会从其它角度评估模型的好坏。
这就是评估指标。通常损失函数都可以作为评估指标,如MAE,MSE,CategoricalCrossentropy等也是常用的评估指标。
但评估指标不一定可以作为损失函数,例如AUC,Accuracy,Precision。因为评估指标不要求连续可导,而损失函数通常要求连续可导。
编译模型时,可以通过列表形式指定多个评估指标。
如果有需要,也可以自定义评估指标。
自定义评估指标需要接收两个张量y_true,y_pred作为输入参数,并输出一个标量作为评估值。
也可以对tf.keras.metrics.Metric进行子类化,重写初始化方法, update_state方法, result方法实现评估指标的计算逻辑,从而得到评估指标的类的实现形式。
由于训练的过程通常是分批次训练的,而评估指标要跑完一个epoch才能够得到整体的指标结果。因此,类形式的评估指标更为常见。即需要编写初始化方法以创建与计算指标结果相关的一些中间变量,编写update_state方法在每个batch后更新相关中间变量的状态,编写result方法输出最终指标结果。
如果编写函数形式的评估指标,则只能取epoch中各个batch计算的评估指标结果的平均值作为整个epoch上的评估指标结果,这个结果通常会偏离拿整个epoch数据一次计算的结果。
一,常用的内置评估指标
- MeanSquaredError(平方差误差,用于回归,可以简写为MSE,函数形式为mse)
- MeanAbsoluteError (绝对值误差,用于回归,可以简写为MAE,函数形式为mae)
- MeanAbsolutePercentageError (平均百分比误差,用于回归,可以简写为MAPE,函数形式为mape)
- RootMeanSquaredError (均方根误差,用于回归)
- Accuracy (准确率,用于分类,可以用字符串"Accuracy"表示,Accuracy=(TP+TN)/(TP+TN+FP+FN),要求y_true和y_pred都为类别序号编码)
- Precision (精确率,用于二分类,Precision = TP/(TP+FP))
- Recall (召回率,用于二分类,Recall = TP/(TP+FN))
- TruePositives (真正例,用于二分类)
- TrueNegatives (真负例,用于二分类)
- FalsePositives (假正例,用于二分类)
- FalseNegatives (假负例,用于二分类)
- AUC(ROC曲线(TPR vs FPR)下的面积,用于二分类,直观解释为随机抽取一个正样本和一个负样本,正样本的预测值大于负样本的概率)
- CategoricalAccuracy(分类准确率,与Accuracy含义相同,要求y_true(label)为onehot编码形式)
- SparseCategoricalAccuracy (稀疏分类准确率,与Accuracy含义相同,要求y_true(label)为序号编码形式)
- MeanIoU (Intersection-Over-Union,常用于图像分割)
- TopKCategoricalAccuracy (多分类TopK准确率,要求y_true(label)为onehot编码形式)
- SparseTopKCategoricalAccuracy (稀疏多分类TopK准确率,要求y_true(label)为序号编码形式)
- Mean (平均值)
- Sum (求和)
二, 自定义评估指标
我们以金融风控领域常用的KS指标为例,示范自定义评估指标。
KS指标适合二分类问题,其计算方式为 KS=max(TPR-FPR).
其中TPR=TP/(TP+FN) , FPR = FP/(FP+TN)
TPR曲线实际上就是正样本的累积分布曲线(CDF),FPR曲线实际上就是负样本的累积分布曲线(CDF)。
KS指标就是正样本和负样本累积分布曲线差值的最大值。
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics
# 函数形式的自定义评估指标
@tf.function
def ks(y_true,y_pred):
y_true = tf.reshape(y_true,(-1,))
y_pred = tf.reshape(y_pred,(-1,))
length = tf.shape(y_true)[0]
t = tf.math.top_k(y_pred,k = length,sorted = False)
y_pred_sorted = tf.gather(y_pred,t.indices)
y_true_sorted = tf.gather(y_true,t.indices)
cum_positive_ratio = tf.truediv(
tf.cumsum(y_true_sorted),tf.reduce_sum(y_true_sorted))
cum_negative_ratio = tf.truediv(
tf.cumsum(1 - y_true_sorted),tf.reduce_sum(1 - y_true_sorted))
ks_value = tf.reduce_max(tf.abs(cum_positive_ratio - cum_negative_ratio))
return ks_value
y_true = tf.constant([[1],[1],[1],[0],[1],[1],[1],[0],[0],[0],[1],[0],[1],[0]])
y_pred = tf.constant([[0.6],[0.1],[0.4],[0.5],[0.7],[0.7],[0.7],
[0.4],[0.4],[0.5],[0.8],[0.3],[0.5],[0.3]])
tf.print(ks(y_true,y_pred))
0.625
# 类形式的自定义评估指标
class KS(metrics.Metric):
def __init__(self, name = "ks", **kwargs):
super(KS,self).__init__(name=name,**kwargs)
self.true_positives = self.add_weight(
name = "tp",shape = (101,), initializer = "zeros")
self.false_positives = self.add_weight(
name = "fp",shape = (101,), initializer = "zeros")
@tf.function
def update_state(self,y_true,y_pred):
y_true = tf.cast(tf.reshape(y_true,(-1,)),tf.bool)
y_pred = tf.cast(100*tf.reshape(y_pred,(-1,)),tf.int32)
for i in tf.range(0,tf.shape(y_true)[0]):
if y_true[i]:
self.true_positives[y_pred[i]].assign(
self.true_positives[y_pred[i]]+1.0)
else:
self.false_positives[y_pred[i]].assign(
self.false_positives[y_pred[i]]+1.0)
return (self.true_positives,self.false_positives)
@tf.function
def result(self):
cum_positive_ratio = tf.truediv(
tf.cumsum(self.true_positives),tf.reduce_sum(self.true_positives))
cum_negative_ratio = tf.truediv(
tf.cumsum(self.false_positives),tf.reduce_sum(self.false_positives))
ks_value = tf.reduce_max(tf.abs(cum_positive_ratio - cum_negative_ratio))
return ks_value
y_true = tf.constant([[1],[1],[1],[0],[1],[1],[1],[0],[0],[0],[1],[0],[1],[0]])
y_pred = tf.constant([[0.6],[0.1],[0.4],[0.5],[0.7],[0.7],
[0.7],[0.4],[0.4],[0.5],[0.8],[0.3],[0.5],[0.3]])
myks = KS()
myks.update_state(y_true,y_pred)
tf.print(myks.result())
0.625
参考:
开源电子书地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/
GitHub 项目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Android 8.0 中如何实现视频通话的画中画模式的示例
- Android7.0开发实现Launcher3去掉应用抽屉的方法详解
- Android利用Paint自定义View实现进度条控件方法示例
- 前端科普系列(5):ESLint - 守住优雅的护城河
- 图的储存方式,链式前向星最简单实现方式 (边集数组)
- 技术前刊:PostgreSQL12 COPY和bulkloading提升
- 疯子的算法总结(八) 最短路算法+模板
- POJ - 2387 Til the Cows Come Home (最短路入门)
- POJ - 3074 Sudoku (搜索)剪枝+位运算优化
- C语言rand随机函数问题
- HDU - 1253 胜利大逃亡(搜索)
- Android7.0版本影响开发的改进分析
- POJ - 2251 Dungeon Master(搜索)
- An Overview of PostgreSQL & MySQL Cross Replication
- POJ - 1321 棋盘问题