Main Content

카트-폴 시스템의 균형을 유지하도록 DQN 에이전트 훈련시키기

이 예제에서는 MATLAB®에서 모델링된 카트-폴 시스템의 균형을 유지하도록 DQN(심층 Q-러닝 신경망) 에이전트를 훈련시키는 방법을 보여줍니다.

DQN 에이전트에 대한 자세한 내용은 심층 Q-신경망 에이전트 항목을 참조하십시오. Simulink®에서 DQN 에이전트를 훈련시키는 예제는 Train DQN Agent to Swing Up and Balance Pendulum 항목을 참조하십시오.

카트-폴 MATLAB 환경

이 예제의 강화 학습 환경은 카트의 비구동 관절에 붙어 있는 막대로, 카트는 마찰이 없는 트랙을 따라 움직입니다. 훈련 목표는 이 막대가 넘어지지 않고 똑바로 서 있게 만드는 것입니다.

이 환경의 경우 다음이 적용됩니다.

  • 위쪽으로 똑바로 균형이 잡혀 있을 때의 막대 위치는 0라디안이고, 아래쪽으로 매달려 있을 때의 위치는 pi라디안입니다.

  • 막대의 초기 각이 –0.05라디안과 0.05라디안 사이이고 위쪽을 향해 있을 때 시작합니다.

  • 에이전트에서 환경으로 전달되는 힘 행동 신호는 –10N에서 10N까지입니다.

  • 환경에서 관측하는 값은 카트의 위치와 속도, 막대 각, 막대 각 도함수입니다.

  • 막대가 수직에서 12도 이상 기울거나 카트가 원래 위치에서 2.4m 이상 이동하면 에피소드가 종료됩니다.

  • 막대가 위쪽을 향해 바로 서 있는 매 시간 스텝마다 보상 +1점이 주어집니다. 막대가 넘어지면 벌점 –5점이 적용됩니다.

이 모델에 대한 자세한 설명은 Load Predefined Control System Environments 항목을 참조하십시오.

환경 인터페이스 만들기

시스템에 대해 미리 정의된 환경 인터페이스를 만듭니다.

env = rlPredefinedEnv("CartPole-Discrete")
env = 
  CartPoleDiscreteAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -5
                    State: [4x1 double]

인터페이스에는 에이전트가 힘 값 –10N 또는 10N 중 하나를 카트에 적용할 수 있는 이산 행동 공간이 있습니다.

관측값 및 행동 사양 정보를 가져옵니다.

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-10 10]
           Name: "CartPole Action"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

재현이 가능하도록 난수 생성기 시드값을 고정합니다.

rng(0)

DQN 에이전트 만들기

DQN 에이전트는 관측값과 행동이 주어지면 가치 함수 크리틱을 사용하여 장기 보상을 근사합니다.

DQN 에이전트는 다중 출력 Q-값 크리틱 근사기를 사용할 수 있는데, 일반적으로 이 방식이 더 효율적입니다. 다중 출력 근사기는 입력값으로 관측값을 갖고 출력값으로 상태-행동 값을 갖습니다. 각 출력 요소는 관측값 입력값이 나타내는 상태로부터 상응하는 이산 행동을 취했을 때 예상되는 누적 장기 보상을 표현합니다.

크리틱을 만들려면 먼저 1개의 입력값(4차원 관측 상태)과 2개의 요소(10N 행동 요소 및 –10N 행동 요소)로 구성된 1개의 출력 벡터를 갖는 심층 신경망을 만드십시오. 신경망을 기반으로 가치 함수 표현을 만드는 방법에 대한 자세한 내용은 Create Policies and Value Functions 항목을 참조하십시오.

dnn = [
    featureInputLayer(obsInfo.Dimension(1),'Normalization','none','Name','state')
    fullyConnectedLayer(24,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(24, 'Name','CriticStateFC2')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(length(actInfo.Elements),'Name','output')];
dnn = dlnetwork(dnn);

신경망 구성을 확인합니다.

figure
plot(layerGraph(dnn))

Figure contains an axes object. The axes object contains an object of type graphplot.

rlOptimizerOptions를 사용하여 크리틱 최적화 함수에 대한 훈련 옵션을 지정합니다.

criticOpts = rlOptimizerOptions('LearnRate',0.001,'GradientThreshold',1);

지정된 신경망과 옵션을 사용하여 크리틱 표현을 만듭니다. 자세한 내용은 rlVectorQValueFunction 항목을 참조하십시오.

critic = rlVectorQValueFunction(dnn,obsInfo,actInfo);

DQN 에이전트를 만들려면 먼저 rlDQNAgentOptions 객체를 사용하여 DQN 에이전트 옵션을 지정하십시오.

agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetSmoothFactor',1, ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',100000, ...
    'CriticOptimizerOptions',criticOpts, ...
    'MiniBatchSize',256);

그런 다음, 지정된 크리틱 표현과 에이전트 옵션을 사용하여 DQN 에이전트를 만듭니다. 자세한 내용은 rlDQNAgent 항목을 참조하십시오.

agent = rlDQNAgent(critic,agentOpts);

에이전트 훈련시키기

에이전트를 훈련시키려면 먼저 훈련 옵션을 지정하십시오. 이 예제에서는 다음 옵션을 사용합니다.

  • 최대 1000개의 에피소드가 포함된 1개의 훈련 세션을 실행하며, 각 에피소드마다 최대 500개의 시간 스텝이 지속됩니다.

  • 에피소드 관리자 대화 상자에 훈련 진행 상황을 표시하고(Plots 옵션 설정) 명령줄 표시를 비활성화합니다(Verbose 옵션을 false로 설정).

  • 에이전트가 받은 이동평균 누적 보상이 480점 이상일 때 훈련을 중지합니다. 이 시점에서 에이전트는 똑바로 서 있는 위치에서 카트-폴 시스템의 균형을 유지할 수 있습니다.

자세한 내용은 rlTrainingOptions 항목을 참조하십시오.

trainOpts = rlTrainingOptions(...
    'MaxEpisodes',1000, ...
    'MaxStepsPerEpisode',500, ...
    'Verbose',false, ...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',480); 

훈련이나 시뮬레이션 중에 plot 함수를 사용하여 카트-폴 시스템을 시각화할 수 있습니다.

plot(env)

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

train 함수를 사용하여 에이전트를 훈련시킵니다. 이 에이전트를 훈련시키는 것은 완료하는 데 수 분이 소요되는 계산 집약적인 절차입니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load('MATLABCartpoleDQNMulti.mat','agent')
end

DQN 에이전트 시뮬레이션하기

훈련된 에이전트의 성능을 검증하려면 카트-폴 환경 내에서 에이전트를 시뮬레이션하십시오. 에이전트 시뮬레이션에 대한 자세한 내용은 rlSimulationOptions 항목과 sim 항목을 참조하십시오. 에이전트는 시뮬레이션 시간이 500개의 스텝으로 늘어날 경우에도 카트-폴의 균형을 유지할 수 있습니다.

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

totalReward = sum(experience.Reward)
totalReward = 500

참고 항목

관련 항목