Main Content

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

MDP 환경에서 강화 학습 에이전트 훈련시키기

이 예제에서는 일반 마르코프 결정 과정(MDP) 환경을 풀도록 Q-러닝 에이전트를 훈련시키는 방법을 보여줍니다. 이러한 에이전트에 대한 자세한 내용은 Q-러닝 에이전트 항목을 참조하십시오.

이 MDP 환경은 다음과 같은 그래프를 가집니다.

여기서 각 요소는 다음과 같습니다.

  1. 각 원은 상태를 나타냅니다.

  2. 각 상태에는 올라가는(up) 결정이나 내려가는(down) 결정이 있습니다.

  3. 에이전트는 상태 1부터 시작합니다.

  4. 에이전트는 그래프의 각 천이에 해당하는 값과 동일한 보상을 받습니다.

  5. 훈련 목표는 누적 보상을 최대한으로 모으는 것입니다.

MDP 환경 만들기

8개의 상태와 2개의 행동("up" 및 "down")을 갖는 MDP 모델을 만듭니다.

MDP = createMDP(8,["up";"down"]);

위 그래프의 천이를 모델링하기 위해, MDP의 상태 천이 행렬과 보상 행렬을 수정합니다. 기본적으로 이러한 행렬은 0을 포함합니다. MDP 모델 생성과 MDP 객체의 속성에 대한 자세한 내용은 createMDP 항목을 참조하십시오.

MDP의 상태 천이 행렬과 보상 행렬을 지정합니다. 예를 들어, 아래 명령에서 다음과 같이 지정합니다.

  • 첫 두 줄은 행동 1("up")을 취하여 상태 1에서 상태 2로 천이하도록 지정하고 이 천이에 대해 보상 +3점을 지정합니다.

  • 다음 두 줄은 행동 2("down")를 취하여 상태 1에서 상태 3으로 천이하도록 지정하고 이 천이에 대해 보상 +1점을 지정합니다.

MDP.T(1,2,1) = 1;
MDP.R(1,2,1) = 3;
MDP.T(1,3,2) = 1;
MDP.R(1,3,2) = 1;

이와 비슷하게 이 그래프의 나머지 규칙에 대해서도 상태 천이와 보상을 지정합니다.

% State 2 transition and reward
MDP.T(2,4,1) = 1;
MDP.R(2,4,1) = 2;
MDP.T(2,5,2) = 1;
MDP.R(2,5,2) = 1;
% State 3 transition and reward
MDP.T(3,5,1) = 1;
MDP.R(3,5,1) = 2;
MDP.T(3,6,2) = 1;
MDP.R(3,6,2) = 4;
% State 4 transition and reward
MDP.T(4,7,1) = 1;
MDP.R(4,7,1) = 3;
MDP.T(4,8,2) = 1;
MDP.R(4,8,2) = 2;
% State 5 transition and reward
MDP.T(5,7,1) = 1;
MDP.R(5,7,1) = 1;
MDP.T(5,8,2) = 1;
MDP.R(5,8,2) = 9;
% State 6 transition and reward
MDP.T(6,7,1) = 1;
MDP.R(6,7,1) = 5;
MDP.T(6,8,2) = 1;
MDP.R(6,8,2) = 1;
% State 7 transition and reward
MDP.T(7,7,1) = 1;
MDP.R(7,7,1) = 0;
MDP.T(7,7,2) = 1;
MDP.R(7,7,2) = 0;
% State 8 transition and reward
MDP.T(8,8,1) = 1;
MDP.R(8,8,1) = 0;
MDP.T(8,8,2) = 1;
MDP.R(8,8,2) = 0;

상태 "s7""s8"을 MDP의 종료 상태로 지정합니다.

MDP.TerminalStates = ["s7";"s8"];

이 과정 모델에 대해 강화 학습 MDP 환경을 만듭니다.

env = rlMDPEnv(MDP);

에이전트의 초기 상태가 항상 상태 1로 지정되도록 초기 에이전트 상태를 반환하는 재설정 함수를 지정합니다. 이 함수는 매 훈련 에피소드 및 시뮬레이션이 시작될 때 호출됩니다. 초기 상태를 1로 설정하는 익명 함수 핸들을 만듭니다.

env.ResetFcn = @() 1;

재현이 가능하도록 난수 생성기 시드값을 고정합니다.

rng(0)

Q-러닝 에이전트 만들기

Q-러닝 에이전트를 만들기 위해, 먼저 MDP 환경의 관측값과 행동 사양을 사용하여 Q 테이블을 만듭니다. 표현의 학습률을 1로 설정합니다.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
qTable = rlTable(obsInfo, actInfo);
qFunction = rlQValueFunction(qTable, obsInfo, actInfo);
qOptions = rlOptimizerOptions("LearnRate",1);

그 다음, 이 테이블 표현을 사용하여 Q-러닝 에이전트를 만들고 엡실론-그리디 탐색을 구성합니다. Q-러닝 에이전트를 만드는 방법에 대한 자세한 내용은 rlQAgent 항목 및 rlQAgentOptions 항목을 참조하십시오.

agentOpts = rlQAgentOptions;
agentOpts.DiscountFactor = 1;
agentOpts.EpsilonGreedyExploration.Epsilon = 0.9;
agentOpts.EpsilonGreedyExploration.EpsilonDecay = 0.01;
agentOpts.CriticOptimizerOptions = qOptions;
qAgent = rlQAgent(qFunction,agentOpts); %#ok<NASGU> 

Q-러닝 에이전트 훈련시키기

에이전트를 훈련시키려면 먼저 훈련 옵션을 지정하십시오. 이 예제에서는 다음 옵션을 사용합니다.

  • 최대 500개의 에피소드에 대해 훈련시키며, 각 에피소드마다 최대 50개의 시간 스텝이 지속됩니다.

  • 연속 30개의 에피소드 동안 에이전트가 받은 평균 누적 보상이 10보다 크면 훈련을 중지합니다.

자세한 내용은 rlTrainingOptions 항목을 참조하십시오.

trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes = 500;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 13;
trainOpts.ScoreAveragingWindowLength = 30;

train 함수를 사용하여 에이전트를 훈련시킵니다. 완료하는 데 몇 분 정도 걸릴 수 있습니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;

if doTraining
    % Train the agent.
    trainingStats = train(qAgent,env,trainOpts); %#ok<UNRCH> 
else
    % Load pretrained agent for the example.
    load('genericMDPQAgent.mat','qAgent'); 
end

Q-러닝 결과 검증하기

훈련 결과를 검증하기 위해 sim 함수를 사용하여 훈련 환경에서 에이전트를 시뮬레이션합니다. 에이전트는 누적 보상이 13인 최적 경로를 성공적으로 찾습니다.

Data = sim(qAgent,env);
cumulativeReward = sum(Data.Reward)
cumulativeReward = 13

감가 인자가 1로 설정되어 있으므로 훈련된 에이전트의 Q 테이블 값은 환경의 감가되지 않은 리턴값과 일치합니다.

QTable = getLearnableParameters(getCritic(qAgent));
QTable{1}
ans = 8×2

   12.9874    7.0759
   -7.6425    9.9990
   10.7193    0.9090
    5.9128   -2.2466
    6.7830    8.9988
    7.5928   -5.5053
         0         0
         0         0

TrueTableValues = [13,12;5,10;11,9;3,2;1,9;5,1;0,0;0,0]
TrueTableValues = 8×2

    13    12
     5    10
    11     9
     3     2
     1     9
     5     1
     0     0
     0     0

참고 항목

|

관련 항목