How does RL algorithm work with RNNs?

5 visualizaciones (últimos 30 días)
Tech Logg Ding
Tech Logg Ding el 10 de Feb. de 2021
Comentada: Takeshi Takahashi el 2 de Mzo. de 2021
Hi,
I noticed that Matlab 2021a allows users to use RL algorithms, such as DDPG, with RNN in the deep neural network structure. This is great as it could benefit continuous control problems with time delay and time-dependent parameters.
However, I am wondering about the algorithm used by Matlab for the RL RNN learning process. RNN learn through backpropagation through time (BPTT), therefore, the sampled states for BPTT must be in series. On the other hand, RL algorithms (such as DDPG) learn by sampling random samples from the experience buffer; therefore, the algorithms does not integrate naturally compared to the conventional MLPNN structure. How does Matlab work with this? Is there any paper that I can refrence?
Next, I am also curious about the RNN BPTT execution in MATLAB. In RL, an episode could have hundred to thousands of time steps and RNN is usually expected to keep a memory of the states in each time step (referring to the unrolled structure) in order to learn the weights and bias for its' internal state. Does the series terminate at the end of every episode to update the RNN? Will this consume significantly more memory?
Thank you very much.
  1 comentario
Tech Logg Ding
Tech Logg Ding el 23 de Feb. de 2021
Bumping this question. After looking into the documentation, I've not found any information on how updates with RNN in DNN works. This paper (https://academic.oup.com/jigpal/article/18/5/620/751594?login=true) also describes that random episodes should be sampled with a short series for to train its' lstm network to work effectively. Does the RL toolbox include this?

Iniciar sesión para comentar.

Respuesta aceptada

Takeshi Takahashi
Takeshi Takahashi el 24 de Feb. de 2021
Hi,
rlDDPGAgent with RNN first randomly samples B sequences (trajectories) from the experience buffer, where B is MiniBatchSize. Then, it randomly selects the starting point of each sampled sequence if the sequence is longer than L, where L is SequenceLength you specified. The end point of the sequence will be determined by the starting point and L so that the length becomes L.
Suppose some sampled sequences from the experience buffer are shorter than L. In that case, the sequences are padded with fake samples so that all short sequences in a batch have the same length (L). We apply masking to those padded samples, and the padded samples don't affect the BPTT.
We use these short sequences as a batch for BPTT. MiniBatchSize and SequenceLength control the size of the batch. Bigger MiniBatchSize and SequenceLength require more memory space during BPTT.
I hope this clarifies your question.
Thank you.
  2 comentarios
Tech Logg Ding
Tech Logg Ding el 28 de Feb. de 2021
Got it! Thank you very much! Does the other algorithms such as TD3 and SAC use the same sampling method?
Takeshi Takahashi
Takeshi Takahashi el 2 de Mzo. de 2021
Yes. They use the same method.

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by