Borrar filtros
Borrar filtros

Visualize policy in Grid World

1 visualización (últimos 30 días)
GCats
GCats el 14 de Feb. de 2022
Respondida: Aman el 8 de Feb. de 2024
hi all!
I would like to visualize the policy followed by my q-agent in the Grid World environment in a way similar to this (https://www.datascienceblog.net/post/reinforcement-learning/mdps_dynamic_programming/):
Does anyone have any clue on how to do this?
Attached my code:
GW = createGridWorld(9,9);
GW.CurrentState = '[1,9]';
GW.TerminalStates = '[8,3]';
GW.ObstacleStates = ["[2,4]";"[2,5]";"[2,6]";"[2,7]"; "[2,8]"; "[4,2]"; "[4,3]"; "[5,2]"; "[5,3]" ;"[5,9]"; "[5,6]"; "[5,7]"; "[5,8]"; "[8,2]"; "[7,2]"; "[7,3]"; "[7,4]"; "[8,4]"];
updateStateTranstionForObstacles(GW)
% GW.T(state2idx(GW,"[2,4]"),:,:) = 0;
% GW.T(state2idx(GW,"[2,4]"),state2idx(GW,"[4,4]"),:) = 1;
nS = numel(GW.States);
nA = numel(GW.Actions);
GW.R = -1*ones(nS,nS,nA);
% GW.R(:,state2idx(GW,"[4,4]"),:) = -5;
% GW.R(state2idx(GW,"[4,4]"),state2idx(GW,"[4,5]"),:) = -5;
GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;
env = rlMDPEnv(GW)
env.ResetFcn = @() 73; %begins at 73rd cell of the grid
rng(0)
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);
%training
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 500;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 11;
trainOpts.ScoreAveragingWindowLength = 30;
doTraining = true;
if doTraining
% Train the agent.
trainingStats = train(qAgent,env,trainOpts);
else
% Load the pretrained agent for the example.
load('basicGWQAgent.mat','qAgent')
end
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(qAgent,env)
Thanks in advance!

Respuestas (1)

Aman
Aman el 8 de Feb. de 2024
Hi GCats,
As per my understanding, you are trying to generate a custom plot for the "GridWorld" and are facing issues while doing so.
The plot of "GridWorld" has circles and rectangles patched, so it would be difficult to update them to get your desired customized "GridWorld" plot.
In order to achieve the desired "GridWorld," you can create a grid, followed by filling up the rectangles with the desired colors and symbols. Refer to the below code, where I have created a grid and put 'X' in a red-colored cell, which is my target cell, and you can extend this to achieve your desired result.
n = 9;
figure;
hold on;
targetXPos = 8;
targetYPos = 3;
rectangle('Position',[targetYPos,n-targetXPos+1,1,1],'FaceColor','r','EdgeColor','k');
for i = 1:n
for j = 1:n
rectangle('Position', [j, n-i+1, 1, 1], 'EdgeColor', 'k');
end
end
text(targetYPos+0.5,n-targetXPos+1.5,'X','HorizontalAlignment', 'center');
hold off;
axis([1 n+1 1 n+1]);
axis equal;
axis off;
Hopefully, this will help you!

Productos


Versión

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by