Custom DQN Algorithm Not Learning or Converging
5 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Hello MathWorks community,
I am currently working on implementing a custom Deep Q-Network (DQN) algorithm for a specific problem, but I am facing difficulties as the algorithm doesn't seem to learn or converge.
I have attached my code below for reference. I would appreciate it if someone could take a look and provide insights into why the algorithm might not be performing as expected. Additionally, if there are any improvements or modifications that could enhance its learning capabilities, I am open to suggestions.
clc, clear, close all
% environment:
rngSeed = 1;
rng(rngSeed);
env = rlPredefinedEnv("CartPole-Discrete");
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
numObservations = obsInfo.Dimension(1);
% build network:
QNetwork = [
featureInputLayer(obsInfo.Dimension(1))
fullyConnectedLayer(20)
reluLayer
fullyConnectedLayer(24)
reluLayer
fullyConnectedLayer(length(actInfo.Elements))];
QNetwork = dlnetwork(QNetwork);
% buffer
myBuffer.bufferSize = 1e5;
myBuffer.bufferIndex = 0;
myBuffer.currentBufferSize = 0;
myBuffer.observation = zeros(numObservations, myBuffer.bufferSize);
myBuffer.nextObservation = zeros(numObservations, myBuffer.bufferSize);
myBuffer.action = zeros(1, myBuffer.bufferSize);
myBuffer.reward = zeros(1, myBuffer.bufferSize);
myBuffer.isDone = zeros(1, myBuffer.bufferSize);
% parameters
num_episodes = 100;
max_steps = 500;
batch_size = 256;
discountFactor = 0.99;
epsilon = 1;
epsilonMin = 0.01;
epsilonDecay = 0.005;
totalSteps = 0;
numGradientSteps = 5;
targetUpdateFrequency = 4;
target_QNetwork = QNetwork;
iteration = 0;
% Plot
monitor = trainingProgressMonitor(Metrics="Loss",Info="Episode",XLabel="Iteration");
[trainingPlot,lineReward,lineAveReward, ax] = hBuildFigure;
set(trainingPlot,Visible = "on");
episodeCumulativeRewardVector = [];
aveWindowSize = 10;
% training loop
for episode = 1:num_episodes
observation = reset(env);
episodeReward = zeros(max_steps,1);
for stepCt = 1:max_steps
totalSteps = totalSteps + 1;
action = policy(QNetwork, observation', actInfo, epsilon);
if totalSteps > batch_size
epsilon = max(epsilon*(1-epsilonDecay), epsilonMin);
end
[nextObservation, reward, isDone] = step(env, action);
myBuffer = storeData(myBuffer, observation, action, nextObservation, reward, isDone);
episodeReward(stepCt) = reward;
observation = nextObservation;
for gradientCt = 1:numGradientSteps
if myBuffer.currentBufferSize >= batch_size
iteration = iteration + 1;
[sampledObservation, sampledAction, sampledNextObservation, sampledReward, sampledIsDone] = ...
sampleBatch(myBuffer, batch_size);
target_Q = zeros(1,batch_size);
Y = zeros(1,batch_size);
for i=1:batch_size
Y(i) = target_predict(target_QNetwork, dlarray(sampledNextObservation(:,i), 'CB'), actInfo);
if myBuffer.isDone(i)
target_Q(i) = myBuffer.reward(i);
else
target_Q(i) = myBuffer.reward(i) + discountFactor * Y(i);
end
end
lossData.batchSize = batch_size;
lossData.actInfo = actInfo;
lossData.actionBatch = sampledAction;
lossData.targetValues = target_Q;
% calculating gradient
[loss, gradients] = dlfeval(@QNetworkLoss, QNetwork, sampledObservation, lossData.targetValues,...
lossData);
% performing gradient descent
params = QNetwork.Learnables;
for i=1:6
params(i,3).Value{1} = params(i,3).Value{1} - 1e-3 .* gradients(i,3).Value{1};
end
QNetwork.Learnables = params;
recordMetrics(monitor,iteration,Loss=loss);
end
end
if mod(totalSteps, targetUpdateFrequency) == 0
target_QNetwork = QNetwork;
end
if isDone
break
end
end
episodeCumulativeReward = sum(episodeReward);
episodeCumulativeRewardVector = cat(2, episodeCumulativeRewardVector,episodeCumulativeReward);
movingAveReward = movmean(episodeCumulativeRewardVector, aveWindowSize,2);
addpoints(lineReward,episode,episodeCumulativeReward);
addpoints(lineAveReward,episode,movingAveReward(end));
title(ax, "Training Progress - Episode: " + episode + ", Total Step: " + string(totalSteps) + ", epsilon:" + ...
string(epsilon))
drawnow;
updateInfo(monitor,Episode=episode);
end
and here is the code for @QNetworkLoss:
function [loss, gradients] = QNetworkLoss(net, X, T, lossData)
batchSize = lossData.batchSize;
Z = repmat(lossData.actInfo.Elements', 1, batchSize);
actionIndicationMatrix = lossData.actionBatch(:,:) == Z;
Y = forward(net, X);
Y = Y(actionIndicationMatrix);
T = reshape(T,size(Y));
loss = mse(Y, T, 'DataFormat', 'CB');
gradients = dlgradient(loss, net.Learnables);
end
I have thoroughly reviewed my code and attempted various adjustments, but the desired convergence remains elusive. Any guidance, tips, or insights from experienced members would be highly appreciated.
0 comentarios
Respuestas (0)
Ver también
Categorías
Más información sobre Training and Simulation en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!