Main Content

Model-Based Reinforcement Learning Using Custom Training Loop

This example shows how to define a custom training loop for a model-based reinforcement learning (MBRL) algorithm. You can use this workflow to train an MBRL policy with your custom training algorithm using policy and value function representations from Reinforcement Learning Toolbox™ software.

In this example, you use transition models to generate more experiences while training a custom DQN [2] agent in a cart-pole environment. The algorithm used in this example is based on a model-based policy optimization algorithm (MBPO) [1]. The original MBPO algorithm trains an ensemble of stochastic models and a soft actor-critic (SAC) agent in tasks with continuous actions. In contrast, this example trains an ensemble of deterministic models and a DQN agent in a task with discrete actions. The following figure summarizes the algorithm used in this example.

The agent generates real experiences by interacting with the environment. These experiences are used to train a set of transition models, which are used to generate additional experiences. The training algorithm then uses both the real and generated experiences to update the agent policy.

Create Environment

For this example, a reinforcement learning policy is trained in a discrete cart-pole environment. The objective in this environment is to balance the pole by applying forces (actions) on the cart. Create the environment using the rlPredefinedEnv function. Fix the random generator seed for reproducibility. For more information on this environment, see Load Predefined Control System Environments.

rngSeed = 1;
env = rlPredefinedEnv('CartPole-Discrete');

Extract the observation and action specifications from the environment.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Obtain the number of observations (numObservations) and actions (numActions).

numObservations = obsInfo.Dimension(1);
numActions = numel(actInfo.Elements); % number of discrete actions, -10 or 10
numContinuousActions = 1; % force

Critic Construction

DQN is a value-based reinforcement learning algorithm that estimates the discounted cumulative reward using a critic. In this example, the critic network contains fullyConnectedLayer, and reluLayer layers.

qNetwork = [
    fullyConnectedLayer(24, 'Name','CriticStateFC2')
qNetwork = dlnetwork(qNetwork);

Create the critic representation using the specified neural network and options. For more information, see rlQValueRepresentation.

criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);
critic = rlQValueRepresentation(qNetwork,obsInfo,actInfo,'Observation',{'state'},criticOpts);

Create Transition Models

Model-based reinforcement learning uses transition models of the environment. The model usually consists of a transition function, a reward function, and a terminal state function.

  • The transition function predicts the next observation given the current observation and the action.

  • The reward function predicts the reward given the current observation, the action, and the next observation.

  • The terminal state function predicts the terminal state given the observation.

As shown in the following figure, this example uses three transition functions as an ensemble of transition models to generate samples without interacting with the environment. The true reward function and the true terminal state function are given in this example.

Define three neural networks for transition models. The neural network predicts the difference between the next observation and the current observation

numModels = 3;
transitionNetwork1 = createTransitionNetwork(numObservations, numContinuousActions);
transitionNetwork2 = createTransitionNetwork(numObservations, numContinuousActions);
transitionNetwork3 = createTransitionNetwork(numObservations, numContinuousActions);
transitionNetworkVector = [transitionNetwork1, transitionNetwork2, transitionNetwork3];

Create Experience Buffers

Create an experience buffer for storing agent experiences (observation, action, next observation, reward, and isDone).

myBuffer.bufferSize = 1e5;
myBuffer.bufferIndex = 0;
myBuffer.currentBufferLength = 0;
myBuffer.observation = zeros(numObservations,myBuffer.bufferSize);
myBuffer.nextObservation = zeros(numObservations,myBuffer.bufferSize);
myBuffer.action = zeros(numContinuousActions,1,myBuffer.bufferSize);
myBuffer.reward = zeros(1,myBuffer.bufferSize);
myBuffer.isDone = zeros(1,myBuffer.bufferSize);

Create a model experience buffer for storing the experiences generated by the models.

myModelBuffer.bufferSize = 1e5;
myModelBuffer.bufferIndex = 0;
myModelBuffer.currentBufferLength = 0;
myModelBuffer.observation = zeros(numObservations,myModelBuffer.bufferSize);
myModelBuffer.nextObservation = zeros(numObservations,myModelBuffer.bufferSize);
myModelBuffer.action = zeros(numContinuousActions,myModelBuffer.bufferSize);
myModelBuffer.reward = zeros(1,myModelBuffer.bufferSize);
myModelBuffer.isDone = zeros(1,myModelBuffer.bufferSize);

Configure Training

