基于Tensorflow框架的BP神经网络回归小案例--预测跳高
(案例):我们将14组国内男子跳高运动员各项素质指标作为输入,即(30m行进跑,立定三级跳远,助跑摸高,助跑4-6步跳高,负重深蹲杠铃,杠铃半蹲系数,100m,抓举),将对应的跳高成绩作为输出,通过对14位选手的数据训练建立模型,预测第15位选手的跳高成绩。
待预测样本a=[[3.0,9.3,3.3,2.05,100,2.8,11.2,50]]
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
构造数据:14个样本,8个特征1个标签
x=[[3.2,3.2,3,3.2,3.2,3.4,3.2,3,3.2,3.2,3.2,3.9,3.1,3.2],
[9.6,10.3,9,10.3,10.1,10,9.6,9,9.6,9.2,9.5,9,9.5,9.7],
[3.45,3.75,3.5,3.65,3.5,3.4,3.55,3.5,3.55,3.5,3.4,3.1,3.6,3.45],
[2.15,2.2,2.2,2.2,2,2.15,2.14,2.1,2.1,2.1,2.15,2,2.1,2.15],
[140,120,140,150,80,130,130,100,130,140,115,80,90,130],
[2.8,3.4,3.5,2.8,1.5,3.2,3.5,1.8,3.5,2.5,2.8,2.2,2.7,4.6],
[11,10.9,11.4,10.8,11.3,11.5,11.8,11.3,11.8,11,11.9,13,11.1,10.85],
[50,70,50,80,50,60,65,40,65,50,50,50,70,70]]
y=[[2.24],[2.33],[2.24],[2.32],[2.2],[2.27],[2.2],[2.26],[2.2],[2.24],[2.24],[2.2],
[2.2],[2.35]]
获取数据集
x_t=np.array(x,dtype=‘float32’).T #[148]
y_true=np.array(y,dtype=‘float32’) #[141]
将特征数据最值归一,范围在(0,1)之间
mm=MinMaxScaler() #实例化
std=mm.fit(x_t) #训练模型
x_true=std.transform(x_t) #转化
print(x_true)
print(y_true)
通过占位符,预定义输入X,输出Y
即输入层8*1个神经元,输出层1个神经元
X=tf.placeholder(tf.float32,[None,8])
Y=tf.placeholder(tf.float32,[None,1])
随机数列生成,创建隐含层的神经网络,隐含层4个神经元
truncated_normal:选取位于正态分布方差在0.1附近的随机数据
w1=tf.Variable(tf.truncated_normal([8,4],stddev=0.1))
b1=tf.Variable(tf.zeros([4]))
w2=tf.Variable(tf.zeros([4,1]))
b2=tf.Variable(tf.zeros([1]))
relu,为激活函数,增加非线性关系,隐藏层和输出层的计算
L1=tf.nn.relu(tf.matmul(X,w1)+b1)
y_pre=tf.matmul(L1,w2)+b2
计算损失函数:均方误差
loss=tf.reduce_mean(tf.cast(tf.square(Y-y_pre),tf.float32))
梯度下降优化损失函数,学习率过大容易导致权重非常大,会出现nan值
train_op=tf.train.GradientDescentOptimizer(0.01).minimize(loss)
初始化变量
init_op=tf.global_variables_initializer()
创建一个saver,用来保存训练模型
saver=tf.train.Saver()
开启回话
with tf.Session() as sess:
sess.run(init_op)
训练模型15次
for i in range(1,300): #控制训练批次
for j in range (len(y_true)):#控制每批次训练的样本数
sess.run(train_op,feed_dict={X:[x_true[j,:]],Y:[y_true[j,:]]})#[[]]是为了匹配占位的类型
输出每次训练的损失
print(‘第%s批次第%s个样本训练的损失为:%s,真实值为:%s,预测值为:%s’% (i,j+1,
sess.run(loss, feed_dict={X:[x_true[j,:]],Y:[y_true[j,:]]}),
y_true[j,:],
sess.run(y_pre,feed_dict={X:[x_true[j,:]],Y:[y_true[j,:]]})))
保存模型:需要在会话里完成(注意缩进代码)
saver.save(sess,’./BP_demo/BP_model’)
加载模型,预测15号选手的跳高成绩
saver.restore(sess,’./BP_demo/BP_model’)
样本原始数据
a = [[3.0,9.3,3.3,2.05,100,2.8,11.2,50]]
获取测试样本
x_test=np.array(a,dtype=‘float32’)
将数据最值归一
x_test=std.transform(x_test)
print(‘15号选手的跳高成绩预测值为:’, sess.run(y_pre,feed_dict={X:x_test})
结果:
…
第299批次第11个样本训练的损失为:0.00016767633,真实值为:[2.24],预测值为:[[2.227051]]
第299批次第12个样本训练的损失为:1.5376372e-06,真实值为:[2.2],预测值为:[[2.19876]]
第299批次第13个样本训练的损失为:0.00062711653,真实值为:[2.2],预测值为:[[2.2250423]]
第299批次第14个样本训练的损失为:0.008744326,真实值为:[2.35],预测值为:[[2.2564888]]
15号选手的跳高成绩预测值为: [[2.1450984]]
问题:
只是简单实现数据预测,误差还是较大,有更好的优化方法,欢迎大家一起来分享哦!
- 每周.NET前沿技术文章摘要(2017-05-24)
- ruby学习笔记(10)-puts,p,print的区别
- Linux下的Mongodb部署应用梳理
- Ocelot API网关的实现剖析
- ruby学习笔记(9)-别名(alias)与方法取消(undef,remove_method)
- Pupet自动化管理环境部署记录
- ruby学习笔记(8)-"静态方法的4种写法"与"单例方法的2种写法"
- Puppet常识梳理
- linux下增加磁盘改变指定文件路径分区挂载点和迁移数据
- 手动编写的几个简单的puppet管理配置
- 选择一款适合自己的ruby on rails IDE开发工具
- 微信的两种用途
- Sqlite快速上手使用指南
- 自动类型安全的.NET标准REST库refit
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- macOS 安装软件已损坏无法打开解决办法 (真好用!)
- nginx 配置反向代理
- ES6新特性速查表
- React-Native Android打包
- React-Native iOS打包
- Webpack+Babel手把手带你搭建开发环境(内附配置文件)
- Redux 异步解决方案2. Redux-Saga中间件
- Redux异步解决方案 1. Redux-Thunk中间件
- 深度学习Pytorch检测实战 - Notes - 第1&2章 基础知识
- Java多线程编程在JMeter中应用
- Kubernetes 升级填坑指南(一)
- 根据 PID 获取 K8S Pod名称 - 反之 POD名称 获取 PID
- 用python实现一个verilog网表Parser
- 经典 | Python实例小挑战—Part eight
- python的数字与字符串相互转换