神经网络训练中回调函数的实用教程
磐创AI分享
作者 | Andre Ye
编译 | VK
来源 | Towards Data Science
❝回调函数是神经网络训练的重要组成部分 ❞
回调操作可以在训练的各个阶段执行,可能是在epoch之间,在处理一个batch之后,甚至在满足某个条件的情况下。回调可以利用许多创造性的方法来改进训练和性能,节省计算资源,并提供有关神经网络内部发生的事情的结论。
本文将详细介绍重要回调的基本原理和代码,以及创建自定义回调的过程。
ReduceLROnPlateau是Keras中默认包含的回调。神经网络的学习率决定了梯度的比例因子,因此过高的学习率会导致优化器超过最优值,而学习率过低则会导致训练时间过长。很难找到一个静态的、效果很好的、不变的学习率。
顾名思义,“降低高原学习率”就是在损失指标停止改善或达到稳定时降低学习率。一般学习率减少2到10倍,这有助于磨练参数的最佳值。
要使用ReduceLROnPlateau,必须首先创建回调对象。有四个参数很重要:
- monitor,它用来监视指标
- factor,它是新的学习率将被降低(乘以)的因子
- persistence,回调激活之前等待的停滞epoch数
- min_lr,它可以降低到的最小学习率。这可以防止不必要和不有益的减少。
from keras.callbacks import ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])
当使用model.fit时,可以指定回调参数。注意,这可以接受一个列表,因此可以安排多个回调。
LearningRateScheduler是ReduceLROnPlateau的另一种选择,它允许用户根据epoch来安排学习率。如果你知道(可能来自以前的研究或实验)网络的学习率在从epochs 1-10时应该是x, 在epochs 10-20是应该是y,LearningRateScheduler可以帮助实现这些变化。以上epoch的数字可以任意变化。
创建学习率调度器需要一个用户定义的函数,该函数将epoch和learning rate作为参数。返回对象应该是新的学习率。
from keras.callbacks import LearningRateScheduler
def scheduler(epoch, lr): #定义回调schedule
if lr < 0.001: return lr * 1.5 #如果lr太小,增加lr
elif epoch < 5: return lr #前五个epoch,不改变lr
elif epoch < 10: return lr * tf.math.exp(-0.1) #第五到第十个epoch,减少lr
else: return lr * tf.math.exp(-0.05) #第10个epoch之后,减少量更少
callback = LearningRateScheduler(scheduler) #创建回调对象
model.fit(X_train, y_train, epochs=15, callbacks=[callback])
然后,将其转换为Keras回调后,就可以用于模型的训练。这些调度程序非常有用,允许对网络进行控制,但建议在第一次训练网络时使用ReduceLROnPlateau,因为它更具适应性。然后,可以进行可视化模型,看是否能提供关于如何构造一个适当的LR调度器的相关想法。
此外,你可以同时使用ReduceLROnPlateau和LearningRateScheduler,例如,使用调度程序硬编码一些学习速率(例如在前10个epoch不更改),同时利用自适应能力,在高原上降低学习率以提升性能。
「EarlyStopping」可以非常有助于防止在训练模型时产生额外的冗余运行。冗余运行会导致高昂的计算成本。当网络在给定的时间段内没有得到改善时,网络完成训练并停止使用计算资源。与ReduceLROnPlateau类似,「EarlyStopping」需要monitor。
from keras.callbacks import EarlyStopping
callback = EarlyStopping(monitor='loss', patience=5)
model.fit(X_train, y_train, epochs=15, callbacks=[callback])
TerminateOnNaN有助于防止在训练中产生梯度爆炸问题,因为输入NaN会导致网络的其他部分发生爆炸。如果不采用TerminateOnNaN,Keras并不阻止网络的训练。另外,nan会导致对计算能力的需求增加。为了防止这些情况发生,添加TerminateOnNaN是一个很好的安全检查。
rom keras.callbacks import TerminateOnNaN
model.fit(X_train, y_train, epochs=15, callbacks = [TerminateOnNaN()])
由于许多原因,ModelCheckpoint可以以某种频率(也许每隔10个左右的epoch)保存模型的权重,因此它非常有用。
- 如果训练模型时突然中断,则不需要完全重新训练模型。
- 如果,比如说,在第30个epoch,模型开始显示出过拟合的迹象或其他问题,比如梯度爆炸,我们可以用最近保存的权重重新加载模型(比如在第25个epoch),并调整参数以避免该问题,而无需重新进行大部分训练。
- 能够提取某个epoch的权重并将其重新加载到另一个模型中有利于迁移学习。
在下面的场景中,ModelCheckpoint用于存储具有最佳性能的模型的权重。在每个epoch,如果模型比其他记录的epoch表现更好,则其权重存储在一个文件中(覆盖前一个的权重)。在训练结束时,我们使用model.load_weights进行加载.
from keras.callbacks import ModelCheckpoint
callback = ModelCheckpoint( #创建回调
filepath='/filepath/checkpoint', #告诉回调要存储权重的filepath在哪
save_weights_only=True, #只保留权重(更有效),而不是整个模型
monitor='val_acc', #度量
mode='max', #找出使度量最大化的模型权重
save_best_only=True #只保留最佳模型的权重(更有效),而不是所有的权重
)
model.fit(X_train, y_train, epochs=15, callbacks=[callback])
model.load_weights(checkpoint_filepath) #将最佳权重装入模型中。
或者,如果需要基于频率的保存(每5个epoch保存一次),请将save_freq设置为5
编写自定义回调是Keras包含的最好的特性之一,它允许执行高度特定的操作。但是,请注意,构造它比使用默认回调要复杂得多。
我们的自定义回调将采用类的形式。类似于在PyTorch中构建神经网络,我们可以继承keras.callbacks.Callback回调,它是一个基类。
我们的类可以有许多函数,这些函数必须具有下面列出的给定名称以及这些函数将在何时运行。例如,将在每个epoch开始时运行on_epoch_begin函数。下面是Keras将从自定义回调中读取的所有函数,但是可以添加其他“helper”函数。
class CustomCallback(keras.callbacks.Callback): #继承keras的基类
def on_train_begin(self, logs=None):
#日志是某些度量的字典,例如键可以是 ['loss', 'mean_absolute_error']
def on_train_end(self, logs=None): ...
def on_epoch_begin(self, epoch, logs=None): ...
def on_epoch_end(self, epoch, logs=None): ...
def on_test_begin(self, logs=None): ...
def on_test_end(self, logs=None): ...
def on_predict_begin(self, logs=None): ...
def on_predict_end(self, logs=None): ...
def on_train_batch_begin(self, batch, logs=None): ...
def on_train_batch_end(self, batch, logs=None): ...
def on_test_batch_begin(self, batch, logs=None): ...
def on_test_batch_end(self, batch, logs=None): ...
def on_predict_batch_begin(self, batch, logs=None): ...
def on_predict_batch_end(self, batch, logs=None): ...
根据函数的不同,你可以访问不同的变量。例如,在函数on_epoch_begin中,该函数既可以访问epoch编号,也可以访问当前度量、日志的字典。如果需要其他信息,比如学习率,可以使用keras.backend.get_value.
然后,可以像对待其他回调函数一样对待你自定义的回调函数。
model.fit(X_train, y_train, epochs=15, callbacks=[CustomCallback()])
自定义回调的一些常见想法:
- 在JSON或CSV文件中记录训练结果。
- 每10个epoch就通过电子邮件发送训练结果。
- 在决定何时保存模型权重或者添加更复杂的功能。
- 训练一个简单的机器学习模型(例如使用sklearn),通过将其设置为类变量并以(x: action, y: change)的形式获取数据,来学习何时提高或降低学习率。
当在神经网络中使用回调函数时,你的控制力增强,神经网络变得更容易拟合。
原文链接:https://towardsdatascience.com/a-short-practical-guide-to-callbacks-in-neural-network-training-3a4d69568aef
- spring cloud:config-server中@RefreshScope的"陷阱"
- JavaWeb(六)之MVC与三层架构设计
- 纸上谈兵: 最短路径与贪婪
- Java魔法堂:枚举类型详解
- 机器学习笔记(5):多类逻辑回归-手动添加隐藏层
- JavaWeb(五)之JSTL标签库
- spring cloud:Edgware.RELEASE版本中zuul回退方法的变化
- js中几种实用的跨域方法原理详解
- spring cloud:Edgware.RELEASE版本hystrix超时新坑
- JS魔法堂:再识ASCII实体、符号实体和字符实体
- MyBatis之传入参数——parameterType
- 被解放的姜戈07 马不停蹄
- 机器学习笔记(6):多类逻辑回归-使用gluon
- JS魔法堂:被玩坏的innerHTML、innerText、textContent和value属性
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- C语言的原子操作
- 【Linux】tmux命令使用教程
- 【Android 音视频开发打怪升级:FFmpeg音视频编解码篇】七、Android FFmpeg 视频编码
- 2020已经过去五分之四了,你确定还不来了解一下JS的rAF?
- c#任意进制转换
- 缺失值的处理方法(基于sklearn)
- 中国核酸数据库GSA数据提交指南
- 基于sklearn建立机器学习的pipeline
- Sentinle集群流控【源码笔记】
- ggplot坐标轴方向设置
- Scaling与Normalization的区别
- Android 音乐APP(一)扫描本地音乐
- Semaphore:如何快速实现一个限流器?
- 延迟初始化Spring Bean:延迟初始化的Bean会影响依赖注入吗?
- Android 音乐APP(二)启动白屏优化、定位当前播放歌曲