Borrar filtros
Borrar filtros

How to save pretrained DQN agent and extract the weights inside the network?

36 visualizaciones (últimos 30 días)
Kuan Yi Li
Kuan Yi Li hace alrededor de 20 horas
Respondida: praguna manvi hace alrededor de 15 horas
The following is part of the program. I want to know how to extract the weight values from the trained DQN network.
DQNnet = [
imageInputLayer([1 520 1],"Name","ImageFeatureInput","Normalization","none")
fullyConnectedLayer(1024,"Name","fc1")
reluLayer("Name","relu1")
% fullyConnectedLayer(512,"Name","fc2")
% reluLayer("Name","relu2")
fullyConnectedLayer(14,"Name","fc3")
softmaxLayer("Name","softmax")
classificationLayer("Name","ActionOutput")];
ObsInfo = getObservationInfo(env);
ActInfo = getActionInfo(env);
DQNOpts = rlRepresentationOptions('LearnRate',0.0001,'GradientThreshold',1,'UseDevice','gpu');
DQNagent = rlQValueRepresentation(DQNnet,ObsInfo,ActInfo,'Observation',{'ImageFeatureInput'},'ActionInputNames',{'BoundingBox Actions'},DQNOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',true ...
,'MiniBatchSize',256);
agentOpts.EpsilonGreedyExploration.Epsilon = 1;
agent = rlDQNAgent(DQNagent,agentOpts);
%% Agent Training
% Training options
trainOpts = rlTrainingOptions(...
'MaxEpisodes', 100, ...
'MaxStepsPerEpisode', 100, ...
'Verbose', true, ...
'Plots','training-progress',...
'ScoreAveragingWindowLength',400,...
'StopTrainingCriteria','AverageSteps',...
'StopTrainingValue',1000000000,...
'SaveAgentDirectory', pwd + "\agents\");
% Agent training
trainingStats = train(agent,env,trainOpts);

Respuestas (1)

praguna manvi
praguna manvi hace alrededor de 15 horas
For saving and loading pretrained “DQN” agent, you could use “load” and “save” functions refer: https://www.mathworks.com/matlabcentral/answers/712518-how-to-save-and-use-the-pre-trained-dqn-agent-in-the-reinforcement-learning-tool-box?s_tid=prof_contriblnk
To extract weights from the saved agent you can use “getLearnableParameters“ function refer: https://www.mathworks.com/matlabcentral/answers/513136-how-can-i-extract-a-trained-rl-agent-s-network-s-weights-and-biases

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by