How do I get DQN to output the policy I want
    4 visualizaciones (últimos 30 días)
  
       Mostrar comentarios más antiguos
    
I'm solving a problem with DQN. This environment currently has 10 optional moves, 8 states, and 20 rounds per run. I want to keep my problem variables to a minimum. The optima
0 comentarios
Respuestas (1)
  praguna manvi
 el 17 de Jul. de 2024
        Hi,  
Here is a sample code on how you could train a DQN agent with the above input, I am assuming a random “step function” and “reset function” for a simplified example: 
% Define your environment
numStates = 8;
numActions = 10;
% Define the observation and action spaces
obsInfo = rlNumericSpec([numStates 1]);
actInfo = rlFiniteSetSpec(1:numActions);
% Create the custom environment
env = rlFunctionEnv(obsInfo, actInfo, @myStepFunction, @myResetFunction);
% Define the DQN agent
statePath = [
    featureInputLayer(8, 'Normalization', 'none', 'Name', 'state')
    fullyConnectedLayer(24,'Name','fc1')
    reluLayer('Name','relu1')
    fullyConnectedLayer(24,'Name','fc2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(numActions,'Name','fc3')];
criticNetwork = dlnetwork(statePath);
criticOpts = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);
critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,...
    'Observation',{'state'},criticOpts);
agentOpts = rlDQNAgentOptions(...
    'SampleTime',1,...
    'DiscountFactor',0.99,...
    'ExperienceBufferLength',10000,...
    'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
% Train the agent
trainOpts = rlTrainingOptions(...
    'MaxEpisodes',20,...
    'MaxStepsPerEpisode',numStates,...
    'Verbose',false,...
    'Plots','training-progress');
trainingStats = train(agent,env,trainOpts);
% Define the step function
function [nextObs, reward, isDone, loggedSignals] = myStepFunction(action, loggedSignals)
    % step function logic here, calculating the next state
    nextObs = randi([1, 8], [8, 1]);
    reward = randi([-1, 1]);
    isDone = false;
end
% Define the reset function
function [initialObs, loggedSignals] = myResetFunction()
    % reset function logic here, I have used a random intial state
    initialObs = randi([1, 8], [8, 1]);
    loggedSignals = [];
end
For a detailed example please refer to this documentation on training a Custom PG Agent:
0 comentarios
Ver también
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