Configure the training to use the following options.

  • Maximum number of training episodes — 250

  • Maximum steps per training episode — 500

  • Discount factor — 0.99

  • Training termination condition — Average reward across 10 episodes reaches the value of 480

numEpisodes = 250;
maxStepsPerEpisode = 500;
discountFactor = 0.99;
aveWindowSize = 10;
trainingTerminationValue = 480;

Configure the model options.

  • Train transition models only after 2000 samples are collected.

  • Train the models using all experiences in the real experience buffer in each episode. Use a mini-batch size of 256.

  • The models generates trajectories with a length of 4 at the beginning of each episode.

  • The number of generated trajectories is numGenerateSampleIteration x numModels x miniBatchSize = 20 x 3 x 256 = 15360.

  • Use the same epsilon-greedy parameters as the DQN agent, except for the minimum epsilon value.

  • Use a minimum epsilon value of 0.1, which is higher than the value used for interacting with the environment. Doing so allows the model to generate more diverse data.

warmStartSamples = 2000;
numEpochs = 1;
miniBatchSize = 256;
horizonLength = 4;
epsilonMinModel = 0.1;
numGenerateSampleIteration = 20;
sampleGenerationOptions.horizonLength = horizonLength;
sampleGenerationOptions.numGenerateSampleIteration = numGenerateSampleIteration;
sampleGenerationOptions.miniBatchSize = miniBatchSize;
sampleGenerationOptions.numObservations = numObservations;
sampleGenerationOptions.epsilonMinModel = epsilonMinModel;

% optimizer options
velocity1 = [];
velocity2 = [];
velocity3 = [];
decay = 0.01;
momentum = 0.9;
learnRate = 0.0005;

Configure the DQN training options.

  • Use the epsilon greedy algorithm with an initial epsilon value is 1, a minimum value of 0.01, and a decay rate of 0.005.

  • Update the target network every 4 steps.

  • Set the ratio of the real experiences to generated experiences to 0.2:0.8 by setting RealRatio to 0.2. Setting RealRatio to 1.0 is the same as the model-free DQN.

  • Take 5 gradient steps at each environment step.

epsilon = 1;
epsilonMin = 0.01;
epsilonDecay = 0.005;
targetUpdateFrequency = 4;
realRatio = 0.2; % Set to 1 to run a standard DQN
numGradientSteps = 5;

Create a vector for storing the cumulative reward for each training episode.

episodeCumulativeRewardVector = [];

Create a figure for model training visualization using the hBuildFigureModel helper function.

[trainingPlotModel, lineLossTrain1, lineLossTrain2, lineLossTrain3, axModel] = hBuildFigureModel();

Create a figure for model validation visualization using the hBuildFigureModelTest helper function.

[testPlotModel, lineLossTest1, axModelTest] = hBuildFigureModelTest();

Create a figure for DQN agent training visualization using the hBuildFigure helper function.

[trainingPlot,lineReward,lineAveReward, ax] = hBuildFigure;

Train Agent

Train the agent using a custom training loop. The training loop uses the following algorithm. For each episode:

  1. Train the transition models.

  2. Generate experiences using the transition models and store the samples in the model experience buffer.

  3. Generate a real experience. To do so, generate an action using the policy, apply the action to the environment, and obtain the resulting observation, reward, and is-done values.

  4. Create a mini-batch by sampling experiences from both the experience buffer and the model experience buffer.

  5. Compute the target Q value.

  6. Compute the gradient of the loss function with respect to the critic representation parameters.

  7. Update the critic representation using the computed gradients.

  8. Update the training visualization.

  9. Terminate training if the critic is sufficiently trained.

targetCritic = critic;
modelTrainedAtleastOnce = false;
totalStepCt = 0;
start = tic;


