주요 콘텐츠

이산 카트-폴 시스템의 균형을 유지하도록 DQN 에이전트 훈련시키기

이 예제에서는 MATLAB®에서 모델링된 이산 행동 공간 카트-폴 시스템의 균형을 유지하도록 DQN(심층 Q-러닝 신경망) 에이전트를 훈련시키는 방법을 보여줍니다.

DQN 에이전트에 대한 자세한 내용은 DQN(심층 Q-신경망) 에이전트 항목을 참조하십시오. Simulink®에서 DQN 에이전트를 훈련시키는 예제는 진자가 위쪽으로 똑바로 서서 균형을 유지하도록 DQN 에이전트 훈련시키기 항목을 참조하십시오.

재현이 가능하도록 난수 스트림 고정하기

예제 코드의 다양한 단계에서 난수 계산이 포함될 수 있습니다. 예제 코드에 있는 다양한 섹션의 시작 부분에서 난수 스트림을 고정하면 매 실행 시에 섹션의 난수열이 유지되며 결과를 재현할 가능성이 높아집니다. 자세한 내용은 결과 재현성 항목을 참조하십시오.

시드값 0과 난수 알고리즘인 메르센 트위스터를 사용하여 난수 스트림을 고정합니다. 난수 생성에 사용되는 시드값을 제어하는 방법에 대한 자세한 내용은 rng 항목을 참조하십시오.

previousRngState = rng(0,"twister");

출력값 previousRngState는 스트림의 이전 상태에 대한 정보를 포함하는 구조체입니다. 이 예제의 끝부분에서 그 상태를 복원할 것입니다.

이산 행동 공간 카트-폴 MATLAB 환경

이 예제의 강화 학습 환경은 카트의 비구동 관절에 붙어 있는 막대로, 카트는 마찰이 없는 트랙을 따라 움직입니다. 훈련 목표는 이 막대가 똑바로 서 있게 만드는 것입니다.

이 환경의 경우 다음이 적용됩니다.

  • 위쪽을 향해 있을 때의 막대 각도는 0라디안입니다. 처음에는 막대가 움직이지 않고 아래쪽으로 (pi라디안 각도로) 매달려 있습니다.

  • 막대의 초기 각이 –0.05라디안과 0.05라디안 사이이고 위쪽을 향해 있을 때 시작합니다.

  • 에이전트에서 환경으로 전달되는 힘 행동 신호는 –10N 또는 10N입니다.

  • 환경에서 관측하는 값은 카트의 위치와 속도, 막대 각, 막대 각 도함수입니다.

  • 막대가 수직에서 12도 이상 기울거나 카트가 원래 위치에서 2.4m 이상 이동하면 에피소드가 종료됩니다.

  • 막대가 위쪽을 향해 바로 서 있는 매 시간 스텝마다 보상 +1이 주어집니다. 막대가 넘어지면 벌점 –5가 적용됩니다.

이 모델에 대한 자세한 설명은 Load Predefined Control System Environments 항목을 참조하십시오.

환경 객체 만들기

미리 정의된 카트-폴 환경 객체를 만듭니다.

env = rlPredefinedEnv("CartPole-Discrete")
env = 
  CartPoleDiscreteAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -5
                    State: [4×1 double]

객체에는 에이전트가 힘 값 –10N 또는 10N 중 하나를 카트에 적용할 수 있는 이산 행동 공간이 있습니다.

관측값 및 행동 사양 정보를 가져옵니다.

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-10 10]
           Name: "CartPole Action"
    Description: [0×0 string]
      Dimension: [1 1]
       DataType: "double"

DQN 에이전트 만들기

재현이 가능하도록 난수 스트림을 고정합니다.

rng(0,"twister");

은닉 계층 크기가 20인 크리틱 신경망을 초기화하기 위해 agent initialization 객체를 만듭니다.

initOpts = rlAgentInitializationOptions(NumHiddenUnit=20);

rlDQNAgentOptions 객체와 rlOptimizerOptions 객체를 사용하여 훈련에 대한 에이전트 옵션을 지정합니다. 이 훈련의 경우 다음이 적용됩니다.

  • 256개 경험 미니 배치를 사용합니다. 미니 배치 크기가 이보다 작으면 계산이 효율적이지만 훈련 시 변동성이 생길 수 있습니다. 반대로 미니 배치 크기가 이보다 크면 훈련이 안정되지만 메모리가 더 많이 필요할 수 있습니다.

  • 평활화 인자를 1로 지정하여 매 4회 학습 반복 시 타깃 크리틱 신경망을 업데이트합니다.

  • 더블 DQN 알고리즘을 사용하지 마십시오.

