python实现多变量线性回归(Linear Regression with Multiple Variables)
本文介绍如何使用python实现多变量线性回归,文章参考NG的视频和黄海广博士的笔记
现在对房价模型增加更多的特征,例如房间数楼层等,构成一个含有多个变量的模型,模型中的特征为( x1,x2,...,xn)
表示为:
引入 x0=1,则公式 转化为:
1、加载训练数据
数据格式为:
X1,X2,Y
2104,3,399900
1600,3,329900
2400,3,369000
1416,2,232000
将数据逐行读取,用逗号切分,并放入np.array
#加载数据
#加载数据
def load_exdata(filename):
data = []
with open(filename, 'r') as f:
for line in f.readlines():
line = line.split(',')
current = [int(item) for item in line]
#5.5277,9.1302
data.append(current)
return data
data = load_exdata('ex1data2.txt');
data = np.array(data,np.int64)
x = data[:,(0,1)].reshape((-1,2))
y = data[:,2].reshape((-1,1))
m = y.shape[0]
# Print out some data points
print('First 10 examples from the dataset: n')
print(' x = ',x[range(10),:],'ny=',y[range(10),:])
First 10 examples from the dataset:
x = [[2104 3]
[1600 3]
[2400 3]
[1416 2]
[3000 4]
[1985 4]
[1534 3]
[1427 3]
[1380 3]
[1494 3]]
y= [[399900]
[329900]
[369000]
[232000]
[539900]
[299900]
[314900]
[198999]
[212000]
[242500]]
2、通过梯度下降求解theta
(1)在多维特征问题的时候,要保证特征具有相近的尺度,这将帮助梯度下降算法更快地收敛。
解决的方法是尝试将所有特征的尺度都尽量缩放到-1 到 1 之间,最简单的方法就是(X - mu) / sigma,其中mu是平均值, sigma 是标准差。
(2)损失函数和单变量一样,依然计算损失平方和均值
我们的目标和单变量线性回归问题中一样,是要找出使得代价函数最小的一系列参数。多变量线性回归的批量梯度下降算法为:
求导数后得到:
(3)向量化计算
向量化计算可以加快计算速度,怎么转化为向量化计算呢?
在多变量情况下,损失函数可以写为:
对theta求导后得到:
(1/2*m) * (X.T.dot(X.dot(theta) - y))
因此,theta迭代公式为:
theta = theta - (alpha/m) * (X.T.dot(X.dot(theta) - y))
(4)完整代码如下:
#特征缩放
def featureNormalize(X):
X_norm = X;
mu = np.zeros((1,X.shape[1]))
sigma = np.zeros((1,X.shape[1]))
for i in range(X.shape[1]):
mu[0,i] = np.mean(X[:,i]) # 均值
sigma[0,i] = np.std(X[:,i]) # 标准差
# print(mu)
# print(sigma)
X_norm = (X - mu) / sigma
return X_norm,mu,sigma
#计算损失
def computeCost(X, y, theta):
m = y.shape[0]
# J = (np.sum((X.dot(theta) - y)**2)) / (2*m)
C = X.dot(theta) - y
J2 = (C.T.dot(C))/ (2*m)
return J2
#梯度下降
def gradientDescent(X, y, theta, alpha, num_iters):
m = y.shape[0]
#print(m)
# 存储历史误差
J_history = np.zeros((num_iters, 1))
for iter in range(num_iters):
# 对J求导,得到 alpha/m * (WX - Y)*x(i), (3,m)*(m,1) X (m,3)*(3,1) = (m,1)
theta = theta - (alpha/m) * (X.T.dot(X.dot(theta) - y))
J_history[iter] = computeCost(X, y, theta)
return J_history,theta
iterations = 10000 #迭代次数
alpha = 0.01 #学习率
x = data[:,(0,1)].reshape((-1,2))
y = data[:,2].reshape((-1,1))
m = y.shape[0]
x,mu,sigma = featureNormalize(x)
X = np.hstack([x,np.ones((x.shape[0], 1))])
# X = X[range(2),:]
# y = y[range(2),:]
theta = np.zeros((3, 1))
j = computeCost(X,y,theta)
J_history,theta = gradientDescent(X, y, theta, alpha, iterations)
print('Theta found by gradient descent',theta)
Theta found by gradient descent [[ 109447.79646964]
[ -6578.35485416]
[ 340412.65957447]]
绘制迭代收敛图
plt.plot(J_history)
plt.ylabel('lost');
plt.xlabel('iter count')
plt.title('convergence graph')
使用模型预测结果
def predict(data):
testx = np.array(data)
testx = ((testx - mu) / sigma)
testx = np.hstack([testx,np.ones((testx.shape[0], 1))])
price = testx.dot(theta)
print('price is %d ' % (price))
predict([1650,3])
price is 293081
no bb,上代码,代码下载
- Oracle dbms_random随机函数包
- volatile和Synchronized区别
- Oracle 快速插入1000万条数据的实现方式
- HashMap实现原理分析
- Oracle TM锁和TX锁
- Oracle给Select结果集加锁,Skip Locked(跳过加锁行获得可以加锁的结果集)
- select for update和select for update wait和select for update nowait的区别
- Android入门之动画
- Java 读写大文本文件
- 高级聚类
- Oracle 数据库名、实例名、Oracle_SID
- Oracle 多行、多列子查询
- Android入门介绍
- Oracle 数据表的管理
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 排序算法之快速排序
- 干货 | Oracle数据库操作命令大全,满满的案例供你理解,收藏!
- 【2万字长文】深入浅出主流的几款小程序跨端框架原理
- 关于动态规划的练习题
- Linux笔记
- 使用OpenCV和Python计算视频中的总帧数
- HDOJ 1087 (JAVA实现 最大上升子序列和dp)
- JavaSE笔记
- [译]Gas 优化 - 如何优化存储
- Codeforces Round #613 (Div. 2) C. Fadi and LCM
- N皇后问题(DFS)
- [译]区块链民主 - 如何开发通过投票运行的合约
- java安全编码指南之:异常处理
- java安全编码指南之:死锁dead lock
- java安全编码指南之:方法编写指南