DQN三大改进(一)-Double DQN
Double-DQN原文:https://arxiv.org/pdf/1509.06461v3.pdf 代码地址:https://github.com/princewen/tensorflow_practice/tree/master/Double-DQN-demo
1、背景
这篇文章我们会默认大家已经了解了DQN的相关知识,如果大家对于DQN还不是很了解,可以参考文章https://www.jianshu.com/p/10930c371cac。
我们简单回顾一下DQN的过程(这里是2015版的DQN):
DQN中有两个关键的技术,叫做经验回放和双网络结构。
DQN中的损失函数定义为:
其中,yi也被我们称为q-target值,而后面的Q(s,a)我们称为q-eval值,我们希望q-target和q-eval值越接近越好。
q-target如何计算呢?根据下面的公式:
上面的两个公式分别截取自两篇不同的文章,所以可能有些出入。我们之前说到过,我们有经验池存储的历史经验,经验池中每一条的结构是(s,a,r,s'),我们的q-target值根据该轮的奖励r以及将s'输入到target-net网络中得到的Q(s',a')的最大值决定。
我们进一步展开我们的q-target计算公式:
也就是说,我们根据状态s'选择动作a'的过程,以及估计Q(s',a')使用的是同一张Q值表,或者说使用的同一个网络参数,这可能导致选择过高的估计值,从而导致过于乐观的值估计。为了避免这种情况的出现,我们可以对选择和衡量进行解耦,从而就有了双Q学习,在Double DQN中,q-target的计算基于如下的公式:
我们根据一张Q表或者网络参数来选择我们的动作a',再用另一张Q值表活着网络参数来衡量Q(s',a')的值。
2、代码实现
本文的代码还是根据莫烦大神的代码,它的github地址为:https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow
这里我们想要实现的效果类似于寻宝。
其中,红色的方块代表寻宝人,黑色的方块代表陷阱,黄色的方块代表宝藏,我们的目标就是让寻宝人找到最终的宝藏。
这里,我们的状态可以用横纵坐标表示,而动作有上下左右四个动作。使用tkinter来做这样一个动画效果。宝藏的奖励是1,陷阱的奖励是-1,而其他时候的奖励都为0。
接下来,我们重点看一下我们Double-DQN相关的代码。
定义输入
# ------------------------input---------------------------self.s = tf.placeholder(tf.float32, [None, self.n_features], name='s')self.q_target = tf.placeholder(tf.float32, [None, self.n_actions], name='Q-target')self.s_ = tf.placeholder(tf.float32,[None,self.n_features],name='s_')
定义双网络结构
这里我们的双网络结构都简单的采用简单的全链接神经网络,包含一个隐藏层。这里我们得到的输出是一个向量,表示该状态才取每个动作可以获得的Q值:
def build_layers(s,c_name,n_l1,w_initializer,b_initializer):
with tf.variable_scope('l1'):
w1 = tf.get_variable(name='w1',shape=[self.n_features,n_l1],initializer=w_initializer,collections=c_name)
b1 = tf.get_variable(name='b1',shape=[1,n_l1],initializer=b_initializer,collections=c_name)
l1 = tf.nn.relu(tf.matmul(s,w1)+b1)
with tf.variable_scope('l2'):
w2 = tf.get_variable(name='w2',shape=[n_l1,self.n_actions],initializer=w_initializer,collections=c_name)
b2 = tf.get_variable(name='b2',shape=[1,self.n_actions],initializer=b_initializer,collections=c_name)
out = tf.matmul(l1,w2) + b2 return out
接下来,我们定义两个网络:
# ------------------ build evaluate_net ------------------with tf.variable_scope('eval_net'):
c_names = ['eval_net_params',tf.GraphKeys.GLOBAL_VARIABLES]
n_l1 = 20
w_initializer = tf.random_normal_initializer(0,0.3)
b_initializer =tf.constant_initializer(0.1) self.q_eval = build_layers(self.s,c_names,n_l1,w_initializer,b_initializer)# ------------------ build target_net ------------------with tf.variable_scope('target_net'):
c_names = ['target_net_params', tf.GraphKeys.GLOBAL_VARIABLES] self.q_next = build_layers(self.s_, c_names, n_l1, w_initializer, b_initializer)
定义损失和优化器 接下来,我们定义我们的损失,和DQN一样,我们使用的是平方损失:
with tf.variable_scope('loss'): self.loss = tf.reduce_mean(tf.squared_difference(self.q_target,self.q_eval))
with tf.variable_scope('train'): self.train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
定义我们的经验池 我们使用一个函数定义我们的经验池,经验池每一行的长度为 状态feature * 2 + 2。
def store_transition(self,s,a,r,s_): if not hasattr(self, 'memory_counter'): self.memory_counter = 0
transition = np.hstack((s, [a, r], s_))
index = self.memory_counter % self.memory_size self.memory[index, :] = transition self.memory_counter += 1
选择action 我们仍然使用的是e-greedy的选择动作策略,即以e的概率选择随机动作,以1-e的概率通过贪心算法选择能得到最多奖励的动作a。
def choose_action(self,observation):
observation = observation[np.newaxis,:]
actions_value = self.sess.run(self.q_eval,feed_dict={self.s:observation})
action = np.argmax(actions_value) if np.random.random() > self.epsilon:
action = np.random.randint(0,self.n_actions) return action
选择数据batch 我们从经验池中选择我们训练要使用的数据。
if self.memory_counter > self.memory_size:
sample_index = np.random.choice(self.memory_size, size=self.batch_size)else:
sample_index = np.random.choice(self.memory_counter, size=self.batch_size)
batch_memory = self.memory[sample_index,:]
更新target-net 这里,每个一定的步数,我们就更新target-net中的参数:
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]if self.learn_step_counter % self.replace_target_iter == 0: self.sess.run(self.replace_target_op) print('ntarget_params_replacedn')
更新网络参数 根据Double DQN的做法,我们需要用两个网络的来计算我们的q-target值,同时通过最小化损失来更新网络参数。这里的做法是,根据eval-net的值来选择动作,然后根据target-net的值来计算Q值。
q_next,q_eval4next = self.sess.run([self.q_next, self.q_eval],
feed_dict={self.s_: batch_memory[:, -self.n_features:], self.s: batch_memory[:, -self.n_features:]})
q_next是根据经验池中下一时刻状态输入到target-net计算得到的q值,而q_eval4next是根据经验池中下一时刻状态s'输入到eval-net计算得到的q值,这个q值主要用来选择动作。
下面的动作用来得到我们batch中的实际动作和奖励
batch_index = np.arange(self.batch_size, dtype=np.int32)
eval_act_index = batch_memory[:, self.n_features].astype(int)
reward = batch_memory[:, self.n_features + 1]
接下来,我们就要来选择动作并计算该动作的q值了,如果是double dqn的话,我们是根据刚刚计算的q_eval4next来选择动作,然后根据q_next来得到q值的。而原始的dqn直接通过最大的q_next来得到q值:
if self.double_q:
max_act4next = np.argmax(q_eval4next, axis=1) # the action that brings the highest value is evaluated by q_eval
selected_q_next = q_next[batch_index, max_act4next] # Double DQN, select q_next depending on above actionselse:
selected_q_next = np.max(q_next, axis=1) # the natural DQN
那么我们的q-target值就可以计算得到了:
q_target = q_eval.copy()
q_target[batch_index, eval_act_index] = reward + self.gamma * selected_q_next
有了q-target值,我们就可以结合eval-net计算的q-eval值来更新网络参数了:
_, self.cost = self.sess.run([self.train_op, self.loss],
feed_dict={self.s: batch_memory[:, :self.n_features], self.q_target: q_target})self.cost_his.append(self.cost)self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_maxself.learn_step_counter += 1
3、参考文献
1、Double-DQN原文:https://arxiv.org/pdf/1509.06461v3.pdf 2、解析 DeepMind 采用双 Q 学习 (Double Q-Learning) 深度强化学习技术:https://www.jianshu.com/p/193ca0106aa5
- 经典Java面试题收集(二)
- 关于表联结方法(一)(r3笔记第57天)
- Go 语言读写 Excel 文档
- 关于索引的使用模式(r3笔记56天)
- 关于oracle中的半连接(r3笔记55天)
- 关于正则表达式第三篇(r3笔记第52天)
- 关于正则表达式第四篇(r3笔记第53天)
- 外部表简单总结(r3笔记第51天)
- 通过shell脚本监控sql执行频率(r3笔记第50天)
- 和Null有关的函数(r3笔记第48天)
- 关于查询转换的一些简单分析(二) (r3笔记第68天)
- 跨网络拷贝文件的简单实践(r3笔记第67天)
- 关于enq: TX - allocate ITL entry的问题分析(r3笔记第66天)
- Tensorflow学习:使用Tensorflow搭建深层网络分类器
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Android集成zxing扫码框架功能
- Android 实现抖音小游戏潜艇大挑战的思路详解
- 修改Android Studio 的 Logcat 缓冲区大小操作
- Android自定义View验证码输入框
- PHP生成随机字符串实例代码(字母+数字)
- Laravel框架中缓存的使用方法分析
- android实现搜索功能并将搜索结果保存到SQLite中(实例代码)
- Android实现全局右滑返回
- Android实现打地鼠小游戏
- android实现手机传感器调用
- Android实现接近传感器
- PHP 模拟登陆功能实例详解
- PHP判断一个变量是否为整数、正整数的方法示例
- android实现打地鼠游戏
- Yii框架连表查询操作示例