简单易学的机器学习算法——Softmax Regression
时间:2022-05-04
本文章向大家介绍简单易学的机器学习算法——Softmax Regression,主要内容包括一、Softmax Regression简介、二、Logistic回归的回顾、三、Logistic回归的推广——Softmax Regression、2、对、四、实验、2、测试数据、3、Matlab源码、基本概念、基础应用、原理机制和需要注意的事项等,并结合实例形式分析了其使用技巧,希望通过本文能帮助到大家理解应用这部分内容。
一、Softmax Regression简介
Softmax Regression是Logistic回归的推广,Logistic回归是处理二分类问题的,而Softmax Regression是处理多分类问题的。Logistic回归是处理二分类问题的比较好的算法,具有很多的应用场合,如广告计算等。Logistic回归利用的是后验概率最大化的方式去计算权重。
二、Logistic回归的回顾
在Logistic回归中比较重要的有两个公式,一个是阶跃函数:
另一个是对应的损失函数
最终,Logistic回归需要求出的是两个概率:
和
具体的Logistic回归的过程可参见“简单易学的机器学习算法——Logistic回归”。
三、Logistic回归的推广——Softmax Regression
在Logistic回归需要求解的是两个概率:
和
,而在Softmax Regression中将不是两个概率,而是
个概率,
表示的是分类的个数。我们需要求出以下的概率值:
此时的损失函数为
其中
是一个指示性函数,意思是大括号里的值为真时,该函数的结果为1,否则为0。下面就这几个公式做个解释:
1、损失函数的由来
概率函数可以表示为
其似然函数为
似然为
我们要最大化似然函数,即求
。再转化成损失函数。
2、对
似然(或者是损失函数)求偏导
为了简单,我们仅取一个样本,则可简单表示为
下面对
求偏导:
其中,
表示第
维。如Logistic回归中一样,可以使用基于梯度的方法来求解这样的最大化问题。基于梯度的方法可以参见“优化算法——梯度下降法”。
四、实验
1、训练数据
从图上我们可以看到分为4类。
2、测试数据
在区间上随机生成了4000个点,这样比较直观地看到分类边界。
3、Matlab源码
主程序
clear all;
clc;
%% 导入数据
data = load('SoftInput.txt');
[m,n] = size(data);
labels = unique(data(:,3));
labelLen = length(labels);%划分的种类
dataMat(:,2:3) = data(:,1:2);
dataMat(:,1) = 1;%做好数据集,添加一列为1
labelMat(:,1) = data(:,3)+1;%分类的标签
%% 画图
figure;
hold on
for i = 1:m
if labelMat(i,:) == 1
plot(data(i,1),data(i,2),'.m');%粉红色
elseif labelMat(i,:) == 2
plot(data(i,1),data(i,2),'.b');%蓝色
elseif labelMat(i,:) == 3
plot(data(i,1),data(i,2),'.r');%红色
else
plot(data(i,1),data(i,2),'.k');%黑色
end
end
title('原始数据集');
hold off
%% 初始化一些参数
M = m;%数据集的行
N = n;%数据集的列
K = labelLen;%划分的种类
alpha = 0.001;%学习率
weights = ones(N, K);%初始化权重
%% 利用随机梯度修改权重
weights = stochasticGradientAscent(dataMat, labelMat, M, weights, alpha);
%% 测试数据集(主要在区间里随机生成)
size = 4000;
[testDataSet, testLabelSet] = testData(weights, size, N);
%% 画出最终的分类图
figure;
hold on
for i = 1:size
if testLabelSet(i,:) == 1
plot(testDataSet(i,2),testDataSet(i,3),'.m');
elseif testLabelSet(i,:) == 2
plot(testDataSet(i,2),testDataSet(i,3),'.b');
elseif testLabelSet(i,:) == 3
plot(testDataSet(i,2),testDataSet(i,3),'.r');
else
plot(testDataSet(i,2),testDataSet(i,3),'.k');
end
end
title('测试数据集');
hold off
随机梯度法
%% 随机梯度下降法(这里要用上升法)
function [ weights ] = stochasticGradientAscent( dataMat, labelMat, M, weights, alpha )
for step = 1:500
for i = 1:M%对每一个样本
pop = exp(dataMat(i,:)*weights);%计算概率
popSum = sum(pop);%分母
pop = -pop/popSum;%求好概率
pop(:,labelMat(i)) = pop(:,labelMat(i))+1;%加1的操作
weights = weights + alpha*dataMat(i,:)'*pop;
end
end
end
生成测试数据
%% 计算测试数据集
function [ testDataSet, testLabelSet ] = testData( weights, m, n)
testDataSet = ones(m,n);%构建了全1的矩阵
testLabelSet = zeros(m,1);
for i = 1:m
testDataSet(i,2) = rand()*6-3;
testDataSet(i,3) = rand()*15;
end
%% 计算测试数据的所属分类
for i = 1:m
testResult = testDataSet(i,:)*weights;
[C,I] = max(testResult);
testLabelSet(i,:) = I;
end
end
- AngularJS 用 $http.jsonp 跨域SyntaxError问题
- 简单的java socket 示例
- Hadoop二次开发环境构建
- Android EditText 获得输入焦点 以及requestfocus()失效的问题
- 【直播】我的基因组68:看看哪些基因的突变较多,哪些较少
- GDI+编程
- GDI编程
- ADO访问数据库
- 【直播】我的基因组76:用krona对血液全基因组的菌比例可视化
- 【直播】我的基因组74:快速给测序reads比对到物种
- 用ADO操作数据库的方法步骤
- VC如何获取对话框中控件的坐标
- 【直播】我的基因组72:把基因检测芯片数据转为vcf格式
- 【直播】我的基因组78:简单解析一下蛋白编码基因的测序深度及覆盖度
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- ASP.NET Core 使用 Google 验证码(reCAPTCHA v3)代替传统验证码
- Centos7 安装 Tomcat8 伪集群 的正确姿势 并设置开机自启 实践笔记
- 问题合集,持续更新
- ASP.NET Core Swagger接入使用IdentityServer4 的 WebApi
- 基于IdentityServer4的OIDC实现单点登录(SSO)原理简析
- OpenGL ES 变量、结构体、语句、函数、精度
- OpenGL ES for Android 绘制矩形和正方形
- OpenGL ES for Android 绘制立方体
- 服务化最佳实践
- OpenGL ES for Android 深度测试
- OpenGL ES for Android 绘制旋转的地球
- [Hei.Captcha] Asp.Net Core 跨平台图形验证码实现
- Asp.Net Core 3.1 获取不到Post、Put请求的内容 System.NotSupportedException Specified method is not supported
- OpenGL ES for Android 播放视频
- Centos 7 在线安装 离线安装 最新 Docker-compose 的正确姿势 实践笔记