agentOpts = rlDQNAgentOptions( ...
    MiniBatchSize            = 256,...
    TargetSmoothFactor       = 1, ...
    TargetUpdateFrequency    = 4,...
    UseDoubleDQN             = false);

훈련 중 DQN 에이전트는 엡실론-그리디 알고리즘을 사용하여 행동 공간을 탐색합니다. 훈련 중 엡실론 값이 점진적으로 감쇠되도록 감쇠율을 1e-3으로 지정합니다. 그러면 에이전트가 적합한 정책이 없는 경우에는 시작 부분 쪽으로 탐색하도록 하고, 에이전트가 최적의 정책을 학습한 경우에는 끝부분 쪽으로 탐색하도록 합니다.

agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-3;

관측값 사양과 행동 입력 사양, 초기화 옵션과 에이전트 옵션을 사용하여 DQN 에이전트를 만듭니다. 에이전트를 생성하면 크리틱 신경망의 초기 파라미터가 난수 값으로 초기화됩니다. 에이전트가 항상 동일한 파라미터 값으로 초기화되도록 난수 스트림을 고정합니다.

agent = rlDQNAgent(obsInfo,actInfo,initOpts,agentOpts);

자세한 내용은 rlDQNAgent 항목을 참조하십시오.

임의의 관측값 입력값을 사용하여 에이전트의 행동을 확인합니다.

getAction(agent,{rand(obsInfo.Dimension)})
ans = 1×1 cell array
    {[10]}

에이전트 훈련시키기

재현이 가능하도록 난수 스트림을 고정합니다.

rng(0,"twister");

에이전트를 훈련시키려면 먼저 훈련 옵션을 지정하십시오. 이 예제에서는 다음 옵션을 사용합니다.

  • 최대 1000개의 에피소드에 대해 훈련을 실행하며, 각 에피소드마다 최대 500개의 시간 스텝이 지속됩니다.

  • 강화 학습 훈련 모니터 창에 훈련 진행 상황을 표시하고(Plots 옵션 설정) 명령줄 표시를 비활성화합니다(Verbose 옵션을 false로 설정).

  • 훈련 에피소드 20개마다 그리디 정책 성능을 평가하고, 5회 시뮬레이션의 누적 보상에 대한 평균값을 계산합니다.

  • 평가 점수가 500에 도달하면 훈련을 중지합니다. 이 시점에서 에이전트는 똑바로 서 있는 위치에서 카트-폴 시스템의 균형을 유지할 수 있습니다.

% training options
trainOpts = rlTrainingOptions(...
    MaxEpisodes=1000, ...
    MaxStepsPerEpisode=500, ...
    Verbose=false, ...
    Plots="training-progress",...
    StopTrainingCriteria="EvaluationStatistic",...
    StopTrainingValue=500);

% agent evaluator
evl = rlEvaluator(EvaluationFrequency=20, NumEpisodes=5);

훈련 옵션에 대한 자세한 내용은 rlTrainingOptions 항목과 rlEvaluator 항목을 참조하십시오.

train 함수를 사용하여 에이전트를 훈련시킵니다. 이 에이전트를 훈련시키는 것은 완료하는 데 수 분이 소요되는 계산 집약적인 절차입니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(agent,env,trainOpts,Evaluator=evl);
else
    % Load the pretrained agent for the example.
    load("MATLABCartpoleDQNMulti.mat","agent")
end

에이전트 시뮬레이션하기

재현이 가능하도록 난수 스트림을 고정합니다.

rng(0,"twister");

plot 함수를 사용하여 카트-폴 시스템을 시각화할 수 있습니다.

plot(env)

훈련된 에이전트의 성능을 검증하려면 카트-폴 환경 내에서 에이전트를 시뮬레이션하십시오. 에이전트 시뮬레이션에 대한 자세한 내용은 rlSimulationOptions 항목과 sim 항목을 참조하십시오.

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

totalReward = sum(experience.Reward)
totalReward = 
500

에이전트가 카트-폴 시스템의 균형을 유지할 수 있습니다.

previousRngState에 저장된 정보를 사용하여 난수 스트림을 복원합니다.

rng(previousRngState);

참고 항목

함수

객체

도움말 항목