TensorFlow小入门
同步微博端,代码混乱,请查看原文如下:
这篇文章是TensorFlow的入门教程。在开始阅读本文之前,请确保你会Python,并且对矩阵有一定的了解,除此之外,最好能懂一点机器学习的知识,不过如果你对机器学习一无所知也没关系,你可以从阅读这篇文章开始学起。
TensorFlow提供了丰富的接口供调用。TensorFlow的内核尽可能开放了最完备的接口,它允许你在此基础上从最底层开始开发。我们建议一般开发者可以不用从这么底层开始开发,这些底层接口更适合科研人员。TensorFlow的上层接口都是在此基础上搭建的。上层接口比底层更容易使用。像tf.contrib.learn这样的高层接口帮助你去管理数据集、估算、训练和推理。请注意,有一些高层的接口的方法名包含contrib,这些是正在开发的接口,它们在接下来的版本当中,有可能会被修改甚至删掉。
这篇文章会先讲一下TensorFloat的基础知识,然后,我们会带着大家一起学习一下如何用tf.contrib.learn实现之前提到的模型。
Tensors
TensorFlow的核心就是tensor,tensor是由一系列任意维度的矩阵构成。
3 # 这是一个维度为0的tensor,这是一个标量,它的大小是[]。 [1. ,2., 3.] # 这是一个维度为1的tensor,这是一个矢量,它的大小是[3] [[1., 2., 3.], [4., 5., 6.]] # 这是一个维度为2的tensor。这是一个矩阵,它的大小是[2, 3] [[[1., 2., 3.]], [[7., 8., 9.]]] # 这是一个维度为3的tensor,它的大小是[2, 1, 3]
TensorFlow的内核
导入TensorFlow
TensorFlow的导入方式如下:
import tensorflow as tf
没有这个包的话安装:
$ pip install tensorflow
导入了TensorFlow了之后,就可以通过Python来访问TensorFlow的里面的类、方法和符号。后续所有的代码执行前都必须先导入TensorFlow。
算法图
- TensorFlow的内核分成两部分:构建算法图和运行算法图。
算法图是由一系列算法作为节点形成的一幅图。让我们一起构建一个简单的算法图。每一个节点都有任意数量的tensor作为输入,一个tensor作为输出。常量是一个特殊的节点。所有的常量都没有输入,它的输出来自内部存储的数据。如果想要创建两个浮点数类型的tensor,比如node1和node2,我们可以这样写:
node1 = tf.constant(3.0, dtype=tf.float32)
node2 = tf.constant(4.0) # 隐式指定了tensor内部数据的类型是tf.float32
print(node1, node2)
最后一行打印出来的结果是:
Tensor("Const:0", shape=(), dtype=float32) Tensor("Const_1:0", shape=(), dtype=float32)
请注意,这里并没有像我们期望的那样打印出3.0和4.0。这是因为node1和node2是节点,只有运算的时候,才会分别生成3.0和4.0。想要对这两个节点进行运算,我们必须启动一个会话(Session)。
下面的代码创建 了一个会话,并且调用了它的run方法来执行这个算法图。求出了node1和node2的值。
sess = tf.Session()
print(sess.run([node1, node2]))
这样我们就可以看到期望的结果:
[3.0, 4.0]
我们还可以构建更加复杂的图。比如,我们可以把刚才的两个节点加起来生成一个新的节点:
node3 = tf.add(node1, node2)
print("node3: ", node3)
print("sess.run(node3): ",sess.run(node3))
最后两行打印的代码打印出来的结果如下:
node3: Tensor("Add:0", shape=(), dtype=float32)
sess.run(node3): 7.0
- TensorFlow提供了一个工具,叫TensorBoard,可以用来查看算法图的结构。这就是刚才的代码对应的算法图。
这张图非常简单,因为我们的算法产生的永远是一个定值。一张图如果想要有外部输入,我们就需要用到占位符(placeholder)。一个占位符表示一定会提供一个输入。
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
adder_node = a + b # 这里面的加号和调用tf.add(a, b)等效
上面的三行像是一个函数。定义了两个输入参数a和b。然后把它们加了起来。我们要执行这个算法图,就必须要传参数:
print(sess.run(adder_node, {a: 3, b:4.5}))
print(sess.run(adder_node, {a: [1,3], b: [2, 4]}))
运行的结果是:
7.5
[ 3. 7.]
这张算法图是这样的:
我们还可以把图变得更复杂,比如我们可以再添加一个节点:
add_and_triple = adder_node * 3.
print(sess.run(add_and_triple, {a: 3, b:4.5}))
执行的结果是:
22.5
对应的图是这样的:
- 在机器学习中,我们希望一个模型可以接受任何参数。为了让模型可以被训练,我们希望可以通过修改图,使得同样的输入会得到新的输出。变量(Variables)允许我们把一个可以训练的参数加入到图中。创建变量的时候,需要指定它们的类型和初始值。
w = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32)
linear_model = W * x + b
当你调用tf.constant的时候,常量就被初始化了,它的值永远不会改变。但是当你调用tf.Variable的时候,变量并没有被初始化,要初始化变量,你必须显式地执行如下操作:
init = tf.global_variables_initializer()
sess.run(init)
有一点很重要,init是一个引用。它引用的是一个子图。这个子图初始化了所有全局变量。一直到执行sess.run的时候,这些变量才真正被初始化。
因为x是一个占位符,所以我们可以一次性计算醋linear_model的四个值。
print(sess.run(linear_model, {x:[1,2,3,4]}))
运行的结果是:
[ 0. 0.3 0.6 0.90000004]
我们创建了一个模型,但是我们不知道这个模型好不好。想要对这个模型做一个评估,我们需要y占位符去提供想要的值,然后我们需要写一个损失函数(loss function)。
损失函数表征了当前模型和所提供的数据之间的差别。我们将会用一个标准的损失函数来做线性回归。线性回归就是把所求的值和所提供数据的差的平方加起来。linear_model - y是一个向量,向量的值就是所求的值和提供的数据之间的误差。我们调用tf.square去把这个值求平方。然后我们用tf.reduce_sum把所有的值加起来,得到一个数字,通过这个方法得到一个衡量这个样本错误的数据。
y = tf.placeholder(tf.float32)
squared_deltas = tf.square(linear_model - y)
loss = tf.reduce_sum(squared_deltas)
print(sess.run(loss, {x:[1,2,3,4], y:[0,-1,-2,-3]}))
求出的损耗是:
23.66
我们可以把W和b设置为正确答案-1和1。一个变量可以用tf.Variable来初始化,如果要修改它的值,也可以用tf.assign。比如,W = -1 并且 b = 1 就是我们这个模型最理想的参数:
fixW = tf.assign(w, [-1.])
fixb = tf.assign(b, [1.])
sess.run([fixW, fixb])
print(sess.run(loss, {x:[1,2,3,4], y:[0,-1,-2,-3]}))
当我们这么设置了之后,我们发现损耗就变成了:
0.0
- 我们这里是事先知道w和b的正确值,而我们现在研究机器学习的目标是要机器自己找到这个合适的参数。接下来,我们就来训练我们的机器找到这个参数。
tf.train
- TensorFlow提供了一个优化器,缓慢的改变每个变量,来最小化损耗。最简单的优化器就是gradient descent。这个优化器调整参数的方式是根据变量和损耗之间的导数的大小。简单起见,一般都帮你代劳。
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
sess.run(init) # reset values to incorrect defaults.
for i in range(1000):
sess.run(train, {x:[1,2,3,4], y:[0,-1,-2,-3]})
print(sess.run([W, b]))
最终求得模型的参数是:
[array([-0.9999969], dtype=float32), array([ 0.99999082],
dtype=float32)]
到此为止,我们让机器完成了一次学习的过程。尽管这仅仅是一个简单的线性回归的问题,根本不需要用TensorFlow大费周章的来实现。TensorFlow为常见的设计模式,数据结构和功能提供了很好的抽象。
完整的代码
一个完整的训练线性回归模型的代码如下:
# coding:utf8
import tensorflow as tf
import numpy as np
w = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32)
linear_model = w * x + b
y = tf.placeholder(tf.float32)
loss = tf.reduce_sum(tf.square(linear_model - y))
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
x_train = [1, 2, 3, 4]
y_train = [0, -1, -2, -3]
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for i in range(1000):
sess.run(train, {x: x_train, y: y_train})
curr_w, curr_b, curr_loss = sess.run([w, b, loss], {x: x_train, y: y_train})
print("w:%s b:%s loss: %s" % (curr_w, curr_b, curr_loss))
运行的结果是:
W: [-0.9999969] b: [ 0.99999082] loss: 5.69997e-11
注意这里的损耗是一个非常接近于0的数字,如果你运行同样的代码,得到的结果不一定和这个一模一样,因为我们是用随机值来训练这个模型的。
- 最后给出这个算法图的图形:
- 小猪农场获百万天使轮,六声域名源自运营主体
- Intellij idea 的maven项目自动下载jar包
- python3和python2共存
- 揭密微信跳一跳小游戏那些外挂
- 特斯拉出现人才流失潮,竟因为一些工程师认为Autopilot自动驾驶技术并不安全
- 微信又更新了,这次放出年度大招!新变化让不少人拍手叫好!
- “JINAN”:未来电动汽车边跑边充电
- Bagging算法
- 基于Region Proposal的深度学习目标检测简述(一)
- 10大数据挖掘算法及其简介
- SpringMVC返回图片的几种方式
- 区块链技术3.0来了,靠谱吗,看看区块链技术3.0能干啥
- SpringMVC支持跨域的几种姿势
- 万达回应网科裁员:系局部调整
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法