TensorFlow-5: 用 tf.contrib.learn 来构建输入函数
学习资料: https://www.tensorflow.org/get_started/input_fn
今天学习用 tf.contrib.learn 来建立 input funciton, 并用 DNN 对 Boston Housing 数据集进行回归预测。
问题:
- 给一组波士顿房屋价格数据,要用神经网络回归模型来预测房屋价格的中位数
- 数据集可以从官网教程下载: https://www.tensorflow.org/get_started/input_fn
- 它包括以下特征:
- 我们需要预测的是MEDV这个标签,以每一千美元为单位
一共有 5 步:
- 导入 CSV 格式的数据集
- 建立神经网络回归模型
- 用训练数据集训练模型
- 评价模型的准确率
- 对新样本数据进行分类
"""DNNRegressor with custom input_fn for Housing dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import pandas as pd
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age",
"dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm",
"age", "dis", "tax", "ptratio"]
LABEL = "medv"
def input_fn(data_set):
feature_cols = {k: tf.constant(data_set[k].values) for k in FEATURES}
labels = tf.constant(data_set[LABEL].values)
return feature_cols, labels
def main(unused_argv):
# Load datasets
training_set = pd.read_csv("boston_train.csv", skipinitialspace=True,
skiprows=1, names=COLUMNS)
test_set = pd.read_csv("boston_test.csv", skipinitialspace=True,
skiprows=1, names=COLUMNS)
# Set of 6 examples for which to predict median house values
prediction_set = pd.read_csv("boston_predict.csv", skipinitialspace=True,
skiprows=1, names=COLUMNS)
# Feature cols
feature_cols = [tf.contrib.layers.real_valued_column(k)
for k in FEATURES]
# Build 2 layer fully connected DNN with 10, 10 units respectively.
regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_cols,
hidden_units=[10, 10],
model_dir="/tmp/boston_model")
# Fit
regressor.fit(input_fn=lambda: input_fn(training_set), steps=5000)
# Score accuracy
ev = regressor.evaluate(input_fn=lambda: input_fn(test_set), steps=1)
loss_score = ev["loss"]
print("Loss: {0:f}".format(loss_score))
# Print out predictions
y = regressor.predict(input_fn=lambda: input_fn(prediction_set))
# .predict() returns an iterator; convert to a list and print predictions
predictions = list(itertools.islice(y, 6))
print("Predictions: {}".format(str(predictions)))
if __name__ == "__main__":
tf.app.run()
今天主要的知识点就是输入函数
在上面的代码中我们可以看到,输入数据时用的是 pandas,可以直接读取 CSV 文件 为了识别数据集中哪些是列,哪些是特征,哪些是预测标签,需要把这三者定义出来
在定义神经网络回归模型时,我们建立一个具有两层隐藏层的神经网络,每一层具有 10 个神经元节点, 接下来就是建立输入函数,它的作用就是把输入数据传递给回归模型,它可以接受 pandas 的 Dataframe 结构,并将特征和标签列作为 Tensors 返回
在训练时,只需要把训练数据集传递给输入函数,用 fit 迭代5000步 评价模型时,也是将测试数据集传递给输入函数,再用 evaluate 预测时,同样将预测数据集传递给输入函数
关于 输入函数:
昨天学到读取 CSV 文件的方法适用于不需要对原来的数据有什么操作的时候 但是当需要对数据进行特征工程时,我们就需要有一个输入函数来把数据的预处理给封装起来,再传递给模型
输入函数的基本框架:
def my_input_fn():
# Preprocess your data here...
# ...then return 1) a mapping of feature columns to Tensors with
# the corresponding feature data, and 2) a Tensor containing labels
return feature_cols, labels
输入函数必须返回下面两种值:
feature_cols
:是一个字典,key 就是特征列的名字,value 就是 tensor,包含了相应的数据
labels
:返回包含标签数据的 tensor,即所想要预测的目标
如果特征/标签数据存在pandas数据帧中或numpy数组中,那么需要将其转换为Tensor,然后从 input_fn 中返回。
对于稀疏数据 大多数值为0的数据,应该填充一个 SparseTensor,
下面例子,就是定义了一个具有3行和5列的二维 SparseTensor。在 [0,1] 上的元素的值为 6,[2,4] 上的元素值为 0.5,其他值为 0:
sparse_tensor = tf.SparseTensor(indices=[[0,1], [2,4]],
values=[6, 0.5],
dense_shape=[3, 5])
[[0, 6, 0, 0, 0]
[0, 0, 0, 0, 0]
[0, 0, 0, 0, 0.5]]
- FFLIB之FFLUA——C++嵌入Lua&扩展Lua利器
- Python之匿名函数
- H2Engine游戏服务器设计之属性管理器
- linux epoll 开发指南-【ffrpc源码解析】
- Python之递归函数
- 你不得不会的MarkDown
- 状态机的实现探讨
- Docker入门实战(二)——Docker镜像操作
- 使用强大的 Mockito 来测试你的代码
- java学习手册-CentOS 6.3(x86_32)下安装Oracle 10g R2
- Docker入门实战(三)——用Dockerfile构建镜像
- C++中消息自动派发之二 About IDL解析器
- C++中消息自动派发之三 About JSON Encode
- Linux管道命令
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 全网最详细的一篇Flutter 尺寸限制类容器总结
- 一篇带你看懂Flutter叠加组件Stack
- Flutter 拖拽排序组件 ReorderableListView
- 女神节 | 程序员如何低调而又不失逼格
- Flutter 拖拽控件Draggable看这一篇就够了
- 面试官:你精通多少种语言的Hello World?
- Flutter 裁剪类组件 最全总结
- Flutter Form表单控件超全总结
- 你知道吗,Flutter内置了10多种Button控件
- Flutter 日期时间DatePicker控件及国际化
- 强大的Flutter App升级功能
- 你知道吗,Flutter内置了10多种show
- 还记得第一个看到的Flutter组件吗?
- 150多个Flutter组件详细介绍送给你
- Flutter 学习路线图