# Train DQN Agent to Balance Cart-Pole System

This example shows how to train a deep Q-learning network (DQN) agent to balance a cart-pole system modeled in MATLAB®.

For more information on DQN agents, see Deep Q-Network Agents. For an example that trains a DQN agent in Simulink®, see Train DQN Agent to Swing Up and Balance Pendulum.

### Cart-Pole MATLAB Environment

The reinforcement learning environment for this example is a pole attached to an unactuated joint on a cart, which moves along a frictionless track. The training goal is to make the pole stand upright without falling over.

For this environment:

• The upward balanced pole position is `0` radians, and the downward hanging position is `pi` radians.

• The pole starts upright with an initial angle between –0.05 and 0.05 radians.

• The force action signal from the agent to the environment is from –10 to 10 N.

• The observations from the environment are the position and velocity of the cart, the pole angle, and the pole angle derivative.

• The episode terminates if the pole is more than 12 degrees from vertical or if the cart moves more than 2.4 m from the original position.

• A reward of +1 is provided for every time step that the pole remains upright. A penalty of –5 is applied when the pole falls.

### Create Environment Interface

Create a predefined environment interface for the system.

`env = rlPredefinedEnv("CartPole-Discrete")`
```env = CartPoleDiscreteAction with properties: Gravity: 9.8000 MassCart: 1 MassPole: 0.1000 Length: 0.5000 MaxForce: 10 Ts: 0.0200 ThetaThresholdRadians: 0.2094 XThreshold: 2.4000 RewardForNotFalling: 1 PenaltyForFalling: -5 State: [4x1 double] ```

The interface has a discrete action space where the agent can apply one of two possible force values to the cart, –10 or 10 N.

Get the observation and action specification information.

`obsInfo = getObservationInfo(env)`
```obsInfo = rlNumericSpec with properties: LowerLimit: -Inf UpperLimit: Inf Name: "CartPole States" Description: "x, dx, theta, dtheta" Dimension: [4 1] DataType: "double" ```
`actInfo = getActionInfo(env)`
```actInfo = rlFiniteSetSpec with properties: Elements: [-10 10] Name: "CartPole Action" Description: [0x0 string] Dimension: [1 1] DataType: "double" ```

Fix the random generator seed for reproducibility.

`rng(0)`

### Create DQN Agent

A DQN agent approximates the long-term reward, given observations and actions, using a value-function critic.

DQN agents can use multi-output Q-value critic approximators, which are generally more efficient. A multi-output approximator has observations as inputs and state-action values as outputs. Each output element represents the expected cumulative long-term reward for taking the corresponding discrete action from the state indicated by the observation inputs.

To create the critic, first create a deep neural network with one input (the 4-dimensional observed state) and one output vector with two elements (one for the 10 N action, another for the –10 N action). For more information on creating value-function representations based on a neural network, see Create Policies and Value Functions.

```dnn = [ featureInputLayer(obsInfo.Dimension(1),'Normalization','none','Name','state') fullyConnectedLayer(24,'Name','CriticStateFC1') reluLayer('Name','CriticRelu1') fullyConnectedLayer(24, 'Name','CriticStateFC2') reluLayer('Name','CriticCommonRelu') fullyConnectedLayer(length(actInfo.Elements),'Name','output')]; dnn = dlnetwork(dnn);```

View the network configuration.

```figure plot(layerGraph(dnn))```

Specify some training options for the critic optimizer using `rlOptimizerOptions`.

`criticOpts = rlOptimizerOptions('LearnRate',0.001,'GradientThreshold',1);`

Create the critic representation using the specified neural network and options. For more information, see `rlVectorQValueFunction`.

`critic = rlVectorQValueFunction(dnn,obsInfo,actInfo);`

To create the DQN agent, first specify the DQN agent options using `rlDQNAgentOptions`.

```agentOpts = rlDQNAgentOptions(... 'UseDoubleDQN',false, ... 'TargetSmoothFactor',1, ... 'TargetUpdateFrequency',4, ... 'ExperienceBufferLength',100000, ... 'CriticOptimizerOptions',criticOpts, ... 'MiniBatchSize',256);```

Then, create the DQN agent using the specified critic representation and agent options. For more information, see `rlDQNAgent`.

`agent = rlDQNAgent(critic,agentOpts);`

### Train Agent

To train the agent, first specify the training options. For this example, use the following options:

• Run one training session containing at most 1000 episodes, with each episode lasting at most 500 time steps.

• Display the training progress in the Episode Manager dialog box (set the `Plots` option) and disable the command line display (set the `Verbose` option to `false`).

• Stop training when the agent receives an moving average cumulative reward greater than 480. At this point, the agent can balance the cart-pole system in the upright position.

For more information, see `rlTrainingOptions`.

```trainOpts = rlTrainingOptions(... 'MaxEpisodes',1000, ... 'MaxStepsPerEpisode',500, ... 'Verbose',false, ... 'Plots','training-progress',... 'StopTrainingCriteria','AverageReward',... 'StopTrainingValue',480); ```

You can visualize the cart-pole system can be visualized by using the `plot` function during training or simulation.

`plot(env)`

Train the agent using the `train` function. Training this agent is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting `doTraining` to `false`. To train the agent yourself, set `doTraining` to `true`.

```doTraining = false; if doTraining % Train the agent. trainingStats = train(agent,env,trainOpts); else % Load the pretrained agent for the example. load('MATLABCartpoleDQNMulti.mat','agent') end```

### Simulate DQN Agent

To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, see `rlSimulationOptions` and `sim`. The agent can balance the cart-pole even when the simulation time increases to 500 steps.

```simOptions = rlSimulationOptions('MaxSteps',500); experience = sim(env,agent,simOptions);```

`totalReward = sum(experience.Reward)`
```totalReward = 500 ```