初探 TensorFlow.js
// 每日前端夜话 第409篇
// 正文共:3400 字
// 预计阅读时间:8 分钟
在本文中我们来研究怎样用 TensorFlow.js 创建基本的 AI 模型,并用更复杂的模型实现一些有趣的功能。我只是刚刚开始接触人工智能,尽管不需要深入的人工智能知识,但还是需要搞清楚一些概念才行。
什么是模型?
真实世界是很复杂的,我们需要对其进行简化才能理解,可以用通过模型来进行简化,这种模型有很多种:比如世界地图,或者图表等。
比如要建立一个用来表示房子出租价格与房屋面积关系的模型:首先要收集一些数据:
房间数量 |
价格 |
---|---|
3 |
131000 |
3 |
125000 |
4 |
235000 |
4 |
265000 |
5 |
535000 |
然后,把这些数据显示在二维图形上,把每个参数(价格,房间数量)都做为 1 个维度:
线性回归
然后我们可以画一条线,并预测 更多房间的房屋出租价格。这种模型被称为线性回归,它是机器学习中最简单的模型之一。不过这个模型还不够好:
- 只有 5 个数据,所以不够可靠。
- 只有 2 个参数(价格,房间),但是还有更多可能会影响价格的因素:比如地区、装修情况等。
可以通过添加更多的数据来解决第一个问题,比如一百万个。对于第二个问题,可以添加更多维度。在二维图表中可以很容易理解数据并画一条线,在三维图中可以使用平面:
三维图中的平面
但是当数据的维度是三维呢四维甚至是 1000000 维的时候,大脑就没有办法在图表上对其进行可视化了,但是可以在维度超过三维时通过数学来计算超平面,而神经网络就是为了解决这个问题而生的。
什么是神经网络?
要解什么是神经网络,需要知道什么是神经元。真正的神经元看上去是这样的:
神经元
神经元由以下几部分组成:
- 树突:这是数据的输入端。
- 轴突:这是输出端。
- 突触(未在图中表示):该结构允许一个神经元与另一个神经元之间进行通信。它负责在轴突的神经末梢和附近神经元的树突之间传递电信号。这些突触是学习的关键,因为它们会根据用途增减电活动。
机器学习中的神经元(简化):
机器学习中的神经元
- Inputs(输入) :输入的参数。
- Weights(权重) :像突触一样,用来通过调节神经元更好的建立线性回归。
- Linear function(线性函数) :每个神经元就像一个线性回归函数,对于线性回归模型,只需要一个神经元够了。
- Activation function(激活函数) :可以用一些激活函数来将输出从标量改为另一个非线性函数。常见的有 sigmoid、RELU 和 tanh。
- Output(输出) :应用激活函数后的计算输出。
激活函数是非常有用的,神经网络的强大主要归功于它。假如没有任何激活功能,就不可能得到智能的神经元网络。因为尽管你的神经网络中有多个神经元,但神经网络的输出始终将是线性回归。所以需要一些机制来将各个线性回归变形为非线性的来解决非线性问题。通过激活函数可以将这些线性函数转换为非线性函数:
神经网络
训练模型
正如 2D 线性回归的例子所描述的,只需要在图中画一条线就可以预测新数据了。尽管如此,“深度学习”的思想是让我们的神经网络学会画这条线。对于一条简单的线,可以用只有一个神经元的非常简单的神经网络即可,但是对于想要做更复杂事情的模型,例如对两组数据进行分类这种操作,需要通过“训练”使网络学习怎样得到下面的内容:
分类问题
这个过程并不复杂,因为它是二维的。每个模型都用来描述一个世界,但是“训练”的概念在所有模型中都非常相似。第一步是绘制一条随机线,并在算法中通过迭代对其进行改进,每次迭代中过程中修正错误。这种优化算法名为 Gradient Descent(梯度下降)(有着相同概念的算法还有更复杂的 SGD 或 ADAM 等)。每种算法(线性回归,对数回归等)都有不同的成本函数来度量误差,成本函数会始终收敛于某个点。它可以是凸函数或凹函数,但是最终要收敛在 0% 误差的点上。我们的目标就是实现这一点。
凸函数和凹函数
当使用梯度下降算法时,先从其成本函数的某个随机点开始,但是我们不知道它究竟在什么地方!这就像你被蒙着眼睛丢在一座山上,想要下山的话必须一步一步地走到最低点。如果地形是不规则的(例如凹函数),则下降会更加复杂。
在这里不会深入解释“梯度下降”算法,只需要记住这是训练 AI 模型过程中最小化预测误差的优化算法就足够了。这种算法需要大量的时间和 GPU 进行矩阵乘法。通常在第一次执行时很难达到这个收敛点,因此需要修正一些超参数,例如学习率(learning rate)或添加正则化(regularization)。在梯度下降迭代之后,当误差接近 0% 时,会接近收敛点。这样就创建了模型,接下来就能够进行预测了。
进行预测
用 TensorFlow.js 训练模型
TensorFlow.js 提供了一种创建神经网络的简便方法。首先用 trainModel
方法创建一个 LinearModel
类。我们将使用顺序模型。顺序模型是其中一层的输出是下一层的输入的模型,即当模型拓扑是简单的层级结构,没有分支或跳过。在 trainModel
方法内部定义层(我们仅使用一层,因为它足以解决线性回归问题):
import * as tf from '@tensorflow/tfjs';
/**
* 线性模型类
*/
export default class LinearModel {
/**
* 训练模型
*/
async trainModel(xs, ys){
const layers = tf.layers.dense({
units: 1, // 输出空间的纬度
inputShape: [1], // 只有一个参数
});
const lossAndOptimizer = {
loss: 'meanSquaredError',
optimizer: 'sgd', // 随机梯度下降
};
this.linearModel = tf.sequential();
this.linearModel.add(layers); // 添加一层
this.linearModel.compile(lossAndOptimizer);
// 开始模型训练
await this.linearModel.fit(
tf.tensor1d(xs),
tf.tensor1d(ys),
);
}
//...
}
使用这个类进行训练:
const model = new LinearModel()
// xs 与 ys 是 数组成员(x-axis 与 y-axis)
await model.trainModel(xs, ys)
训练结束后就可以开始预测了。
用 TensorFlow.js 进行预测
尽管在训练模型时需要事先定义一些超参数,但是进行一般的预测还是很容易的。通过下面的代码就够了:
import * as tf from '@tensorflow/tfjs';
export default class LinearModel {
... //前面训练模型的代码
predict(value){
return Array.from(
this.linearModel
.predict(tf.tensor2d([value], [1, 1]))
.dataSync()
)
}
}
现在就可以预测了:
const prediction = model.predict(500) // 预测数字 500
console.log(prediction) // => 420.423
在 TensorFlow.js 中使用预训练的模型
训练模型是最难的部分。首先对数据进行标准化来进行训练,还需要正确的设定所有超参数等等。对于咱们初学者,可以直接用那些预先训练好的模型。TensorFlow.js 可以使用很多预训练的模型,还可以导入使用 TensorFlow 或 Keras 创建的外部模型。例如可以直接用 posenet 模型(实时人体姿态评估)做一些有意思的项目:
posenet Demo
? 这个 Demo 的代码:https://github.com/aralroca/posenet-d3
它用起来很容易:
import * as posenet from '@tensorflow-models/posenet'
// 设置一些常数
const imageScaleFactor = 0.5
const outputStride = 16
const flipHorizontal = true
const weight = 0.5
// 加载模型
const net = await posenet.load(weight)
// 进行预测
const poses = await net.estimateSinglePose(
imageElement,
imageScaleFactor,
flipHorizontal,
outputStride
)
这个 JSON 是 pose 变量:
{
"score": 0.32371445304906,
"keypoints": [
{
"position": {
"y": 76.291801452637,
"x": 253.36747741699
},
"part": "nose",
"score": 0.99539834260941
},
{
"position": {
"y": 71.10383605957,
"x": 253.54365539551
},
"part": "leftEye",
"score": 0.98781454563141
}
// 后面还有: rightEye, leftEar, rightEar, leftShoulder, rightShoulder
// leftElbow, rightElbow, leftWrist, rightWrist, leftHip, rightHip,
// leftKnee, rightKnee, leftAnkle, rightAnkle...
]
}
从官方的 demo 可以看得到,用这个模型可以开发出很多有趣的项目。
体感控制 ? 的游动
? 这个项目的源代码: https://github.com/aralroca/fishFollow-posenet-tfjs
导入 Keras 模型
可以把外部模型导入 TensorFlow.js。下面是一个用 Keras 模型(h5格式)进行数字识别的程序。首先要用 tfjs_converter 对模型的格式进行转换。
pip install tensorflowjs
使用转换器:
tensorflowjs_converter --input_format keras keras/cnn.h5 src/assets
最后,把模型导入到 JS 代码中:
// 载入模型
const model = await tf.loadModel('./assets/model.json')
// 准备图片
let img = tf.fromPixels(imageData, 1)
img = img.reshape([1, 28, 28, 1])
img = tf.cast(img, 'float32')
// 进行预测
const output = model.predict(img)
只需要几行代码行就完成了。当然还可以在代码中添加更多的逻辑来实现更多功能,例如可以把数字写在 canvas 上,然后得到其图像来进行预测。
识别数字
? 这个项目的源代码: https://github.com/aralroca/MNIST_React_TensorFlowJS
为什么要用在浏览器中?
由于设备的不同,在浏览器中训练模型时效率可能很低。用 TensorFlow.js 利用 WebGL 在后台训练模型,比用 Python 版的 TensorFlow 慢 1.5 ~ 2倍。
但是在 TensorFlow.js 之前,没有能直接在浏览器中使用机器学习模型的 API,现在则可以在浏览器应用中离线训练和使用模型。而且预测速度更快,因为不需要向服务器发送请求。另一个好处是成本低,因为所有这些计算都是在客户端完成的。
总结
- 模型是表示现实世界的一种简化方式,可以使用它来进行预测。
- 可以用神经网络创建模型。
- TensorFlow.js 是创建神经网络的简便工具。
- Pycharm常用技巧
- hdu 1598 find the most comfortable road(枚举+卡鲁斯卡尔最小生成树)
- 查询IP地址归属详情
- oracle commit详解
- hdu 4315 Climbing the Hill(阶梯博弈转nim博弈)
- iftop实时网络流量监控工具的安装使用
- hdu 3908 Triple(组合计数、容斥原理)
- hdu 4034 Graph (floyd的深入理解)
- hdu 4033Regular Polygon(二分+余弦定理)
- Debian8配置SSH允许root登陆
- hdu 4405Aeroplane chess(概率DP)
- hdu 3853LOOPS (概率DP)
- cf(#div1 B. Dreamoon and Sets)(数论)
- hdu 1805Expressions(二叉树构造的后缀表达式)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 打卡群刷题总结0923——完全平方数
- 打卡群刷题总结0924——最长上升子序列
- VS2017中使用QT Chart图表
- C++核心准则T.81:不要混用继承层级和数组
- C++核心准则T.83:不要将成员函数定义为模板虚函数
- C++核心准则T.84:使用非模板核心实现提供稳定的ABI接口
- C++核心准则T.120:只在确实有需要时使用模板元编程
- C++核心准则T.121:模板元编程主要用于模仿概念
- C++核心准则T.122:使用模板在编译时计算类型
- C++核心准则T.123:使用常量表达式函数在编译时求值
- Java基础 【类之间的关系】
- MySql 学习之路-基础
- (有趣的)项目实战:Java实现计算机自动关机
- 猜生日 Java小游戏
- KDD Cup 2020多模态召回比赛亚军方案与搜索业务应用