mdl = "rlQubeServoModel";
obsInfo = rlNumericSpec([7 1]);
actInfo = rlNumericSpec([1 1],UpperLimit=1,LowerLimit=-1);
agentBlk = mdl + "/RL Agent";
simEnv = rlSimulinkEnv(mdl,agentBlk,obsInfo,actInfo);
numObs = prod(obsInfo.Dimension);
criticLayerSizes = [400 300];
actorLayerSizes = [400 300];
featureInputLayer(numObs)
fullyConnectedLayer(criticLayerSizes(1), ...
Weights=sqrt(2/numObs)*...
(rand(criticLayerSizes(1),numObs)-0.5), ...
Bias=1e-3*ones(criticLayerSizes(1),1))
fullyConnectedLayer(criticLayerSizes(2), ...
Weights=sqrt(2/criticLayerSizes(1))*...
(rand(criticLayerSizes(2),criticLayerSizes(1))-0.5), ...
Bias=1e-3*ones(criticLayerSizes(2),1))
fullyConnectedLayer(1, ...
Weights=sqrt(2/criticLayerSizes(2))* ...
(rand(1,criticLayerSizes(2))-0.5), ...
criticNetwork = dlnetwork(criticNetwork);
critic = rlValueFunction(criticNetwork,obsInfo);
prod(obsInfo.Dimension), ...
prod(actInfo.Dimension), ...
tanhLayer(Name="tanhMean");
fullyConnectedLayer(prod(actInfo.Dimension));
scalingLayer(Name="scale", ...
Scale=actInfo.UpperLimit)
tanhLayer(Name="tanhStdv");
fullyConnectedLayer(prod(actInfo.Dimension));
softplusLayer(Name="splus")
net = addLayers(net,inPath);
net = addLayers(net,meanPath);
net = addLayers(net,sdevPath);
net = connectLayers(net,"infc","tanhMean/in");
net = connectLayers(net,"infc","tanhStdv/in");
actor = rlContinuousGaussianActor(net, obsInfo, actInfo, ...
ActionMeanOutputNames="scale",...
ActionStandardDeviationOutputNames="splus",...
ObservationInputNames="netOin");
actorOpts = rlOptimizerOptions(LearnRate=1e-4);
criticOpts = rlOptimizerOptions(LearnRate=1e-4);
agentOpts = rlPPOAgentOptions(...
ExperienceHorizon=600,...
EntropyLossWeight=0.01,...
ActorOptimizerOptions=actorOpts,...
CriticOptimizerOptions=criticOpts,...
AdvantageEstimateMethod="gae",...
agent = rlPPOAgent(actor,critic,agentOpts);
trainOpts = rlTrainingOptions(...
MaxStepsPerEpisode=600,...
Plots="training-progress",...
StopTrainingCriteria="AverageReward",...
StopTrainingValue=430,...
ScoreAveragingWindowLength=100);
trainingStats = train(agent, simEnv, trainOpts);