Main Content

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

기본 그리드 월드에서 강화 학습 에이전트 훈련시키기

이 예제에서는 Q-러닝 에이전트와 SARSA 에이전트를 훈련시켜서 강화 학습으로 그리드 월드 환경을 푸는 방법을 보여줍니다. 해당 에이전트에 대한 자세한 내용은 Q-러닝 에이전트 항목과 SARSA 에이전트 항목을 참조하십시오.

이 그리드 월드 환경은 다음과 같은 구성과 규칙을 갖습니다.

  1. 그리드 월드는 5×5 크기로, 테두리로 경계가 지어져 있으며 4가지 가능한 행동(북쪽 = 1, 남쪽 = 2, 동쪽 = 3, 서쪽 = 4)을 가집니다.

  2. 에이전트는 셀 [2,1](두 번째 행, 첫 번째 열)부터 시작합니다.

  3. 에이전트가 셀 [5,5](파란색)의 종료 상태에 도달하면 보상으로 +10을 받습니다.

  4. 환경에는 셀 [2,4]에서 셀 [4,4]로 가는 특별한 점프가 포함되어 있으며, 보상으로 +5가 주어집니다.

  5. 에이전트는 장애물(검은색 셀)에 의해 막힐 수 있습니다.

  6. 다른 모든 행동에서는 보상으로 –1이 주어집니다.

그리드 월드 환경 만들기

기본 그리드 월드 환경을 만듭니다.

env = rlPredefinedEnv("BasicGridWorld");

에이전트의 초기 상태가 항상 [2,1]이 되도록 지정하기 위해 초기 에이전트 상태의 상태 번호를 반환하는 재설정 함수를 만듭니다. 이 함수는 매 훈련 에피소드 및 시뮬레이션이 시작될 때 호출됩니다. 상태에는 위치 [1,1]부터 시작해서 번호가 매겨집니다. 상태 번호는 먼저 첫 번째 열에서 아래로 이동하고, 그 다음 열에서 아래로 이동하는 방향으로 증가합니다. 따라서, 초기 상태를 2로 설정하는 익명 함수 핸들을 만듭니다.

env.ResetFcn = @() 2;

재현이 가능하도록 난수 생성기 시드값을 고정합니다.

rng(0)

Q-러닝 에이전트 만들기

Q-러닝 에이전트를 만들기 위해, 먼저 그리드 월드 환경의 관측값과 행동 사양을 사용하여 Q 테이블을 만듭니다. 최적화 함수의 학습률을 0.01로 설정합니다.

qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qFunction = rlQValueFunction(qTable,getObservationInfo(env),getActionInfo(env));
qOptions = rlOptimizerOptions("LearnRate",0.01);

그 다음, Q 값 함수를 사용하여 Q-러닝 에이전트를 만들고 엡실론-그리디 탐색을 구성합니다. Q-러닝 에이전트를 만드는 방법에 대한 자세한 내용은 rlQAgent 항목 및 rlQAgentOptions 항목을 참조하십시오.

agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
agentOpts.CriticOptimizerOptions = qOptions;
qAgent = rlQAgent(qFunction,agentOpts);

Q-러닝 에이전트 훈련시키기

에이전트를 훈련시키려면 먼저 훈련 옵션을 지정하십시오. 이 예제에서는 다음 옵션을 사용합니다.

  • 최대 200개의 에피소드에 대해 훈련시킵니다. 각 에피소드가 최대 50개의 시간 스텝만큼 지속되도록 지정합니다.

  • 연속 30개의 에피소드 동안 에이전트가 받은 평균 누적 보상이 10보다 크면 훈련을 중지합니다.

자세한 내용은 rlTrainingOptions 항목을 참조하십시오.

trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 11;
trainOpts.ScoreAveragingWindowLength = 30;

train 함수를 사용하여 Q-러닝 에이전트를 훈련시킵니다. 훈련을 완료하는 데 몇 분 정도 걸릴 수 있습니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;

if doTraining
    % Train the agent.
    trainingStats = train(qAgent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load('basicGWQAgent.mat','qAgent')
end

에피소드 관리자 창이 열리고 훈련 진행 상황이 표시됩니다.

Q-러닝 결과 검증하기

훈련 결과를 검증하기 위해 훈련 환경에서 에이전트를 시뮬레이션합니다.

시뮬레이션을 실행하기 전에 환경을 시각화하고 에이전트 상태 추적을 유지하도록 시각화를 구성합니다.

plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;

sim 함수를 사용하여 환경에서 에이전트를 시뮬레이션합니다.

sim(qAgent,env)

에이전트 추적에서 에이전트가 셀 [2,4]에서 셀 [4,4]로 성공적으로 점프하는 것으로 나타납니다.

SARSA 에이전트를 만들고 훈련시키기

SARSA 에이전트를 만들기 위해, Q-러닝 에이전트를 만들었을 때와 동일하게 Q 값 함수와 엡실론-그리디 구성을 사용합니다. SARSA 에이전트를 만드는 방법에 대한 자세한 내용은 rlSARSAAgent 항목과 rlSARSAAgentOptions 항목을 참조하십시오.

agentOpts = rlSARSAAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;
agentOpts.CriticOptimizerOptions = qOptions;
sarsaAgent = rlSARSAAgent(qFunction,agentOpts);

train 함수를 사용하여 SARSA 에이전트를 훈련시킵니다. 훈련을 완료하는 데 몇 분 정도 걸릴 수 있습니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;

if doTraining
    % Train the agent.
    trainingStats = train(sarsaAgent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load('basicGWSarsaAgent.mat','sarsaAgent')
end

SARSA 훈련 검증하기

훈련 결과를 검증하기 위해 훈련 환경에서 에이전트를 시뮬레이션합니다.

plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;

환경에서 에이전트를 시뮬레이션합니다.

sim(sarsaAgent,env)

SARSA 에이전트는 Q-러닝 에이전트와 동일한 그리드 월드 해를 구합니다.

참고 항목

|

관련 항목