理解keras中的sequential模型
keras中的主要数据结构是model(模型),它提供定义完整计算图的方法。通过将图层添加到现有模型/计算图,我们可以构建出复杂的神经网络。
Keras有两种不同的构建模型的方法:
- Sequential models
- Functional API
本文将要讨论的就是keras中的Sequential模型。
理解Sequential模型
Sequential模型字面上的翻译是顺序模型,给人的第一感觉是那种简单的线性模型,但实际上Sequential模型可以构建非常复杂的神经网络,包括全连接神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、等等。这里的Sequential更准确的应该理解为堆叠,通过堆叠许多层,构建出深度神经网络。
如下代码向模型添加一个带有64个大小为3 * 3的过滤器的卷积层:
from keras.models import Sequential
from keras.layers import Dense, Activation,Conv2D,MaxPooling2D,Flatten,Dropoutmodel = Sequential()
model.add(Conv2D(64, (3, 3), activation='relu'))
Sequential模型的核心操作是添加layers(图层),以下展示如何将一些最流行的图层添加到模型中:
- 卷积层
model.add(Conv2D(64, (3, 3), activation='relu'))
- 最大池化层
model.add(MaxPooling2D(pool_size=(2, 2)))
- 全连接层
model.add(Dense(256, activation='relu'))
- dropout
model.add(Dropout(0.5))
- Flattening layer(展平层)
model.add(Flatten())
基本的Sequential模型开发流程
从我们所学习到的机器学习知识可以知道,机器学习通常包括定义模型、定义优化目标、输入数据、训练模型,最后通常还需要使用测试数据评估模型的性能。keras中的Sequential模型构建也包含这些步骤。
首先,网络的第一层是输入层,读取训练数据。为此,我们需要指定为网络提供的训练数据的大小,这里input_shape参数用于指定输入数据的形状:
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
上面的代码中,输入层是卷积层,其获取224 224 3的输入图像。
接下来就是为模型添加中间层和输出层,请参考上面一节的内容,这里不赘述。
然后,进入最重要的部分: 选择优化器(如rmsprop或adagrad)并指定损失函数(如categorical_crossentropy)来指定反向传播的计算方法。在keras中,Sequential模型的compile方法用来完成这一操作。例如,在下面的这一行代码中,我们使用’rmsprop’优化器,损失函数为’binary_crossentropy’。
model.compile(loss='binary_crossentropy',
optimizer='rmsprop')
到这一步,我们创建了模型,接下来就是调用fit函数将数据提供给模型。这里还可以指定批次大小(batch size)、迭代次数、验证数据集等等。其中批次大小、迭代次数需要根据数据规模来确定,并没有一个固定的最优值。
model.fit(x_train, y_train, batch_size=32, epochs=10,validation_data=(x_val, y_val))
最后,使用evaluate方法来评估模型:
score = model.evaluate(x_test,y_test,batch_size = 32)
以上就是在Keras中使用Sequential模型的基本构建块,相对于tensorflow,keras的代码更少,接口更加清晰,更重要的是,keras的后端框架切(比如从tensorflow切换到Theano)换后,我们的代码不需要做任何修改。
使用Sequential模型解决线性回归问题
谈到tensorflow、keras之类的框架,我们的第一反应通常是深度学习,其实大部分的问题并不需要深度学习,特别是在数据规模较小的情况下,一些机器学习算法就可以解决问题。除了构建深度神经网络,keras也可以构建一些简单的算法模型,下面以线性学习为例,说明使用keras解决线性回归问题。
线性回归中,我们根据一些数据点,试图找出最拟合各数据点的直线。为了说明这一问题,我们创建100个数据点,然后通过回归找出拟合这100个数据点的直线。
- 创建训练数据
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as nptrX = np.linspace(-1, 1, 101)
trY = 3 * trX + np.random.randn(*trX.shape) * 0.33
上面这段代码创中,TrainX的值在-1和1之间均匀分布,而TrainY的值为TrainX的三倍,但增加了一些随机扰动。
- 创建模型
model = Sequential()
model.add(Dense(input_dim=1, output_dim=1, init='uniform', activation='linear'))
代码创建一个Sequential模型,这里使用了一个采用线性激活的全连接(Dense)层。它实际上封装了输入值x乘以权重w,加上偏置(bias)b,然后进行线性激活以产生输出。
我们可以查看默认初始化的权重和偏置值:
weights = model.layers[0].get_weights()
w_init = weights[0][0][0]
b_init = weights[1][0]
print('Linear regression model is initialized with weights w: %.2f, b: %.2f' % (w_init, b_init))
- 选择优化器和损失函数
model.compile(optimizer='sgd', loss='mse')
选择简单的梯度递减优化算法,损失函数选择均方差(mean squared error, mse)。
- 训练模型
model.fit(trX, trY, nb_epoch=200, verbose=1)
训练完毕之后,我们可以再看看权重值和偏置值
weights = model.layers[0].get_weights()
w_final = weights[0][0][0]
b_final = weights[1][0]
print('Linear regression model is trained to have weight w: %.2f, b: %.2f' % (w_final, b_final))
最后的结果为
Linear regression model is trained to have weight w: 2.94, b: 0.08
可以看到,进行200次迭代之后,权重值现在非常接近3。我们可以尝试修改迭代次数,看看不同迭代次数下得到的权重值。
这段例子仅仅作为一个简单的示例,所以没有做模型评估,有兴趣的同学可以构建测试数据自己尝试一下。
总结
keras中的Sequential模型其实非常强大,而且接口简单易懂,大部分情况下,我们只需要使用Sequential模型即可满足需求。在某些特别的场合,可能需要更复杂的模型结构,这时就需要Functional API,在后面的教程中,我将探讨Functional API。
参考
- Keras tutorial: Practical guide from getting started to developing complex deep neural network
- Getting started with the Keras Sequential model
- 应用自然语言处理(NLP)解码电影
- 不引入新的数组,实现数组元素交换位置函数
- (30) 剖析StringBuilder / 计算机程序的思维逻辑
- Java初始化顺序
- ConcurrentHashMap使用示例
- (40) 剖析HashMap / 计算机程序的思维逻辑
- nginx配置https(亲测可用)
- linux中无 conio.h的解决办法
- 运用适配器模式应对项目中的变化
- 开车啦!小爬虫抓取今日头条街拍美女图
- C语言中随机数相关问题
- 算法决策兴起:人工智能时代的若干伦理问题及策略|AI观察
- Win10配置人工智能学习平台Tensorflow的正确姿势
- mysql left( right ) join使用on 与where 筛选的差异
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- keras实现图像预处理并生成一个generator的案例
- Django+RestFramework API接口及接口文档并返回json数据操作
- Yii2框架实现利用mpdf创建pdf文件功能示例
- PHP超低内存遍历目录文件和读取超大文件的方法
- PHP bin2hex()函数基础实例讲解
- Kears 使用:通过回调函数保存最佳准确率下的模型操作
- django form和field具体方法和属性说明
- 总结PHP中初始化空数组的最佳方法
- tensorflow使用CNN分析mnist手写体数字数据集
- PHP7 mongoDB扩展使用的方法分享
- 主流开源分布式图数据库 Benchmark
- PHP封装的简单连接MongoDB类示例
- 基于Tensorflow的MNIST手写数字识别分类
- Yii框架ACF(accessController)简单权限控制操作示例
- tensorflow 动态获取 BatchSzie 的大小实例