How does RL algorithm work with RNNs?

조회 수: 6 (최근 30일)
Tech Logg Ding
Tech Logg Ding 2021년 2월 10일
댓글: Takeshi Takahashi 2021년 3월 2일
Hi,
I noticed that Matlab 2021a allows users to use RL algorithms, such as DDPG, with RNN in the deep neural network structure. This is great as it could benefit continuous control problems with time delay and time-dependent parameters.
However, I am wondering about the algorithm used by Matlab for the RL RNN learning process. RNN learn through backpropagation through time (BPTT), therefore, the sampled states for BPTT must be in series. On the other hand, RL algorithms (such as DDPG) learn by sampling random samples from the experience buffer; therefore, the algorithms does not integrate naturally compared to the conventional MLPNN structure. How does Matlab work with this? Is there any paper that I can refrence?
Next, I am also curious about the RNN BPTT execution in MATLAB. In RL, an episode could have hundred to thousands of time steps and RNN is usually expected to keep a memory of the states in each time step (referring to the unrolled structure) in order to learn the weights and bias for its' internal state. Does the series terminate at the end of every episode to update the RNN? Will this consume significantly more memory?
Thank you very much.
  댓글 수: 1
Tech Logg Ding
Tech Logg Ding 2021년 2월 23일
Bumping this question. After looking into the documentation, I've not found any information on how updates with RNN in DNN works. This paper (https://academic.oup.com/jigpal/article/18/5/620/751594?login=true) also describes that random episodes should be sampled with a short series for to train its' lstm network to work effectively. Does the RL toolbox include this?

댓글을 달려면 로그인하십시오.

채택된 답변

Takeshi Takahashi
Takeshi Takahashi 2021년 2월 24일
Hi,
rlDDPGAgent with RNN first randomly samples B sequences (trajectories) from the experience buffer, where B is MiniBatchSize. Then, it randomly selects the starting point of each sampled sequence if the sequence is longer than L, where L is SequenceLength you specified. The end point of the sequence will be determined by the starting point and L so that the length becomes L.
Suppose some sampled sequences from the experience buffer are shorter than L. In that case, the sequences are padded with fake samples so that all short sequences in a batch have the same length (L). We apply masking to those padded samples, and the padded samples don't affect the BPTT.
We use these short sequences as a batch for BPTT. MiniBatchSize and SequenceLength control the size of the batch. Bigger MiniBatchSize and SequenceLength require more memory space during BPTT.
I hope this clarifies your question.
Thank you.
  댓글 수: 2
Tech Logg Ding
Tech Logg Ding 2021년 2월 28일
Got it! Thank you very much! Does the other algorithms such as TD3 and SAC use the same sampling method?
Takeshi Takahashi
Takeshi Takahashi 2021년 3월 2일
Yes. They use the same method.

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by