机器学习之最小二乘法
1.背景:
1801年,意大利天文学家朱赛普·皮亚齐发现了第一颗小行星谷神星。经过40天的跟踪观测后,由于谷神星运行至太阳背后,使得皮亚齐失去了谷神星的位置。随后全世界的科学家利用皮亚齐的观测数据开始寻找谷神星,但是根据大多数人计算的结果来寻找谷神星都没有结果。时年24岁的高斯也计算了谷神星的轨道。奥地利天文学家海因里希·奥伯斯根据高斯计算出来的轨道重新发现了谷神星。
高斯使用的最小二乘法的方法发表于1809年他的著作《天体运动论》中,而法国科学家勒让德于1806年独立发现“最小二乘法”,但因不为世人所知而默默无闻。两人曾为谁最早创立最小二乘法原理发生争执。
1829年,高斯提供了最小二乘法的优化效果强于其他方法的证明,见高斯-马尔可夫定理。
----维基百科
2. 最小二乘法在机器学习中被用来
3. 高中关于最小二乘法估计
概括:
假设有若干个样本点,(x1,y1),(x2,y2),(x3,y3),(x4,y4),(x5,y5),求解直线y=kx+b,是的这些样本点到直线的距离最小.
我们高中的求解方式也是这样的:
展开为:
min_sum = [y1- (kx1+b)]^2+[y2- (kx2+b)]^2+[y3- (kx3+b)]^2+[y4- (kx4+b)]^2+[y5- (kx5+b)]^2
就是各个点到我们设定的直线的欧式距离
化简为:
以上就是我们高中对于最小二乘法的最初认知. 这个求解的过程,我们称之为最小二乘法,而求解的这条直线,我们称之为线性回归,线性回归用来近似的预测数据的真是情况.
举个例子:(此题来自:北师大版高中数学)
从某所高中随机抽取一些可爱的萌妹子,就比如6个女生好了,测出她们的体重和身高如下表,现在来了一个60kg的女生,求问它的身高会有多高?
女生ID |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
---|---|---|---|---|---|---|---|---|
身高 |
165 |
165 |
157 |
170 |
175 |
165 |
155 |
170 |
体重 |
48 |
57 |
50 |
54 |
64 |
61 |
43 |
59 |
用python画图来表示这些数据好了:
1 # encoding: utf8
2 import matplotlib
3 import matplotlib.pyplot as plt
4 from matplotlib.font_manager import FontProperties
5 from sklearn.linear_model import LinearRegression
6 from scipy import sparse
7
8 print matplotlib.matplotlib_fname() # 将会获得matplotlib包所在文件夹
9 font = FontProperties()
10 plt.rcParams['font.sans-serif'] = ['Droid Sans Fallback'] # 指定默认字体
11 plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
12
13 plt.figure()
14 plt.title(u' 可爱女生的数据 ')
15 plt.xlabel(u'x 体重')
16 plt.ylabel(u'y 身高')
17 plt.axis([40, 80, 140, 200])
18 plt.grid(True)
19 x = [[48], [57], [50], [54], [64], [61], [43], [59]]
20 y = [[165], [165], [157], [170], [175], [165], [155], [170]]
21 plt.plot(x, y, 'k.')
22 model = LinearRegression()
23 model.fit(x, y)
24 # y2 = model.predict(x)
25 # plt.plot(x, y2, 'g-')
26 plt.show()
散点图:
对于这个例子,我们可以使用上面的公式,求解出回归方程,并可以得到方程拟合的该女生的身高值,但是这太麻烦了 , 毕竟高中还是太too yong too simple了~
4. 大学关于最小二乘法
基于上面的那个问题,我们大学有没有更好的一点的求解方式 ?
4.1 大学对于最小二乘法的概括:
找到那样一条函数曲线使得观测值的残差平方之和最小. 通俗的讲:见高中部分概括
4.2 继续上面的这个问题思路:
我们已知这些数据:
f(x,y) = [y1- (kx1+b)]^2+[y2- (kx2+b)]^2+[y3- (kx3+b)]^2+[y4- (kx4+b)]^2+[y5- (kx5+b)]^2+[y6- (kx6+b)]^2+[y7- (kx7+b)]^2+[y7- (kx7+b)]^2
如果存在最大值,那么只需要满足f(x,y)对于x,y的一阶偏导数均为0
求解得:
k= 0.849 , b =85.172
所以预测值为:
y = 0.849x - 85.172 将y = 60kg 代入求解得: x = 170.99175
我们再使用Python求解一次:
1 # encoding: utf8
2 import matplotlib
3 import matplotlib.pyplot as plt
4 from matplotlib.font_manager import FontProperties
5 from scipy.optimize import leastsq
6 from sklearn.linear_model import LinearRegression
7 from scipy import sparse
8 import numpy as np
9
10 # 拟合函数
11 def func(a, x):
12 k, b = a
13 return k * x + b
14
15
16 # 残差
17 def dist(a, x, y):
18 return func(a, x) - y
19
20
21 font = FontProperties()
22 plt.rcParams['font.sans-serif'] = ['Droid Sans Fallback'] # 指定默认字体
23 plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
24
25 plt.figure()
26 plt.title(u' 可爱女生的数据 ')
27 plt.xlabel(u'x 体重')
28 plt.ylabel(u'y 身高')
29 plt.axis([40, 80, 140, 200])
30 plt.grid(True)
31 x = np.array([48.0, 57.0, 50.0,54.0, 64.0, 61.0, 43.0, 59.0])
32 y = np.array([165.0, 165.0,157.0, 170.0, 175.0, 165.0, 155.0, 170.0])
33 plt.plot(x, y, 'k.')
34
35 param = [0, 0]
36
37 var= leastsq(dist, param, args=(x, y))
38 k, b = var[0]
39 print k, b
40
41 plt.plot(x, k*x+b, 'o-')
42
43 plt.show()
从图中,可以发现结果大致相符.
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Go语言(二十)日志采集项目(二)Etcd的使用
- prometheus入门(一)
- Go语言(十九)日志采集项目之logagent开发(一)
- Go语言(十 八)context&日志项目
- 使用梯度上升欺骗神经网络,让网络进行错误的分类
- Go语言(十七) 配置文件库项目
- Python 相对路径问题:“No such file or directory“
- 基于etcd服务发现的overlay跨多宿主机容器网络
- Go语言(十六) 日志项目升级
- PyQt5 技术篇-设置窗口相对桌面位置,按屏幕比例
- Go语言(十五) 反射
- SpringBoot应用跨域访问解决方案
- Spring Boot 2.2都有哪些新变化
- Go语言(十四)日志项目
- 如何在Spring Boot中使用Cookies