PyTorch8:损失函数
1. 损失函数总览
PyTorch 的 Loss Function(损失函数)都在 torch.nn.functional
里,也提供了封装好的类在 torch.nn
里。
因为 torch.nn
可以记录导数信息,在使用时一般不使用 torch.nn.functional
。
PyTorch 里一共有 18 个损失函数,常用的有 6 个,分别是:
回归损失函数:
torch.nn.L1Loss
torch.nn.MSELoss
分类损失函数:
torch.nn.BCELoss
torch.nn.BCEWithLogitsLoss
torch.nn.CrossEntropyLoss
torch.nn.NLLLoss
损失函数是用来衡量模型输出的每个预测值与真实值的差异的:
还有额外的两个概念:
Cost Function(代价函数)是 N 个预测值的损失函数平均值:
Objective Function(目标函数)是最终需要优化的函数:
2. 回归损失函数
回归模型有两种方法进行评估:MAE(mean absolute error) 和 MSE(mean squared error)。
-
torch.nn.L1Loss(reduction='mean')
这个类对应了 MAE 损失函数;
-
torch.nn.MSELoss(reduction='mean')
这个类对应了 MSE 损失函数;
上面两个类中的 reduction
规定了获得 后的行为,有 none
、sum
和 mean
三个。none
表示不对 进行任何处理;sum
表示对 进行求和;mean
表示对 进行平均。默认为 mean
。
>>> y = torch.tensor([1.1, 1.2, 1.3])
>>> y_hat = torch.tensor([1., 1., 1.])
>>> criterion_none = nn.L1Loss(reduction='none') # 什么都不做
>>> criterion_none(y_hat, y)
tensor([0.1000, 0.2000, 0.3000])
>>> criterion_mean = nn.L1Loss(reduction='mean') # 求平均
>>> criterion_mean(y_hat, y)
tensor(0.2000)
>>> criterion_sum = nn.L1Loss(reduction='sum') # 求和
>>> criterion_sum(y_hat, y)
tensor(0.6000)
3. 分类损失函数
3.1 交叉熵
自信息是一个事件发生的概率的负对数:
信息熵用来描述一个事件的不确定性公式为:
一个确定的事件的信息熵为 0;一个事件越不确定,信息熵就越大。
交叉熵,用来衡量在给定的真实分布下,使用非真实分布指定的策略消除系统的不确定性所需要付出努力的大小,表达式为
相对熵又叫 “K-L 散度”,用来描述预测事件对真实事件的概率偏差。
而交叉熵的表达式为
可见H(P,Q) ,即交叉熵是信息熵和相对熵的和。上面的P是事件的真实分布, Q是预测出来的分布。所以优化H(P,Q)等价于优化H(Q) ,因为H(P)是已知不变的。
3.2 分类损失函数
下面我们来了解最常用的四个分类损失函数。
torch.nn.BCELoss(weight=None, reduction='mean')
这个类实现了二分类交叉熵。
使用这个类时要注意,输入值(不是分类)的范围要在 之间,否则会报错。
>>> inputs = torch.tensor([[1, 2], [2, 2], [3, 4], [4, 5]], dtype=torch.float)
>>> target = torch.tensor([[1, 0], [1, 0], [0, 1], [0, 1]], dtype=torch.float)
>>> criterion = nn.BCELoss()
>>> criterion(inputs, target)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
...
RuntimeError: all elements of input should be between 0 and 1
通常可以先使用 F.sigmoid
处理一下数据。
torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)
等价于 F.sigmoid + torch.nn.BCELoss
,就是 先使用了 sigmoid 处理了一下,这样就不需要手动使用 sigmoid 的了。
torch.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')
NLLLoss 的全称为 “negative log likelihood loss”,其作用是实现负对数似然函数中的负号。
torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')
这个类结合了 nn.LogSoftmax
和 nn.NLLLoss
。
torch.nn.KLDivLoss(reduction='mean')
这个类就是上面提到的相对熵。
这几个类的参数类似,除了上面提到的 reduction
,还有一个 weight
,就是每一个类别的权重。下面用例子来解释交叉熵和 weight
是如何运作的。我们先定义一组数据,使用 numpy 推演一下:
inputs = torch.tensor([[1, 1], [1, 2], [3, 3]], dtype=torch.float)
target = torch.tensor([0, 0, 1],dtype=torch.long)
idx = target[0]
input_ = inputs.detach().numpy()[idx] # [1, 1]
target_ = target.numpy()[idx] # [0]
# 第一项
x_class = input_[target_]
# 第二项
sigma_exp_x = np.sum(list(map(np.exp, input_)))
log_sigma_exp_x = np.log(sigma_exp_x)
# 输出 loss
loss_1 = -x_class + log_sigma_exp_x
结果为
>>> print("第一个样本 loss 为: ", loss_1)
第一个样本 loss 为: 0.6931473
现在我们再使用 PyTorch 来计算:
>>> criterion_ce = nn.CrossEntropyLoss(reduction='none')
>>> criterion_ce(inputs, target)
tensor([0.6931, 1.3133, 0.6931])
可以看到,结果是一致的。现在我们再看看 weight
:
>>> weight = torch.tensor([0.1, 0.9], dtype=torch.float)
>>> criterion_ce = nn.CrossEntropyLoss(weight=weight, reduction='none')
>>> criterion_ce(inputs, target)
tensor([0.0693, 0.1313, 0.6238])
与没有权重的交叉熵进行比较后可以发现,每一个值都乘以了 。当 reduction
为 sum
和 mean
的时候,交叉熵的加权总和或者平均值再除以权重的和。
3.3 总结
-
F.sigmoid
(激活函数)+nn.BCELoss
(损失函数)=torch.nn.BCEWithLogitsLoss
(损失函数) -
nn.LogSoftmax
(激活函数)+nn.NLLLoss
(损失函数)=torch.nn.CrossEntropyLoss
(损失函数)
- 优化算法——拟牛顿法之L-BFGS算法
- 一次性能突发情况的紧急修复(r9笔记第18天)
- Java基础-day02-基础题
- 简单易学的机器学习算法——AdaBoost
- 用脚本来讲一个技术生活的故事 (r9笔记第32天)
- 优化算法——拟牛顿法之BFGS算法
- 对于tnsping的连接超时的功能补充(二)(r9笔记第22天)
- 用深度学习每次得到的结果都不一样,怎么办?
- 优化算法——拟牛顿法之DFP算法
- python SVM 案例,sklearn.svm.SVC 参数说明
- 利用Theano理解深度学习——Auto Encoder
- sudo 出现unable to resolve host 解决方法
- Hadoop学习笔记——Hadoop常用命令
- 可扩展机器学习——Spark分布式处理
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- OpenCV 处理中文路径、绘制中文文字的烦恼,这里通通帮你解决!
- 如何快速分析大型系统架构?
- Linux小技巧、文件查找、修改、读取
- 我在赏金计划中发现的RACE条件漏洞
- 哦!数组还能这么用,学到了!
- 【C++简明教程】随机数生成
- Pytest标记预期失败得测试用例@pytest.mark.xfail()
- IAT HOOK
- 形式化分析工具(六):HLPSL Tutorial
- 推荐一款技术人必备的接口测试神器:Apifox
- GO 文档笔记
- 魔改npm私有仓库 | Verdaccio教程
- 【Vulnhub】AI Web 2.0
- Python迭代器和生成器
- Python深层解析json数据之JsonPath