TensorFlow从0到1丨 第五篇:TensorFlow轻松搞定线性回归
上一篇 第一个机器学习问题 其实是一个线性回归问题(line regression),呈现了用数据来训练模型的具体方式。本篇从平行世界返回,利用TensorFlow,重新解决一遍该问题。
TensorFlow的API有低级和高级之分。
底层的API基于TensorFlow内核,它主要用于研究或需要对模型进行完全控制的场合。如果你想使用TF来辅助实现某个特定算法、呈现和控制算法的每个细节,那么就该使用低级的API。
高级API基于TensorFlow内核构建,屏蔽了繁杂的细节,适合大多数场景下使用。如果你有一个想法要验证并快速获得结果,那么TF的高级API就是高效的构建工具。
本篇使用TF的低级API来呈现线性回归的每一个步骤。
第一个机器学习的TF实现
TensorFlow的计算分为两个阶段:
- 构建计算图
- 执行计算图
先给出“平行世界”版本,(a, b)初始值为(-1, 50),第二次尝试(-1, 40)
程序输出:
上面的python代码利用了在2 TensorFlow内核基础 介绍的基本API实现了“第一个机器学习问题”。代码通过一步步构造计算图,最后得到了loss节点。loss即4 第一个机器学习问题中定义过的损失函数,这里再次给出其定义:
构建好计算图,接下来开始执行。执行loss节点(同时提供基于tf.placeholder的训练数据),得到loss的值为50。然后开始第二次训练,修改基于tf.Variable的a和b的值,再次执行loss节点,loss的值为0,降到了最低。此时的a和b就是最佳的模型参数了。
还记得那个神秘力量吗?到底是什么让机器在第二次训练中将模型参数(a, b)的值从初始的随机值(-1, 50)迁移到最优的(-1, 40)?如果不靠运气的话,机器如何能自动的找到最优解呢?
梯度下降算法
在此之前,或许你已经想到了随机穷举的办法,因为机器不怕累。这的确是个办法,但面临的挑战也不可接受:不可控。因为即便是只有2个参数的模型训练,其枚举域也是无限大的,这和靠运气没有分别。运气差的话,等个几百年也说不定。
不绕圈子,那个神秘力量就是:梯度下降算法(gradient descent)。虽然它也是让机器一小步一小步的去尝试不同的(a, b)的组合,但是它能指导每次前进的方向,使得每尝试一组新的值,loss就能变小一点点,直到趋于稳定。
而这一切TF已经把它封装好了。 本篇先把它当个黑盒子使用。
tf.train API
代码几乎和TensorFlow Get Started官方代码一致,主要区别在于训练数据不同,以及初始值不同。
- TF官方的训练数据是x_train = [1, 2, 3, 4],y_train = [0, -1, -2, -3],而我们的训练数据是“平行世界”的观察记录x_train = [22, 25, 28, 30],y_train = [18, 15, 12, 10]。
- TF官方的(a, b)初始值是(.3, -.3), 我们的是(-1., 50.)。
- 或许你还发现在官方版本的loss函数末尾没有
/ 8
,是因为我使用均方差的缘故,8由4x2得到(4个训练数据)。
重点说下tf.train API。tf.train.GradientDescentOptimizer即封装了梯度下降算法。梯度下降在数学上属于最优化领域,从其名字Optimizater也可体现出。其参数就是“学习率”(learning rate),先记住这个名词,暂不展开,其基本的效用是决定待调整参数的调整幅度。学习率越大,调整幅度越大,学习的越快。反之亦然。可也并不是越大越好,是相对来说的。先取0.01。
另一个需要输入给梯度下降算法的就是loss,它是求最优化解的主体,通过optimizer.minimize(loss)传入,并返回train节点。接下来在循环中执行train节点即可,循环的次数,即训练的步数。
执行计算图,程序输出:
这个结果令人崩溃,仅仅换了下TF官方get started中例子中模型的训练数据和初始值,它就不工作了。
先来看看问题在哪。一个调试的小技巧就是打印每次训练的情况,并调整loop的次数。
程序输出:
TF实际是工作的,并没有撂挑子。只是它训练时每次调整(a, b)都幅度很大,接下来又矫枉过正且幅度越来越大,导致最终承载a和b的tf.float32溢出而产生了nan。这不是TF的一个bug,而是算法本身、训练数据、学习率、训练次数共同导致的(它们有个共同的名字:超参数。)。可见,训练是一门艺术。
直觉上,初始值或许有优劣之分,或许是离最优值越近的初始值越容易找到。可是训练数据则应该是无差别的吧?实则不然。但是现在我还不打算把它解释清楚,等后面分析完梯度下降算法后再回来看这个问题。
遇到该问题的也不再少数,Stack Overflow上已经很好的回答了。我们先通过调整学习率和训练次数来得到一个完美的Ending。
把学习率从0.01调制0.0028,然后将训练次数从1000调整至70000。
程序输出:
最终代码如下:
TensorBoard
TF的另一个强大之处就是可视化算法的TensorBoard,把构造的计算图显示出来。图中显示,每一个基本运算都被独立成了一个节点。除了图中我标注的Rank节点、range节点,start节点、delta节点外,其他节点都是由所写代码构建出来的。
TensorBoard
词汇表
- derivative; 导数;
- estimator: 估计;
- gradient descent: 梯度下降;
- inference: 推理;
- line regression:线性回归;
- loss function: 损失函数;
- magnitude: 量;
- optimal: 最优的;
- optimizers: 优化器;
- 机器人产业链分析-中国机器人产业的发展机遇和挑战
- 如何与深度学习服务器优雅的交互?
- 比特币大跌又反弹30%,区块链技术与企业级有着怎样的关系?
- 十个实用MySQL函数
- 使用Apprenda和R分析应用程序工作负载数据
- 实现微信朋友圈所有动态点赞的自动化用例
- 后台设计的一些总结
- 2017年区块链当中的黑客大事件
- 5个云安全解决方案的注意事项
- 深入剖析ASP.NET的编译原理之二:预编译(Precompilation)
- 深入剖析ASP.NET的编译原理之二:预编译(Precompilation)
- Nodejs学习笔记(十六)--- Pomelo介绍&入门
- 美团再出幺蛾子,启动美团打车项目,滴滴感到威胁了吗?
- 深入剖析ASP.NET的编译原理之一:动态编译(Dynamical Compilation)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 详细对比php中类继承和接口继承
- 解决Keras的自定义lambda层去reshape张量时model保存出错问题
- 解决Keras中Embedding层masking与Concatenate层不可调和的问题
- 浅谈keras使用预训练模型vgg16分类,损失和准确度不变
- 关于tf.matmul() 和tf.multiply() 的区别说明
- python中执行smtplib失败的处理方法
- PHP+Ajax简单get验证操作示例
- Python matplotlib读取excel数据并用for循环画多个子图subplot操作
- python转化excel数字日期为标准日期操作
- thinkPHP框架通过Redis实现增删改查操作的方法详解
- PHP中引用类型和值类型功能与用法示例
- PHP文件上传小程序 适合初学者学习!
- php的扩展写法总结
- 实例介绍PHP删除数组中的重复元素
- Python迭代器协议及for循环工作机制详解