주요 콘텐츠

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

기본 그리드 월드에서 강화 학습 에이전트 훈련시키기

이 예제에서는 Q-러닝 에이전트와 SARSA 에이전트를 훈련시켜서 강화 학습으로 그리드 월드 환경을 푸는 방법을 보여줍니다. 해당 에이전트에 대한 자세한 내용은 Q-러닝 에이전트 항목과 SARSA 에이전트 항목을 참조하십시오.

재현이 가능하도록 난수 스트림 고정하기

예제 코드의 다양한 단계에서 난수 계산이 포함될 수 있습니다. 예제 코드에 있는 다양한 섹션의 시작 부분에서 난수 스트림을 고정하면 매 실행 시에 섹션의 난수열이 유지되며 결과를 재현할 가능성이 높아집니다. 자세한 내용은 결과 재현성 항목을 참조하십시오.

시드값 0과 난수 알고리즘인 메르센 트위스터를 사용하여 난수 스트림을 고정합니다. 난수 생성에 사용되는 시드값을 제어하는 방법에 대한 자세한 내용은 rng 항목을 참조하십시오.

previousRngState = rng(0,"twister");

출력값 previousRngState는 스트림의 이전 상태에 대한 정보를 포함하는 구조체입니다. 이 예제의 끝부분에서 그 상태를 복원할 것입니다.

이 그리드 월드 환경은 다음과 같은 구성과 규칙을 갖습니다.

  1. 그리드 월드는 5×5 크기로, 테두리로 경계가 지어져 있으며 4가지 가능한 행동(북쪽 = 1, 남쪽 = 2, 동쪽 = 3, 서쪽 = 4)을 가집니다.

  2. 에이전트는 셀 [2,1](두 번째 행, 첫 번째 열)부터 시작합니다.

  3. 에이전트가 셀 [5,5](파란색)의 종료 상태에 도달하면 보상으로 +10을 받습니다.

  4. 환경에는 셀 [2,4]에서 셀 [4,4]로 가는 특별한 점프가 포함되어 있으며, 보상으로 +5가 주어집니다.

  5. 에이전트는 장애물(검은색 셀)에 의해 막힐 수 있습니다.

  6. 다른 모든 행동에서는 보상으로 –1이 주어집니다.

그리드 월드 환경 만들기

기본 그리드 월드 환경을 만듭니다.

env = rlPredefinedEnv("BasicGridWorld");

에이전트의 초기 상태가 항상 [2,1]이 되도록 지정하기 위해 초기 에이전트 상태의 상태 번호를 반환하는 재설정 함수를 만듭니다. 이 함수는 매 훈련 에피소드 및 시뮬레이션이 시작될 때 호출됩니다. 상태에는 위치 [1,1]부터 시작해서 번호가 매겨집니다. 상태 번호는 먼저 첫 번째 열에서 아래로 이동하고, 그 다음 열에서 아래로 이동하는 방향으로 증가합니다. 따라서, 초기 상태를 2로 설정하는 익명 함수 핸들을 만듭니다.

env.ResetFcn = @() 2;

Q-러닝 에이전트 만들기

Q-러닝 에이전트를 만들기 위해, 먼저 그리드 월드 환경의 관측값과 행동 사양을 사용하여 Q 테이블을 만듭니다. 최적화 함수의 학습률을 0.01로 설정합니다.

qTable = rlTable(getObservationInfo(env), ...
    getActionInfo(env));

에이전트 내에서 Q-값 함수를 근사하기 위해, 테이블과 환경 정보를 사용하여 rlQValueFunction 근사기 객체를 만듭니다.

qFcnAppx = rlQValueFunction(qTable, ...
    getObservationInfo(env), ...
    getActionInfo(env));

그런 다음 Q-값 함수를 사용하여 Q-러닝 에이전트를 만듭니다.

qAgent = rlQAgent(qFcnAppx);

엡실론-그리디 탐색 및 함수 근사기의 학습률 같은 에이전트 옵션을 구성합니다.