for episodeCt = 1:numEpisodes
    if myBuffer.currentBufferLength > miniBatchSize && ...
            totalStepCt > warmStartSamples
        if realRatio < 1.0
            % 1. Train transition models.
            % Training three transition models
            [transitionNetworkVector(1),loss1,velocity1] = ...
            [transitionNetworkVector(2),loss2,velocity2] = ...
            [transitionNetworkVector(3),loss3,velocity3] = ...
            modelTrainedAtleastOnce = true;

            % Display the training progress
            d = duration(0,0,toc(start),'Format','hh:mm:ss');
            title(axModel,"Model Training Progress - Episode: "...
                + episodeCt + ", Elapsed: " + string(d))

            % 2. Generate experience using models.
            % Create numGenerateSampleIteration x horizonLength x numModels x miniBatchSize
            % ex) 20 x 4 x 3 x 256 = 61440 samples            
            myModelBuffer = generateSamples(myBuffer,myModelBuffer,...

    % Interact with environment and train agent.
    % Reset the environment at the start of the episode
    observation = reset(env);
    episodeReward = zeros(maxStepsPerEpisode,1);
    errorPreddiction = zeros(maxStepsPerEpisode,1);

    for stepCt = 1:maxStepsPerEpisode
        % 3. Generate an experience.
        totalStepCt = totalStepCt + 1;

        % Compute an action using the policy based on the current observation.
        if rand() < epsilon
            action = actInfo.usample;
            action = action{1};
            action = getAction(critic,{observation});
        % Udpate epsilon
        if totalStepCt > warmStartSamples
            epsilon = max(epsilon*(1-epsilonDecay),epsilonMin);

        % Apply the action to the environment and obtain the resulting
        % observation and reward.
        [nextObservation,reward,isDone] = step(env,action);

        % Check prediction
        dx = predict(transitionNetworkVector(1),...
        predictedNextObservation = observation + dx;
        errorPreddiction(stepCt) = ...
            sqrt(sum((nextObservation - predictedNextObservation).^2));

        % Store the action, observation, reward and is-done experience  
        myBuffer = storeExperience(myBuffer,observation,action,...

        episodeReward(stepCt) = reward;
        observation = nextObservation;

        % Train DQN agent
        for gradientCt = 1:numGradientSteps
            if myBuffer.currentBufferLength >= miniBatchSize && ...
               % 4. Sample minibatch from experience buffers.
               [sampledObservation,sampledAction,sampledNextObservation,sampledReward,sampledIsdone] = ...

               % 5. Compute target Q value.
                % Compute target Q value
                [targetQValues, MaxActionIndices] = getMaxQValue(targetCritic, ...

                % Compute target for nonterminal states
                targetQValues(~logical(sampledIsdone)) = sampledReward(~logical(sampledIsdone)) + ...
                % Compute target for terminal states
                targetQValues(logical(sampledIsdone)) = sampledReward(logical(sampledIsdone));

                lossData.batchSize = miniBatchSize;
                lossData.actInfo = actInfo;
                lossData.actionBatch = sampledAction;
                lossData.targetQValues = targetQValues;

               % 6. Compute gradients.
                criticGradient = gradient(critic,@criticLossFunction, ...

                % 7. Update the critic network using gradients.
                critic = optimize(critic,criticGradient);
        % Update target critic periodically
        if mod(totalStepCt, targetUpdateFrequency)==0
            targetCritic = critic;

        % Stop if a terminal condition is reached.
        if isDone
    end % End of episode

    % 8. Update the training visualization.
    episodeCumulativeReward = sum(episodeReward);
    episodeCumulativeRewardVector = cat(2,...
    movingAveReward = movmean(episodeCumulativeRewardVector,...
    title(ax, "Training Progress - Episode: " + episodeCt + ...
        ", Total Step: " + string(totalStepCt) + ...
        ", epsilon:" + string(epsilon))

    errorPreddiction = errorPreddiction(1:stepCt);

    % Display one step prediction error.
    title(axModelTest, ...
        "Model one-step prediction error - Episode: " + episodeCt + ...
        ", Error: " + string(mean(errorPreddiction)))

    % Display training progress every 10th episode
    if (mod(episodeCt,10) == 0)    
        fprintf("EP:%d, Reward:%.4f, AveReward:%.4f, Steps:%d, TotalSteps:%d, epsilon:%f, error model:%f\n",...

    % 9. Terminate training if the network is sufficiently trained.
    if max(movingAveReward) > trainingTerminationValue
EP:10, Reward:12.0000, AveReward:13.3333, Steps:18, TotalSteps:261, epsilon:1.000000, error model:3.786379
EP:20, Reward:11.0000, AveReward:20.1667, Steps:17, TotalSteps:493, epsilon:1.000000, error model:3.768267
EP:30, Reward:34.0000, AveReward:19.3333, Steps:40, TotalSteps:769, epsilon:1.000000, error model:3.763075
EP:40, Reward:20.0000, AveReward:13.8333, Steps:26, TotalSteps:960, epsilon:1.000000, error model:3.797021
EP:50, Reward:13.0000, AveReward:22.5000, Steps:19, TotalSteps:1192, epsilon:1.000000, error model:3.813097
EP:60, Reward:32.0000, AveReward:14.8333, Steps:38, TotalSteps:1399, epsilon:1.000000, error model:3.821042
EP:70, Reward:12.0000, AveReward:17.6667, Steps:18, TotalSteps:1630, epsilon:1.000000, error model:3.741603
EP:80, Reward:17.0000, AveReward:16.5000, Steps:23, TotalSteps:1873, epsilon:1.000000, error model:3.780144
EP:90, Reward:13.0000, AveReward:11.8333, Steps:19, TotalSteps:2113, epsilon:0.567555, error model:0.222689
EP:100, Reward:198.0000, AveReward:223.5000, Steps:204, TotalSteps:3631, epsilon:0.010000, error model:0.283726
EP:110, Reward:381.0000, AveReward:262.5000, Steps:387, TotalSteps:6600, epsilon:0.010000, error model:0.117766
EP:120, Reward:79.0000, AveReward:229.5000, Steps:85, TotalSteps:8887, epsilon:0.010000, error model:0.081134
EP:130, Reward:234.0000, AveReward:300.5000, Steps:240, TotalSteps:11798, epsilon:0.010000, error model:0.063376
EP:140, Reward:500.0000, AveReward:403.5000, Steps:500, TotalSteps:15562, epsilon:0.010000, error model:0.036053
EP:150, Reward:500.0000, AveReward:443.6667, Steps:500, TotalSteps:19598, epsilon:0.010000, error model:0.032433
EP:160, Reward:349.0000, AveReward:264.0000, Steps:355, TotalSteps:21980, epsilon:0.010000, error model:0.037416
EP:170, Reward:231.0000, AveReward:231.8333, Steps:237, TotalSteps:24324, epsilon:0.010000, error model:0.029361
EP:180, Reward:311.0000, AveReward:416.6667, Steps:317, TotalSteps:28417, epsilon:0.010000, error model:0.026569
EP:190, Reward:500.0000, AveReward:468.6667, Steps:500, TotalSteps:33092, epsilon:0.010000, error model:0.014980

Simulate Agent

To simulate the trained agent, first reset the environment.

obs0 = reset(env);
obs = obs0;

Enable the environment visualization, which is updated each time the environment step function is called.


For each simulation step, perform the following actions.

  1. Get the action by sampling from the policy using the getAction function.

  2. Step the environment using the obtained action value.

  3. Terminate if a terminal condition is reached.

actionVector = zeros(1,maxStepsPerEpisode);
obsVector = zeros(numObservations,maxStepsPerEpisode+1);
obsVector(:,1) = obs0;
for stepCt = 1:maxStepsPerEpisode
    % Select action according to trained policy.
    action = getAction(critic,{obs});
    % Step the environment.
    [nextObs,reward,isDone] = step(env,action);    

    obsVector(:,stepCt+1) = nextObs;
    actionVector(1,stepCt) = action;

    % Check for terminal condition.
    if isDone
    obs = nextObs;    

lastStepCt = stepCt;

Test Model

Test one of the models by predicting a next observation given a current observation and an action.

modelID = 3;
predictedObsVector = zeros(numObservations,lastStepCt);
obs = dlarray(obsVector(:,1),'CB');
predictedObsVector(:,1) = obs;
for stepCt = 1:lastStepCt
    obs = dlarray(obsVector(:,stepCt),'CB');
    action = dlarray(actionVector(1,stepCt),'CB');
    dx = predict(transitionNetworkVector(modelID),obs, action);
    predictedObs = obs + dx;
    predictedObsVector(:,stepCt+1) = predictedObs;    
predictedObsVector = predictedObsVector(:, 1:lastStepCt);
layOut = tiledlayout(4,1, 'TileSpacing', 'compact');
for i = 1:4
    errorPrediction = abs(predictedObsVector(i,1:lastStepCt) - obsVector(i,1:lastStepCt));
    line1 = plot(errorPrediction,'DisplayName', 'Absolute Error');
    title("observation "+num2str(i));
title(layOut,"Prediction Absolute Error")

The small absolute prediction error shows that the model is successfully trained to predict the next observation.


[1] Volodymyr Minh, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. “Playing Atari with Deep Reinforcement Learning.” ArXiv:1312.5602 [Cs]. December 19, 2013.

[2] Janner, Michael, Justin Fu, Marvin Zhang, and Sergey Levine. "When to trust your model: Model-based policy optimization." ArXiv:1907.08253 [Cs, Stat], November 5, 2019.

Related Topics