MATLAB借助openai gym环境训练强化学习模型
时间:2022-07-22
本文章向大家介绍MATLAB借助openai gym环境训练强化学习模型,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
虽然openai的gym强化学习环境底层绘图库是pyglet,不太方便自定义,但是已有的环境还是很好用的,有了前面的python环境准备之后,只需要安装gym就可以
pip install gym
这样就可以使用这三个大类的环境了
algorithmic
toy_text
classic_control
我们感兴趣的是classic_control,涉及物理环境,不需要在MATLAB中重新建模
这里我们在gym的MountainCar环境中训练
首先建立环境
classdef MountainCarEnv < rl.env.MATLABEnvironment
%MountainCarEnv: matlab的MountainCar环境.
%% 属性设置
properties
show=true;
% pygame环境对象
p
% 初始状态
State
end
properties(Access = protected)
% 结束标记
IsDone = false
end
%% 必须的方法
methods
% 构造方法
function this = MountainCarEnv()
% 初始设置观察状态
ObservationInfo = rlNumericSpec([1 2]);
% 设置动作
ActionInfo = rlFiniteSetSpec(1:3);
% 继承系统环境
this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);
% 初始化、设置
this.State=[0 0];
this.p=py.gym.make('MountainCar-v0');
this.p.reset();
notifyEnvUpdated(this);
end
% 一次动作的效果
function [Observation,Reward,IsDone,LoggedSignals] = step(this,action)
LoggedSignals = [];
act = py.int(action-1);
% 计算reward
temp = cell(this.p.step(act));
Observation = double(temp{1,1});
IsDone = temp{1,3};
Reward=(1+Observation(1))^2;
if Observation(1)>=0.5
Reward=1000;
end
this.State = Observation;
this.IsDone = IsDone;
notifyEnvUpdated(this);
end
% 环境重置
function InitialObservation = reset(this)
this.p.reset();
InitialObservation =[0 0];
this.State = InitialObservation;
notifyEnvUpdated(this);
end
end
%% 可选函数、为了方便自行添加的
methods
% 收到绘图通知开始绘图的方法
function isDone=is_done(this)
% 设置是否需要绘图
isDone = this.IsDone;
end
end
methods (Access = protected)
% 收到绘图通知开始绘图的方法
function envUpdatedCallback(this)
% 设置是否需要绘图
if this.show
this.p.render();
end
end
end
end
接下来就是建立强化学习网络模型
%% 读取环境
ccc
env = MountainCarEnv;
% 获取可观察的状态
obsInfo = getObservationInfo(env);
% 获取可观察的状态维度
numObservations = obsInfo.Dimension(2);
% 获取可执行的动作
actInfo = getActionInfo(env);
% 获取可执行的动作维度
numActions = actInfo.Dimension(1);
rng(0)
%% 初始化agent
statePath = [
imageInputLayer([1 numObservations 1],'Normalization','none','Name','state')
fullyConnectedLayer(24,'Name','CriticStateFC1')
reluLayer('Name','CriticRelu1')
fullyConnectedLayer(24,'Name','CriticStateFC3')];
actionPath = [
imageInputLayer([numActions 1 1],'Normalization','none','Name','action')
fullyConnectedLayer(24,'Name','CriticActionFC1')];
commonPath = [
additionLayer(2,'Name','add')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(1,'Name','output')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);
criticNetwork = connectLayers(criticNetwork,'CriticStateFC3','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
% figure
% plot(criticNetwork)
criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);
critic = rlRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'state'},'Action',{'action'},criticOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',false, ...
'TargetUpdateMethod',"periodic", ...
'TargetUpdateFrequency',4, ...
'ExperienceBufferLength',10000, ...
'DiscountFactor',0.99, ...
'MiniBatchSize',128);
agent = rlDQNAgent(critic,agentOpts);
%% 设置训练参数
trainOpts = rlTrainingOptions(...
'MaxEpisodes', 500, ...
'MaxStepsPerEpisode', 200, ...
'Verbose', false, ...
'Plots','training-progress',...
'StopTrainingCriteria','AverageReward',...
'StopTrainingValue',1000);
%% 训练
% env.show=false;
trainingStats = train(agent,env,trainOpts);
%% 结果展示
env.show=true;
simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward);
- Java基础-20(01)总结,递归,IO流
- 一个Oracle bug的手工修复(r6笔记第59天)
- 由drop datafile导致的oracle bug(r6笔记第56天)
- Java中static关键字的作用
- Java基础-20(02)总结,递归,IO流
- Hive四种数据导入方式
- 34c3 部分Web Writeup
- 原来Oracle也不喜欢“蜀黍"(r6笔记第54天)
- Java基础19(01)总结IO流,异常try…catch,throws,File类
- 使用shell生成orabbix自动化配置脚本(r6笔记第53天)
- 现在 tensorflow 和 mxnet 很火,是否还有必要学习 scikit-learn 等框架?
- 数据的标准化与中心化以及R语言中的scale详解
- Java基础19(02)总结IO流,异常try…catch,throws,File类
- HTML5 — header
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 不要被kafka的异步模式欺骗了
- 给你总结几个ES下最容易踩的坑
- ES系列之利用filter让你的查询效率飞起来
- ES主分片和副本数据大小不一样的情况
- 关于kibana的可视化可能都在这篇文章里了
- ES分页看这篇就够了
- ES系列之原来查看文档数量有这么多姿势
- ES系列之嵌套文档和父子文档
- ES系列之一文带你避开日期类型存在的坑
- ES系列之原来ES的聚合统计不准确啊
- fastjson远程代码执行漏洞问题分析
- 数据库连接池的原理没你想得这么复杂
- 你真的会用volatile吗
- 你真的了解LinkedBlockingQueue的put,add和offer的区别吗
- 关于Java使用groupingBy分组数据乱序问题