Main Content

카트-폴 시스템의 균형을 유지하도록 DQN 에이전트 훈련시키기

이 예제에서는 MATLAB®에서 모델링된 카트-폴 시스템의 균형을 유지하도록 DQN(심층 Q-러닝 신경망) 에이전트를 훈련시키는 방법을 보여줍니다.

DQN 에이전트에 대한 자세한 내용은 DQN(심층 Q-신경망) 에이전트 항목을 참조하십시오. Simulink®에서 DQN 에이전트를 훈련시키는 예제는 진자가 위쪽으로 똑바로 서서 균형을 유지하도록 DQN 에이전트 훈련시키기 항목을 참조하십시오.

카트-폴 MATLAB 환경

이 예제의 강화 학습 환경은 카트의 비구동 관절에 붙어 있는 막대로, 카트는 마찰이 없는 트랙을 따라 움직입니다. 훈련 목표는 이 막대가 넘어지지 않고 똑바로 서 있게 만드는 것입니다.

이 환경의 경우 다음이 적용됩니다.

  • 위쪽으로 똑바로 균형이 잡혀 있을 때의 막대 위치는 0라디안이고, 아래쪽으로 매달려 있을 때의 위치는 pi라디안입니다.

  • 막대의 초기 각이 –0.05라디안과 0.05라디안 사이이고 위쪽을 향해 있을 때 시작합니다.

  • 에이전트에서 환경으로 전달되는 힘 행동 신호는 –10N 또는 10N입니다.

  • 환경에서 관측하는 값은 카트의 위치와 속도, 막대 각, 막대 각 도함수입니다.

  • 막대가 수직에서 12도 이상 기울거나 카트가 원래 위치에서 2.4m 이상 이동하면 에피소드가 종료됩니다.

  • 막대가 위쪽을 향해 바로 서 있는 매 시간 스텝마다 보상 +1이 주어집니다. 막대가 넘어지면 벌점 –5가 적용됩니다.

이 모델에 대한 자세한 설명은 Load Predefined Control System Environments 항목을 참조하십시오.

환경 인터페이스 만들기

시스템에 대해 미리 정의된 환경 인터페이스를 만듭니다.

env = rlPredefinedEnv("CartPole-Discrete")
env = 
  CartPoleDiscreteAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -5
                    State: [4x1 double]

인터페이스에는 에이전트가 힘 값 –10N 또는 10N 중 하나를 카트에 적용할 수 있는 이산 행동 공간이 있습니다.

관측값 및 행동 사양 정보를 가져옵니다.

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-10 10]
           Name: "CartPole Action"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

재현이 가능하도록 난수 생성기 시드값을 고정합니다.

rng(0)

DQN 에이전트 만들기

DQN 에이전트는 벡터 Q-값 함수 크리틱을 사용할 수 있는데, 일반적으로 이 방식이 비교 가능한 단일 출력 크리틱보다 더 효율적입니다. 벡터 Q-값 함수 크리틱은 입력값으로 관측값을 갖고 출력값으로 상태-행동 값을 갖습니다. 각 출력 요소는 관측값 입력값이 나타내는 상태로부터 상응하는 이산 행동을 취했을 때 기대되는 누적 장기 보상을 표현합니다. 가치 함수를 만드는 방법에 대한 자세한 내용은 Create Policies and Value Functions 항목을 참조하십시오.

크리틱 내에서 Q-값 함수를 근사하려면 1개의 입력 채널(4차원 관측 상태 벡터)과 2개의 요소(10N 행동 요소 및 –10N 행동 요소)로 구성된 1개의 출력 채널을 갖는 신경망을 사용하십시오. 신경망을 layer 객체로 구성된 배열로 정의하고, 환경 사양 객체에서 관측값 공간 차원과 가능한 행동 개수를 가져옵니다.

