如何在keras中添加自己的优化器(如adam等)
本文主要讨论windows下基于tensorflow的keras
1、找到tensorflow的根目录
如果安装时使用anaconda且使用默认安装路径,则在 C:ProgramDataAnaconda3envstensorflow-gpuLibsite-packagestensorflow处可以找到(此处为GPU版本),cpu版本可在C:ProgramDataAnaconda3Libsite-packagestensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。
2、找到keras在tensorflow下的根目录
需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:ProgramDataAnaconda3envstensorflow-gpuLibsite-packagestensorflowpythonkeras
3、找到keras目录下的optimizers.py文件并添加自己的优化器
找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类
以本文来说,我在第718行添加如下代码
@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):
def __init__(self,
lr=0.002,
beta_1=0.9,
beta_2=0.999,
epsilon=None,
schedule_decay=0.004,
**kwargs):
super(Adamsss, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.m_schedule = K.variable(1., name='m_schedule')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.schedule_decay = schedule_decay
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [state_ops.assign_add(self.iterations, 1)]
t = math_ops.cast(self.iterations, K.floatx()) + 1
# Due to the recommendations in [2], i.e. warming momentum schedule
momentum_cache_t = self.beta_1 * (
1. - 0.5 *
(math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
momentum_cache_t_1 = self.beta_1 * (
1. - 0.5 *
(math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
m_schedule_new = self.m_schedule * momentum_cache_t
m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
self.updates.append((self.m_schedule, m_schedule_new))
shapes = [K.int_shape(p) for p in params]
ms = [K.zeros(shape) for shape in shapes]
vs = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + ms + vs
for p, g, m, v in zip(params, grads, ms, vs):
# the following equations given in [1]
g_prime = g / (1. - m_schedule_new)
m_t = self.beta_1 * m + (1. - self.beta_1) * g
m_t_prime = m_t / (1. - m_schedule_next)
v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
m_t_bar = (
1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
self.updates.append(state_ops.assign(m, m_t))
self.updates.append(state_ops.assign(v, v_t))
p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
config = {
'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon,
'schedule_decay': self.schedule_decay
}
base_config = super(Adamsss, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
然后修改之后的优化器调用类添加我自己的优化器adamss
需要修改的有(下面的两处修改依旧在optimizers.py内)
# Aliases.
sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam
以及
def deserialize(config, custom_objects=None):
"""Inverse of the `serialize` function.
Arguments:
config: Optimizer configuration dictionary.
custom_objects: Optional dictionary mapping
names (strings) to custom objects
(classes and functions)
to be considered during deserialization.
Returns:
A Keras Optimizer instance.
"""
if tf2.enabled():
all_classes = {
'adadelta': adadelta_v2.Adadelta,
'adagrad': adagrad_v2.Adagrad,
'adam': adam_v2.Adam,
'adamsss': adamsss_v2.Adamsss,
'adamax': adamax_v2.Adamax,
'nadam': nadam_v2.Nadam,
'rmsprop': rmsprop_v2.RMSprop,
'sgd': gradient_descent_v2.SGD
}
else:
all_classes = {
'adadelta': Adadelta,
'adagrad': Adagrad,
'adam': Adam,
'adamax': Adamax,
'nadam': Nadam,
'adamsss': Adamsss,
'rmsprop': RMSprop,
'sgd': SGD,
'tfoptimizer': TFOptimizer
}
这里我们并没有v2版本,所以if后面的部分不改也可以。
4、调用我们的优化器对模型进行设置
model.compile(loss = ‘crossentropy’, optimizer = ‘adamss’, metrics=[‘accuracy’])
5、训练模型
train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)
补充知识:keras设置学习率–优化器的用法
优化器的用法
优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:
from keras import optimizers
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)
你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。
# 传入优化器名称: 默认参数将被采用 model.compile(loss=’mean_squared_error’, optimizer=’sgd’)
以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考。
- win7下恢复“经典任务栏”/“快速启动栏”,关闭“窗口自动最大化”
- WCF和ASP.NET Web API 接口执行时间监控
- 额的神啊:AS3中Button被disable了,也会触发Click事件!
- [原创]CI持续集成系统环境---部署gerrit环境完整记录
- CentOS设置Mono环境变量
- 分布式监控系统Zabbix-3.0.3-完整安装记录(2)-添加mysql监控
- 从APM角度上看:NoSQL和关系数据库并无不同
- 事故记录-过多进程致使CPU卡死
- Flash/Flex学习笔记(54):迷你滚动条ScrollBar
- linux下正向代理/反向代理/透明代理使用说明
- 万达网科年底集体裁员?公司回应仅是业务调整
- 两个四字母域名均以五位数被交易
- Flash/Flex学习笔记(15):FMS 3.5之远程共享对象(Remote Shared Object)
- Android Fragment完全解析
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 线上问题分析之java dump文件生成
- python基础知识
- AtCoder Beginner Contest 177 A ~ E
- 2017 年ICPC 中国大陆区域赛铜牌题解
- 搜索(DFS BFS)专题练习
- AtCoder Beginner Contest 171
- AtCoder Beginner Contest 173 A ~ F(已经补完)
- AtCoder Beginner Contest 174 A ~ E
- AtCoder Beginner Contest 170
- 【队伍训练3】Codeforces Round #661 (Div. 3)
- 购物
- 指纹锁(自定义下比较重载下set的圆括号比较)
- 糖糖别胡说,我真的不是签到题目
- 分数线划定
- 小C的记事本