三分钟读懂Softmax函数
Softmax是一种激活函数,它可以将一个数值向量归一化为一个概率分布向量,且各个概率之和为1。Softmax可以用来作为神经网络的最后一层,用于多分类问题的输出。Softmax层常常和交叉熵损失函数一起结合使用。
从二分类到多分类
对于二分类问题,我们可以使用Sigmod函数(又称Logistic函数)。将
范围内的数值映射成为一个
区间的数值,一个
区间的数值恰好可以用来表示概率。
比如,在互联网广告和推荐系统中,曾广泛使用Sigmod函数来预测某项内容是否有可能被点击。Sigmoid函数输出值越大,说明这项内容被用户点击的可能性越大,越应该将该内容放置到更加醒目的位置。
除了二分类,现实世界往往有其他类型的问题。比如我们想识别手写的阿拉伯数字0-9,这就是一个多分类问题,需要从10个数字中选择一个概率最高的作为预测结果。
手写体识别数据集mnist
对于多分类问题,一种常用的方法是Softmax函数,它可以预测每个类别的概率。对于阿拉伯数字预测问题,选择预测值最高的类别作为结果即可。Softmax的公式如下,其中
是一个向量,
和
是其中的一个元素。
下图中,我们看到,Softmax将一个
的向量转化为了
,而且各项之和为1。
Softmax可以将数值向量转换为概率分布
Softmax函数可以将上一层的原始数据进行归一化,转化为一个
之间的数值,这些数值可以被当做概率分布,用来作为多分类的目标预测值。Softmax函数一般作为神经网络的最后一层,接受来自上一层网络的输入值,然后将其转化为概率。
下图为VGG16网络,是一个图像分类网络,原始图像中的数据经过卷积层、池化层、全连接层后,最终经过Softmax层输出成概率。
VGG16是一个图像分类网络,Softmax是VGG16的最后一层,Softmax层的前面是全连接层,Softmax层也是整个VGG16神经网络的输出,输出的是多分类的概率分布
实际上,Sigmod函数是Softmax函数的一个特例,Sigmod函数只能用于预测值为0或1的二元分类。
指数函数
Softmax函数使用了指数,对于每个输入
,需要计算
的指数。在深度学习进行反向传播时,我们经常需要求导,指数函数求导比较方便:
。
我们可以用NumPy实现一个简单的Softmax:
def softmax(x):
return np.exp(x) / np.sum(np.exp(x), axis=0)
对于下面的输入,可以得到:
a = np.asarray([2, 3, 5])
softmax(a)
array([0.04201007, 0.1141952 , 0.84379473])
如果不使用指数,单纯计算百分比:
def percentile(x):
return x / np.sum(x, axis=0)
得到的结果为:
percentile(a)
array([0.2, 0.3, 0.5])
指数函数在x轴正轴的变化非常明显,斜率越来越大。x轴上一个很小的变化都会导致y轴非常大的变化。相比求和计算百分比的方式:
,指数能把一些数值差距拉大。
指数函数
但正因为指数在x轴正轴爆炸式地快速增长,如果
比较大,
也会非常大,得到的数值可能会溢出。溢出又分为下溢出(Underflow)和上溢出(Overflow)。计算机用一定长度的二进制表示数值,数值又被称为浮点数。当数值过小的时候,被四舍五入为0,这就是下溢出;当数值过大,超出了最大界限,就是上溢出。
比如,仍然用刚才那个NumPy实现的简单的Softmax:
b = np.array([20, 300, 5000])
softmax(b)
会报错:
RuntimeWarning: overflow encountered in exp return np.exp(x) / np.sum(np.exp(x), axis=0)
一个简单的办法是,先求得输入向量的最大值,然后所有向量都减去这个最大值:
参考资料
- https://medium.com/data-science-bootcamp/understand-the-softmax-function-in-minutes-f3a59641e86d
- https://zhuanlan.zhihu.com/p/105722023
- https://en.wikipedia.org/wiki/Softmax_function
- http://deanhan.com/2018/07/26/vgg16/
- 关于 xargs 参数被截断,tar 文件被覆盖的问题
- 一些sql用法例子【Updating】
- 关于腾讯的一道字符串匹配的面试题
- Sort Map by Value in Java
- java 利用反射模拟动态语言的 eval 函数
- Spark函数讲解: combineByKey
- hadoop 里执行 MapReduce 任务的几种常见方式
- Pig、Hive、MapReduce 解决分组 Top K 问题
- Pig、Hive 自定义输入输出分隔符以及Map、Array嵌套分隔符冲突问题
- 新手教程:局域网DNS劫持实战
- 自定义 java 日期、时间 处理函数集
- MapReduce 中的两表 join 几种方案简介
- MapReduce中的自定义多目录/文件名输出HDFS
- 通过hiveserver远程服务构建hive web查询分析工具
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- php解决约瑟夫环算法实例分析
- 浅谈laravel-admin的sortable和orderby使用问题
- 使用composer安装使用thinkphp6.0框架问题【视频教程】
- 基于laravel-admin 后台 列表标签背景的使用方法
- 解决laravel-admin 自己新建页面里 js 需要刷新一次的问题
- laravel-admin 中列表筛选方法
- Laravel框架控制器的middleware中间件用法分析
- laravel-admin的图片删除实例
- 在laravel-admin中列表中禁止某行编辑、删除的方法
- Laravel的Auth验证Token验证使用自定义Redis的例子
- laravel-admin解决表单select联动时,编辑默认没选上的问题
- laravel-admin 后台表格筛选设置默认的查询日期方法
- Laravel框架控制器的request与response用法示例
- laravel 字段格式化 modle 字段类型转换方法
- laravel-admin 在列表页添加自定义按钮的例子