python实现线性回归之简单回归
时间:2022-07-23
本文章向大家介绍python实现线性回归之简单回归,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
代码来源:https://github.com/eriklindernoren/ML-From-Scratch
首先定义一个基本的回归类,作为各种回归方法的基类:
class Regression(object):
""" Base regression model. Models the relationship between a scalar dependent variable y and the independent
variables X.
Parameters:
-----------
n_iterations: float
The number of training iterations the algorithm will tune the weights for.
learning_rate: float
The step length that will be used when updating the weights.
"""
def __init__(self, n_iterations, learning_rate):
self.n_iterations = n_iterations
self.learning_rate = learning_rate
def initialize_wights(self, n_features):
""" Initialize weights randomly [-1/N, 1/N] """
limit = 1 / math.sqrt(n_features)
self.w = np.random.uniform(-limit, limit, (n_features, ))
def fit(self, X, y):
# Insert constant ones for bias weights
X = np.insert(X, 0, 1, axis=1)
self.training_errors = []
self.initialize_weights(n_features=X.shape[1])
# Do gradient descent for n_iterations
for i in range(self.n_iterations):
y_pred = X.dot(self.w)
# Calculate l2 loss
mse = np.mean(0.5 * (y - y_pred)**2 + self.regularization(self.w))
self.training_errors.append(mse)
# Gradient of l2 loss w.r.t w
grad_w = -(y - y_pred).dot(X) + self.regularization.grad(self.w)
# Update the weights
self.w -= self.learning_rate * grad_w
def predict(self, X):
# Insert constant ones for bias weights
X = np.insert(X, 0, 1, axis=1)
y_pred = X.dot(self.w)
return y_pred
说明:初始化时传入两个参数,一个是迭代次数,另一个是学习率。initialize_weights()用于初始化权重。fit()用于训练。需要注意的是,对于原始的输入X,需要将其最前面添加一项为偏置项。predict()用于输出预测值。
接下来是简单线性回归,继承上面的基类:
class LinearRegression(Regression):
"""Linear model.
Parameters:
-----------
n_iterations: float
The number of training iterations the algorithm will tune the weights for.
learning_rate: float
The step length that will be used when updating the weights.
gradient_descent: boolean
True or false depending if gradient descent should be used when training. If
false then we use batch optimization by least squares.
"""
def __init__(self, n_iterations=100, learning_rate=0.001, gradient_descent=True):
self.gradient_descent = gradient_descent
# No regularization
self.regularization = lambda x: 0
self.regularization.grad = lambda x: 0
super(LinearRegression, self).__init__(n_iterations=n_iterations,
learning_rate=learning_rate)
def fit(self, X, y):
# If not gradient descent => Least squares approximation of w
if not self.gradient_descent:
# Insert constant ones for bias weights
X = np.insert(X, 0, 1, axis=1)
# Calculate weights by least squares (using Moore-Penrose pseudoinverse)
U, S, V = np.linalg.svd(X.T.dot(X))
S = np.diag(S)
X_sq_reg_inv = V.dot(np.linalg.pinv(S)).dot(U.T)
self.w = X_sq_reg_inv.dot(X.T).dot(y)
else:
super(LinearRegression, self).fit(X, y)
这里使用两种方式进行计算。如果规定gradient_descent=True,那么使用随机梯度下降算法进行训练,否则使用标准方程法进行训练。
最后是使用:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
import sys
sys.path.append("/content/drive/My Drive/learn/ML-From-Scratch/")
from mlfromscratch.utils import train_test_split, polynomial_features
from mlfromscratch.utils import mean_squared_error, Plot
from mlfromscratch.supervised_learning import LinearRegression
def main():
X, y = make_regression(n_samples=100, n_features=1, noise=20)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
n_samples, n_features = np.shape(X)
model = LinearRegression(n_iterations=100)
model.fit(X_train, y_train)
# Training error plot
n = len(model.training_errors)
training, = plt.plot(range(n), model.training_errors, label="Training Error")
plt.legend(handles=[training])
plt.title("Error Plot")
plt.ylabel('Mean Squared Error')
plt.xlabel('Iterations')
plt.savefig("test1.png")
plt.show()
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print ("Mean squared error: %s" % (mse))
y_pred_line = model.predict(X)
# Color map
cmap = plt.get_cmap('viridis')
# Plot the results
m1 = plt.scatter(366 * X_train, y_train, color=cmap(0.9), s=10)
m2 = plt.scatter(366 * X_test, y_test, color=cmap(0.5), s=10)
plt.plot(366 * X, y_pred_line, color='black', linewidth=2, label="Prediction")
plt.suptitle("Linear Regression")
plt.title("MSE: %.2f" % mse, fontsize=10)
plt.xlabel('Day')
plt.ylabel('Temperature in Celcius')
plt.legend((m1, m2), ("Training data", "Test data"), loc='lower right')
plt.savefig("test2.png")
plt.show()
if __name__ == "__main__":
main()
利用sklearn库生成线性回归数据,然后将其拆分为训练集和测试集。
utils下的mean_squared_error():
def mean_squared_error(y_true, y_pred):
""" Returns the mean squared error between y_true and y_pred """
mse = np.mean(np.power(y_true - y_pred, 2))
return mse
结果:
Mean squared error: 532.3321383700828
- 挑战数据结构与算法面试题——统计上排数在下排出现的次数
- Go语言的 10 个实用技术--转
- MySQL反连接的优化总结(r10笔记第51天)
- python基础知识——内置数据结构(列表)
- 【Go 语言社区】Go语言Slice去重
- 【Go 语言社区】Golang 语言再谈接口
- 【Go 语言社区】Golang 语言再谈常量
- 【Go 语言社区】HTML5 Canvas+JS控制电脑或手机上的摄像头实例
- MySQL Profile在5.7的简单测试(r10笔记第50天)
- 【Go 语言社区】Golang中interface判断nil问题
- 有趣的rownum测试(r10笔记第49天)
- 【Go 语言社区】关于Golang 数据缓存到redis内存数据库遇到的问题
- go中的读写锁RWMutex
- Centos7.4 版本环境下安装Mysql5.7操作记录
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 1,Jupyter NoteBook 常用魔法命令
- 60行代码徒手实现深度神经网络
- 30行代码徒手实现logistic回归
- 8,模型的训练
- 在腾讯云上部署科学计算软件Amber
- 手把手教你搭建一个灰度发布环境
- Kibana: 如何使用 Search Bar
- 「PHP」以nginx、php-cgi为例,把nginx、php-cgi安装为Windows系统服务
- 聊聊dubbo-go的GenericFilter
- 知新 | koa框架入门到熟练第二章
- JVM学习二
- 微信小程序对接云开发录音文件识别nodejs sdk
- 利用python读取WORD文档中的创建者信息
- LeetCode-2.两数相加 使用链表加法实现
- Spring学习(2):Spring Bean管理(上)