Reinforcement Learning Grid World multi-figures
조회 수: 4 (최근 30일)
이전 댓글 표시
Reinforcement Learning
2021년 2월 14일
댓글: Reinforcement Learning
2021년 2월 16일
Hello,
I did my own version of Grid World with my own obstacles (see Code below).
My Question ist: How can I simulate the trained agent in the enviroment in multiple figures?
I am using:
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(agent,env)
And getting one variation. I tried using:
for i=1:3
figure(i)
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(agent,env)
end
But it didn't work as planned.
Here my code for that. For some reason, I am getting spikes in the reward plot, although this already converged. I tried to tune some variables like LearnRate, Epsilon and DiscountFactor, but this is the best result I am getting of that:
GitterWelt = createGridWorld(7,7);
GitterWelt.CurrentState = '[1,1]';
GitterWelt.ObstacleStates = ["[5,3]";"[5,4]";"[5,5]";"[4,5]";"[3,5]"];
GitterWelt.TerminalStates = '[6,6]';
updateStateTranstionForObstacles(GitterWelt)
nS = numel(GitterWelt.States);
nA = numel(GitterWelt.Actions);
GitterWelt.R = -1*ones(nS,nS,nA);
GitterWelt.R(:,state2idx(GitterWelt,GitterWelt.TerminalStates),:) = 10;
env = rlMDPEnv(GitterWelt);
qTable = rlTable(getObservationInfo(env), getActionInfo(env));
qRep = rlQValueRepresentation(qTable, Obs_Info, Act_Info);
%% All trivial until here
qRep.Options.LearnRate = 0.2; % Alpha: This was in the example 1, but it doesn't make sense
Ag_Opts = rlQAgentOptions;
Ag_Opts.DiscountFactor = 0.9; % Gamma
Ag_Opts.EpsilonGreedyExploration.Epsilon = 0.02;
agent = rlQAgent(qRep,Ag_Opts);
Train_Opts = rlTrainingOptions;
Train_Opts.MaxEpisodes = 1000;
Train_Opts.MaxStepsPerEpisode = 40;
Train_Opts.StopTrainingCriteria = "AverageReward";
Train_Opts.StopTrainingValue = 10;
Train_Opts.Verbose = 1;
trainOpts.ScoreAveragingWindowLength = 30;
Train_Opts.Plots = "training-progress";
Train_Info = train(agent,env,Train_Opts);
댓글 수: 0
채택된 답변
Emmanouil Tzorakoleftherakis
2021년 2월 16일
Hello,
I wouldn't worry about the spikes as long as the average reward has converged. Could be the agent exploring something.
For your plotting question, the plot function for the gridworld environments has been set up with a listener callback so that it can be updated on the fly every time you call step. This means that you can only have one plot per grid world environment.
A quick workaround would be to create separate environment objects for the same grid world you created and call plot for each one. So:
function env = MyGridWorld
GitterWelt = createGridWorld(7,7);
GitterWelt.CurrentState = '[1,1]';
GitterWelt.ObstacleStates = ["[5,3]";"[5,4]";"[5,5]";"[4,5]";"[3,5]"];
GitterWelt.TerminalStates = '[6,6]';
updateStateTranstionForObstacles(GitterWelt)
nS = numel(GitterWelt.States);
nA = numel(GitterWelt.Actions);
GitterWelt.R = -1*ones(nS,nS,nA);
GitterWelt.R(:,state2idx(GitterWelt,GitterWelt.TerminalStates),:) = 10;
env = rlMDPEnv(GitterWelt);
end
and then
env1 = MyGridWorld;
env2 = MyGridWorld;
plot(env1)
plot(env2)
Hope that helps
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Training and Simulation에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!