net = [
    featureInputLayer(obsInfo.Dimension(1))
    fullyConnectedLayer(20)
    reluLayer
    fullyConnectedLayer(length(actInfo.Elements))
    ];

dlnetwork로 변환하고 가중치 개수를 표시합니다.

net = dlnetwork(net);
summary(net)
   Initialized: true

   Number of learnables: 142

   Inputs:
      1   'input'   4 features

신경망 구성을 확인합니다.

plot(net)

net와 환경 사양을 사용하여 크리틱 근사기를 만듭니다. 자세한 내용은 rlVectorQValueFunction 항목을 참조하십시오.

critic = rlVectorQValueFunction(net,obsInfo,actInfo);

임의의 관측값 입력값을 사용하여 크리틱을 확인합니다.

getValue(critic,{rand(obsInfo.Dimension)})
ans = 2x1 single column vector

   -0.2257
    0.4299

critic을 사용하여 DQN 에이전트를 만듭니다. 자세한 내용은 rlDQNAgent 항목을 참조하십시오.

agent = rlDQNAgent(critic);

임의의 관측값 입력값을 사용하여 에이전트를 확인합니다.

getAction(agent,{rand(obsInfo.Dimension)})
ans = 1x1 cell array
    {[10]}

크리틱에 대한 훈련 옵션을 포함하여 DQN 에이전트 옵션을 지정합니다. 또는 rlDQNAgentOptions 객체와 rlOptimizerOptions 객체를 사용할 수 있습니다.

agent.AgentOptions.UseDoubleDQN = false;
agent.AgentOptions.TargetSmoothFactor = 1;
agent.AgentOptions.TargetUpdateFrequency = 4;
agent.AgentOptions.ExperienceBufferLength = 1e5;
agent.AgentOptions.MiniBatchSize = 256;
agent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3;
agent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;

에이전트 훈련시키기

에이전트를 훈련시키려면 먼저 훈련 옵션을 지정하십시오. 이 예제에서는 다음 옵션을 사용합니다.

  • 최대 1000개의 에피소드가 포함된 1개의 훈련 세션을 실행하며, 각 에피소드마다 최대 500개의 시간 스텝이 지속됩니다.

  • 강화 학습 훈련 모니터 대화 상자에 훈련 진행 상황을 표시하고(Plots 옵션 설정) 명령줄 표시를 비활성화합니다(Verbose 옵션을 false로 설정).

  • 에이전트가 받은 이동평균 누적 보상이 480보다 클 때 훈련을 중지합니다. 이 시점에서 에이전트는 똑바로 서 있는 위치에서 카트-폴 시스템의 균형을 유지할 수 있습니다.

자세한 내용은 rlTrainingOptions 항목을 참조하십시오.

trainOpts = rlTrainingOptions(...
    MaxEpisodes=1000, ...
    MaxStepsPerEpisode=500, ...
    Verbose=false, ...
    Plots="training-progress",...
    StopTrainingCriteria="AverageReward",...
    StopTrainingValue=480); 

훈련이나 시뮬레이션 중에 plot 함수를 사용하여 카트-폴 시스템을 시각화할 수 있습니다.

plot(env)

train 함수를 사용하여 에이전트를 훈련시킵니다. 이 에이전트를 훈련시키는 것은 완료하는 데 수 분이 소요되는 계산 집약적인 절차입니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load("MATLABCartpoleDQNMulti.mat","agent")
end

DQN 에이전트 시뮬레이션하기

훈련된 에이전트의 성능을 검증하려면 카트-폴 환경 내에서 에이전트를 시뮬레이션하십시오. 에이전트 시뮬레이션에 대한 자세한 내용은 rlSimulationOptions 항목과 sim 항목을 참조하십시오. 에이전트는 시뮬레이션 시간이 500개의 스텝으로 늘어날 경우에도 카트-폴의 균형을 유지할 수 있습니다.

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward)
totalReward = 500

참고 항목

함수

객체

관련 예제

세부 정보