高级API用法示例
tf.contrib.learn Quickstart
TensorFlow的机器学习高级API(tf.contrib.learn)使配置、训练、评估不同的学习模型变得更加容易。在这个教程里,你将使用tf.contrib.learn在Iris data set上构建一个神经网络分类器。代码有一下5个步骤:
- 在TensorFlow数据集上加载Iris
- 构建神经网络
- 用训练数据拟合
- 评估模型的准确性
- 在新样本上分类
Complete Neural Network Source Code
这里是神经网络的源代码:
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import urllib import numpy as np import tensorflow as tf # Data sets IRIS_TRAINING = "iris_training.csv" IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv" IRIS_TEST = "iris_test.csv" IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv" def main(): # If the training and test sets aren't stored locally, download them. if not os.path.exists(IRIS_TRAINING): raw = urllib.urlopen(IRIS_TRAINING_URL).read() with open(IRIS_TRAINING, "w") as f: f.write(raw) if not os.path.exists(IRIS_TEST): raw = urllib.urlopen(IRIS_TEST_URL).read() with open(IRIS_TEST, "w") as f: f.write(raw) # Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32) # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model") # Define the training inputs def get_train_inputs(): x = tf.constant(training_set.data) y = tf.constant(training_set.target) return x, y # Fit model. classifier.fit(input_fn=get_train_inputs, steps=2000) # Define the test inputs def get_test_inputs(): x = tf.constant(test_set.data) y = tf.constant(test_set.target) return x, y # Evaluate accuracy. accuracy_score = classifier.evaluate(input_fn=get_test_inputs, steps=1)["accuracy"] print("nTest Accuracy: {0:f}n".format(accuracy_score)) # Classify two new flower samples. def new_samples():
return np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=np.float32) predictions = list(classifier.predict(input_fn=new_samples)) print( "New Samples, Class Predictions: {}n" .format(predictions) )if __name__ == "__main__": main()
Load the Iris CSV data to TensorFlow
Iris data set包含了150行数据,3个种类:Iris setosa, Iris virginica, and Iris versicolor.
每一行包括了以下的数据:花萼的宽度,长度,花瓣的宽度,花的种类。花的种类有整数表示,0表示Iris setosa, 1表示Iris virginica, 2表示Iris versicolor.
Sepal Length |
Sepal Width |
Petal Length |
Petal Width |
Species |
---|---|---|---|---|
5.1 |
3.5 |
1.4 |
0.2 |
0 |
4.9 |
3.0 |
1.4 |
0.2 |
0 |
4.7 |
3.2 |
1.3 |
0.2 |
0 |
… |
… |
… |
… |
… |
7.0 |
3.2 |
4.7 |
1.4 |
1 |
6.4 |
3.2 |
4.5 |
1.5 |
1 |
6.9 |
3.1 |
4.9 |
1.5 |
1 |
… |
… |
… |
… |
… |
6.5 |
3.0 |
5.2 |
2.0 |
2 |
6.2 |
3.4 |
5.4 |
2.3 |
2 |
5.9 |
3.0 |
5.1 |
1.8 |
2 |
这里,Iris数据随机分割成了两组不同的CSV文件:
- 120个样本的训练数据(iris_training.csv)
- 30个样本的测试数据(iris_test.csv).
开始时,首先引进所有必要的模块,然后定义下载存储数据集的路径:
from __future__ import absolute_ import from __future__ import division from __future__ import print_function import os import urllib import tensorflow as tf import numpy as np IRIS_TRAINING = "iris_training.csv" IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv" IRIS_TEST = "iris_test.csv" IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
然后,如果训练和测试集没有在本地存储,下载:
if not os.path.exists(IRIS_TRAINING): raw = urllib.urlopen(IRIS_TRAINING_URL).read() with open(IRIS_TRAINING,'w') as f: f.write(raw) if not os.path.exists(IRIS_TEST): raw = urllib.urlopen(IRIS_TEST_URL).read() with open(IRIS_TEST,'w') as f: f.write(raw)
然后,用learn.datasets.base的load_csv_with_header()方法加载训练集和测试集成Dataset S,load_csv_with_header()包涵一下三个参数:
- filename,CSV文件的路径
- target_dtype,数据集目标值的numpy数据类型
- features_dtype,数据集特征值的numpy数据类型
这里,目标是花的种类,是0-2的整数,所以数据类型是np.int:
# Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)
tf.contrib.learn中的Dataset S是tuple,你可以通过data,target来访问特征值和目标值,比如,training_set.data,training_set.target
Construct a Deep Neural Network Classifier
tf.contrib.learn提供了多种预定义的模型,称为 Estimator S,你可以用“黑盒子”在你的数据上来训练和评估节点。这里,你讲配置深度神经网络分类器来拟合Iris数据,你可以用tf.contrib.learn.DNNClassifier作为示例:
# Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model")
首先定义特征所在的列,有4个特征,所以dimension设定为4.
然后,构建了DNNClassifier,包含以下参数:
- feature_columns=feature_columns.上面定义的特征的列
- hidden_units=[10, 20, 10]. 三个隐层,分别包含10,20,10个神经元
- n_classes=3.三个目标
- model_dir=/tmp/iris_model.训练模型时保存的断点数据
Describe the training input pipeline
tf.contrib.learn API使用输入函数,创建TensorFlow节点来生成模型数据。这里,数据比较小,可以放在tf.constant。
# Define the test inputs def get_train_inputs():
x = tf.constant(training_set.data) y = tf.constant(training_set.target) return x, y
Fit the DNNClassifier to the Iris Training Data
配置了DNN分类器,你可以用fit方法来拟合数据,传递get_train_inputs到input_fn参数中,循环训练2000次:
# Fit model. classifier.fit(input_fn=get_train_inputs, steps=2000)
等效于:
classifier.fit(x=training_set.data, y=training_set.target, steps=1000) classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
如果你想追踪训练模型,你可以用TensorFlow monitor来执行节点的日志。
“Logging and Monitoring Basics with tf.contrib.learn”
Evaluate Model Accuracy
你已经用训练数据拟合了模型,现在,你可以用evaluate方法在测试集上评估准确性。像fit一样,evaluate也需要一个输入函数来构建输入的通道,并返回评估结果的字典。
# Define the test inputs def get_test_inputs(): x = tf.constant(test_set.data) y = tf.constant(test_set.target) return x, y # Evaluate accuracy. accuracy_score = classifier.evaluate(input_fn=get_test_inputs, steps=1)["accuracy"] print("nTest Accuracy: {0:f}n".format(accuracy_score))
运行整个脚本,打印:
Test Accuracy: 0.966667
Classify New Samples
用predict()方法来分类新的样本,比如,你有下面的两个新样本:
Sepal Length |
Sepal Width |
Petal Length |
Petal Width |
---|---|---|---|
6.4 |
3.2 |
4.5 |
1.5 |
5.8 |
3.1 |
5.0 |
1.7 |
predict方法返回一个generator,可以转换成list
# Classify two new flower samples.
def new_samples():
return np.array(
[[6.4, 3.2, 4.5, 1.5],
[5.8, 3.1, 5.0, 1.7]],dtype=np.float32)
predictions = list(classifier.predict(input_fn=new_samples))
print(
"New Samples, Class Predictions:
{}n" .format(predictions))
结果大致如下:
New Samples, Class Predictions: [1 2]
- 不使用反射的实体类方案
- matlab GUI基础1
- Why to do,What to do,Where to do 与 Lambda表达式!
- Cloak ; Dagger攻击:一种可针对所有版本Android的攻击技术(含演示视频)
- 实例探究字符编码:unicode,utf-8,default,gb2312 的区别
- 分布式计算,WCF+JSON+实体对象与WebService+DataSet效率大比拼
- 【自然框架】 页面里的父类—— 改进和想法、解释
- 线性神经网络
- 【数据可视化】深度解析大数据可视化设计案例分析
- 使用IE6看老赵的博客——比较完美版(可以在线查看、回复)
- 【Python环境】R vs Python:硬碰硬的数据分析
- 使用IE6看老赵的博客——jQuery初探
- matlab GUI基础8
- 见到了“公司”定义一个Company类,那么见到了“字段”是不是也可定义一个Column类?
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法