qAgent.AgentOptions.EpsilonGreedyExploration.Epsilon = .04;
qAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 0.01;

Q-러닝 에이전트를 만드는 방법에 대한 자세한 내용은 rlQAgent 항목 및 rlQAgentOptions 항목을 참조하십시오.

Q-러닝 에이전트 훈련시키기

에이전트를 훈련시키려면 먼저 훈련 옵션을 지정하십시오. 이 예제에서는 다음 옵션을 사용합니다.

  • 최대 200개의 에피소드에 대해 훈련시킵니다. 각 에피소드가 최대 50개의 시간 스텝만큼 지속되도록 지정합니다.

  • 연속 30개의 에피소드 동안 에이전트가 받은 평균 누적 보상이 11이면 훈련을 중지합니다.

훈련 옵션에 대한 자세한 내용은 rlTrainingOptions 항목을 참조하십시오.

trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 11;
trainOpts.ScoreAveragingWindowLength = 30;

재현이 가능하도록 난수 스트림을 고정합니다.

rng(0,"twister");

train 함수를 사용하여 Q-러닝 에이전트를 훈련시킵니다. 훈련을 완료하는 데 몇 분 정도 걸릴 수 있습니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;
if doTraining
    % Train the agent.
    qTrainingStats = train(qAgent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load("basicGWQAgent.mat","qAgent")
end

강화 학습 훈련 모니터 창이 열리고 훈련 진행 상황이 표시됩니다.

Q-러닝 결과 검증하기

재현이 가능하도록 난수 스트림을 고정합니다.

rng(0,"twister");

훈련 결과를 검증하기 위해 훈련 환경에서 에이전트를 시뮬레이션합니다.

시뮬레이션을 실행하기 전에 환경을 시각화하고 에이전트 상태 추적을 유지하도록 시각화를 구성합니다.

plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;

sim 함수를 사용하여 환경에서 에이전트를 시뮬레이션합니다.

sim(qAgent,env)

Figure contains an axes object. The hidden axes object contains 14 objects of type line, patch.

에이전트 추적에서 에이전트가 셀 [2,4]에서 셀 [4,4]로 성공적으로 점프하는 것으로 나타납니다.

SARSA 에이전트를 만들고 훈련시키기

SARSA 에이전트를 만들기 위해, Q-러닝 에이전트를 만들었을 때와 동일하게 Q 값 함수와 엡실론-그리디 구성을 사용합니다. SARSA 에이전트를 만드는 방법에 대한 자세한 내용은 rlSARSAAgent 항목과 rlSARSAAgentOptions 항목을 참조하십시오.

sarsaAgent = rlSARSAAgent(qFcnAppx);
sarsaAgent.AgentOptions.EpsilonGreedyExploration.Epsilon = .04;
sarsaAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 0.01;

train 함수를 사용하여 SARSA 에이전트를 훈련시킵니다. 훈련을 완료하는 데 몇 분 정도 걸릴 수 있습니다. 이 예제를 실행하는 동안 시간을 절약하려면 doTrainingfalse로 설정하여 사전 훈련된 에이전트를 불러오십시오. 에이전트를 직접 훈련시키려면 doTrainingtrue로 설정하십시오.

doTraining = false;
if doTraining
    % Train the agent.
    sarsaTrainingStats = train(sarsaAgent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load("basicGWSarsaAgent.mat","sarsaAgent")
end

SARSA 훈련 검증하기

재현이 가능하도록 난수 스트림을 고정합니다.

rng(0,"twister");

훈련 결과를 검증하기 위해 훈련 환경에서 에이전트를 시뮬레이션합니다.

plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;

환경에서 에이전트를 시뮬레이션합니다.

sim(sarsaAgent,env)

Figure contains an axes object. The hidden axes object contains 21 objects of type line, patch.

SARSA 에이전트는 Q-러닝 에이전트와 동일한 그리드 월드 해를 구합니다.

previousRngState에 저장된 정보를 사용하여 난수 스트림을 복원합니다.

rng(previousRngState);

참고 항목

함수

객체

도움말 항목