机器学习-简单线性回归教程
线性回归(Linear regression)虽然是一种非常简单的方法,但在很多情况下已被证明非常有用。
在这篇文章中,您将逐步发现线性回归(Linear regression)是如何工作的。阅读完这篇文章后,你会学习到在线性回归算法中:
- 如何一步一步地计算一个简单的线性回归。
- 如何使用电子表格执行所有计算。
- 如何使用你的模型预测新的数据。
- 一个能大大简化计算的捷径。
这是一份为开发者所写的教程,读者不需具备数学或统计学背景。
同时,在本教程中,你将使用自己的电子表格,这将有助于你对概念的理解。
更新#1:修正均方误差根(RMSE)计算中的一个错误。
上图作者:Catface27, 保留部分权利
教程数据集
我们正在使用的数据集是完全虚构的。
以下是原始数据
x y
1 1
2 3
4 3
3 2
5 5
属性x是输入变量,y是我们试图预测的输出变量。如果我们得到足够多的数据,我们只通过x值,就能预测得到y值。
下面是x对y的简单散点图。
我们可以看到x和y之间的关系看起来有点线性。如图所示,我们可以从图的左下角向右上角对角地画一条线,以便描述数据之间的关系。
这是一个很好的迹象,表明使用线性回归可能适合于这个小数据集。
简单的线性回归(Simple Linear Regression)
当我们有一个单一的输入属性(x),我们想要使用线性回归,这就是所谓的简单线性回归。
如果我们有多个输入属性(如x1, x2, x3等)这就叫做多元线性回归。简单线性回归的过程与多元线性回归的过程是不同的,但比多元线性回归更简单,因此首先学习简单线性回归是一个很好的起点。
在本节中,我们将根据我们的训练数据创建一个简单线性回归模型,然后对我们的训练数据进行预测,以了解模型如何在数据中学习从而得到函数关系。
通过简单线性回归,我们想要如下模拟我们的数据:
y = B0 + B1 * x
上式是一条直线,其中y是我们想要预测的输出变量,x是我们知道的输入变量,B0和B1是我们需要估计的系数。
从数学上讲,B0被称为截距,因为它决定了直线截取y轴的位置。在机器学习中,我们可以称之为偏差,因为它被添加来抵消我们所做的所有预测。B1项称为斜率,因为它定义了直线的斜率,或者说在我们加上偏差之前x如何转化为y值,就是通过B1。
现在,我们的目标是找到系数的最佳估计,以最小化从x预测y的误差。
简单线性回归是很好的,因为不用通过反复试验来搜索值,或者使用更高级的线性代数来分析它们,我们可以直接从我们的数据中估计它们。
我们可以通过估算B1的值来开始:
B1 = sum((xi-mean(x))*(yi-mean(y)))/ sum((xi-mean(x))^ 2)
其中,mean()是我们数据集中变量的平均值,xi和yi指的是我们需要在数据集中的所有值上重复这些计算,而i指的是x或y的第i个值。
我们可以使用B1和我们的数据集中的一些统计数据来计算B0,如下所示:
B0 = mean(y) – B1 * mean(x)
没那么糟糕吧?我们可以在电子表格(例如Excel)中计算这些。
估计斜率(B1)
让我们从分子的顶部开始。
首先我们需要计算x和y的平均值。平均值计算如下:
1 / n * sum(x)
其中n是值的数量(在这种情况下是5)。您可以在电子表格中使用AVERAGE()函数。我们来计算我们的x和y变量的平均值:
mean(x) = 3
mean(y) = 2.8
现在我们需要从平均值中计算每个变量的误差。先用x来做这个事情:
x mean(x) x - mean(x)
1 3 -2
2 3 -1
4 3 1
3 3 0
5 3 2
然后让我们来做这个y变量
y mean(y) y - mean(y)
1 2.8 -1.8
3 2.8 0.2
3 2.8 0.2
2 2.8 -0.8
5 2.8 2.2
我们现在有计算分子的部分。我们所要做的就是将每个x的误差与每个y的误差相乘,并计算这些乘积的和。
x - mean(x) y - mean(y) Multiplication
-2 -1.8 3.6
-1 0.2 -0.2
1 0.2 0.2
0 -0.8 0
2 2.2 4.4
计算最后一行,我们计算出的分子为8。
现在我们需要计算方程的底部计算B1或分母。这被计算为平均值的每个x值的平方差的总和。
我们已经从平均值中计算了每个x值的差值,我们所要做的就是将每个值平方并计算总和。
x - mean(x) squared
-2 4
-1 1
1 1
0 0
2 4
计算这些平方值的总和可以得出10的分母
现在我们可以计算出我们的斜率值。
B1 = 8 / 10
B1 = 0.8
估计截距(B0)
这是很容易的,因为我们已经知道所有涉及的术语的价值。
B0 = mean(y) – B1 * mean(x)
or
B0 = 2.8 – 0.8 * 3
or
B0 = 0.4
进行预测
现在我们有简单线性回归方程的系数。
y = B0 + B1 * x
or
y = 0.4 + 0.8 * x
让我们通过对训练数据的预测来检验模型。
x y predicted y
1 1 1.2
2 3 2
4 3 3.6
3 2 2.8
5 5 4.4
我们可以将这些预测与我们的数据作为一条线。这给我们提供了一个直观的概念,即我们的数据是如何建立的。
估算误差
我们可以计算一个称为均方根误差或RMSE的预测误差。
RMSE = sqrt(sum((pi-yi)^ 2)/ n)
其中sqrt()是平方根函数,p是预测值,y是实际值,i是特定实例的指数,n是预测的数量,因为我们必须计算所有预测值的误差。
首先,我们必须计算每个模型预测与实际y值之间的差异。
pred-y y error
1.2 1 0.2
2 3 -1
3.6 3 0.6
2.8 2 0.8
4.4 5 -0.6
我们可以很容易地计算出每个误差值的平方(error * error或error ^ 2)。
error squared error
0.2 0.04
-1 1
0.6 0.36
0.8 0.64
-0.6 0.36
这些误差的总和是2.4单位,除以n,取平方根给我们:
RMSE = 0.692
即,每个预测平均误差大约0.692个单位。
估计B0和B1的快捷方法
在我们结束之前,我想向您展示计算系数的快捷方式。
简单线性回归是最简单的回归形式,也是研究最多的形式。您可以使用一个快捷方法来快速估计B0和B1的值。
针对计算B1的捷径。B1的计算可以重写为:
B1 = corr(x,y)* stdev(y)/ stdev(x)
其中corr(x)是x和y之间的相关性,stdev()是一个变量的标准偏差的计算。。
相关性(也称为Pearson相关系数)是一种衡量相关的两个变量在-1到1之间的关系。1的值表示这两个变量是完全正相关的,它们都朝同一个方向运动,但当一个值向一个方向移动,而另一个值向其他方向移动,-1表示它们完全负相关。
标准差是衡量平均数据的平均值。
您可以在电子表格中使用函数PEARSON()计算x和y的相关性为0.852(高度相关)和STDEV()函数计算x的标准偏差为1.5811,y的标准偏差为1.4832。
将这些值代入我们有:
B1 = 0.852 * 1.4832 / 1.5811
B1 = 0.799
可以看到,B1=0.799足够接近0.8的上述值。请注意,如果我们在电子表格(如excel)中为相关和标准偏差方程使用更全面的精度,我们将得到0.8。
总结
在这篇文章中,您发现并学会了如何在电子表格中逐步实现线性回归。你可以了解到:
- 如何根据您的训练数据估计简单线性回归模型的系数。
- 如何使用您的学习模型进行预测。
如果你对这个帖子或者线性回归有任何疑问?留下评论,问你的问题,我会尽我所能来回答。
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 手把手教你使用Python实现常用的假设检验 !
- Oracle 每日一题系列合集
- Arrow更好用的python时间序列处理库,你用过吗?
- 死信队列监听补充
- 手把手教你用Python查询你的物流信息
- Selenium自动登录淘宝,我无意间发现了登录漏洞!
- 【DB宝20】在Docker中分分钟即可拥有OGG Director环境
- mq监听死信队列后如何处理
- 【小白学PyTorch】7 最新版本torchvision.transforms常用API翻译与讲解
- 小白学PyTorch | 8 实战之MNIST小试牛刀
- 干货:用好VSCode这13款插件和8个快捷键,工作效率提升10倍
- 使用dplyr包对表格整理
- 安利 5 个拍案叫绝的 Matplotlib 骚操作!
- 多媒体程序开发
- 本地 IDE 已废!编辑器大结局!GitHub 的云 VSCode 实测