使用tensorflow搭建线性回归模型
时间:2022-07-28
本文章向大家介绍使用tensorflow搭建线性回归模型,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
上一阶段的数据分析学习因为工作原因耽误了,今天忙里偷个闲,重新开始了。 @猴子 求个第二关门票。
先分享一下最近学到的东西吧……
以前买过一本《tensorflow实战谷歌深度学习框架》,看了一半就留在家里吃灰了,最近重新翻开发现这本书已经跟不上现在的版本了,所以从网上找了一点代码学习。
tensorflow不止能用于深度学习,也能用来实现传统机器学习算法。比如实现线性回归。
tensorflow的线性回归代码当然不如scikit learn的简洁,在scikit learn中只需要几行代码:
from sklearn.linear_model import LinearRegression
clf = LinearRegression()
clf.fit(x,y)
而在tensorflow中很多功能需要自己实现。看起来麻烦,其实是提供了更加个性化的解决方案,比如可以自定义误差函数,达到个性化的模型效果。
而像梯度下降优化器这种写起来麻烦的功能,tensorflow已经实现好了。
要说tensorflow有什么优势的话,那就是如果你数据特别特别大的话,用tensorflow能分布计算吧。
下面是用tensorflow实现线性回归的完整代码。
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_boston
#读取数据
def read_boston_data():
boston = load_boston()
x = np.array(boston.data)
y = np.array(boston.target)
return x,y
#转化为正态分布
def feature_normalize(dataset):
mu = np.mean(dataset,axis = 0)
sigma = np.std(dataset,axis=0)
return (dataset-mu)/sigma
#添加一个常数项,用于训练偏差
def append_bias_reshape(x,y):
n_training_samples = x.shape[0]#训练样本数量
n_dim = x.shape[1]#特征数量
f = np.reshape(np.c_[np.ones(n_training_samples),x],[n_training_samples,n_dim+1])
l = np.reshape(y,[n_training_samples,1])
return f,l
if __name__ == '__main__':
#处理数据
x,y = read_boston_data()
norm_features = feature_normalize(x)
f,l = append_bias_reshape(norm_features,y)
n_dim = f.shape[1]
rnd_indices = np.random.rand(len(f)) < 0.80#生成一个随机的布尔数组
x_train = f[rnd_indices]
y_train = l[rnd_indices]
x_test = f[~rnd_indices]
y_test = l[~rnd_indices]
#tensorflow模型
learning_rate = 0.01#步长
training_epochs = 6000#训练次数
cost_history = []#记录训练误差
test_history = []#记录测试误差
X = tf.placeholder(tf.float32,[None,n_dim])
Y = tf.placeholder(tf.float32,[None,1])
W = tf.Variable(tf.ones([n_dim,1]))
init = tf.initialize_all_variables()
y_ = tf.matmul(X,W)
cost = tf.reduce_mean(tf.abs(y_-Y))#选择了绝对平均误差作为衡量指标
training_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
sess = tf.Session()
sess.run(init)
for epoch in range(training_epochs):
sess.run(training_step,feed_dict={X:x_train,Y:y_train})#训练模型
c = sess.run(cost,feed_dict={X:x_train,Y:y_train})#计算误差
print(c)
t = sess.run(cost,feed_dict={X:x_test,Y:y_test})#计算测试误差
print(t)
cost_history.append(c)
test_history.append(t)
plt.plot(range(len(test_history)),test_history,color = 'green')
plt.plot(range(len(cost_history)),cost_history,color = 'red')
plt.axis([0,training_epochs,0,np.max(cost_history)])
plt.show()
训练误差(红色)与测试误差(绿色)
效果还算不错,比自己写的梯度下降靠谱多了……
- 剖析Go编写的Socket服务器模块解耦及基础模块的设计
- Golang中的sync.WaitGroup用法实例
- Go 语言实现 MapReduce 框架
- Performance Schema使用简介(一)
- golang 垃圾回收 gc
- Go语言服务器开发之简易TCP客户端与服务端实现方法
- 移动搜索SEO分享:PHP自动生成百度开放适配及360移动适配专用的Sitemap文件
- 分享两种外链跳转方法,可避免权重流失。
- go语言十大排序算法总结
- 网站安全检测提示“页面异常导致本地路径泄漏”的解决办法
- Go语言归并排序算法实现
- 超強统计插件:My Visitors在知更鸟主题下的修改教程
- 让知更鸟主题的分类图标支持二级分类
- nwui —— 又一个go语言图形界面解决方案
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Android线程池控制并发数多线程下载
- Android progressbar实现带底部指示器和文字的进度条
- js 调用栈机制与ES6尾调用优化介绍
- Android Fragment实现列表和内容联动
- 前端中等算法-无重复字符的最长子串
- Android自定义动态壁纸开发(时钟)
- 手摸手教你写个ESLint 插件以及了解ESLint的运行原理
- 填满Github的绿色格子用我做的VSCode插件-Auto Commit
- Android多国语言转换Excel及Excel转换为string详解
- python上传时包含boundary时的解决方法
- 4行Python代码生成图像验证码(2种)
- Python 输出详细的异常信息(traceback)方式
- 我开发了一个一键添加佛祖保佑永无BUG、神兽护体等注释图形的工具
- Django实现whoosh搜索引擎使用jieba分词
- VMware下ubuntu与Windows实现文件共享的方法