Main Content

사용자 지정 함수를 사용하여 MATLAB 환경 만들기

이 예제에서는 MATLAB®에서 사용자 지정 동적 함수를 제공하여 카트-폴 환경을 만드는 방법을 보여줍니다.

rlFunctionEnv 함수를 사용하여 관측값 사양, 행동 사양, 사용자 정의 step 함수 및 reset 함수로부터 MATLAB 강화 학습 환경을 만들 수 있습니다. 그런 다음 이 환경에서 강화 학습 에이전트를 훈련시킬 수 있습니다. 이 예제에 필요한 step 함수와 reset 함수는 이미 정의되어 있습니다.

사용자 지정 함수를 사용하여 환경을 만드는 것은 동특성이 덜 복잡한 환경, 특수한 시각화 요건이 필요 없는 환경 또는 타사 라이브러리와의 인터페이스가 갖춰진 환경에 유용합니다. 보다 복잡한 환경의 경우 템플릿 클래스를 사용하여 환경 객체를 만들 수 있습니다. 자세한 내용은 Create Custom MATLAB Environment from Template 항목을 참조하십시오.

강화 학습 환경을 만드는 방법에 대한 자세한 내용은 MATLAB 강화 학습 환경 만들기 항목과 Simulink 강화 학습 환경 만들기 항목을 참조하십시오.

카트-폴 MATLAB 환경

카트-폴 환경은 카트의 비구동 관절에 붙어 있는 막대로, 카트는 마찰이 없는 트랙을 따라 움직입니다. 훈련 목표는 이 진자가 넘어지지 않고 똑바로 서 있게 만드는 것입니다.

이 환경의 경우 다음이 적용됩니다.

  • 위쪽으로 똑바로 균형이 잡혀 있을 때의 진자 위치는 0라디안이고, 아래쪽으로 매달려 있을 때의 위치는 pi라디안입니다.

  • 진자의 초기 각이 –0.05도와 0.05도 사이이고 위쪽을 향해 있을 때 시작합니다.

  • 에이전트에서 환경으로 전달되는 힘 행동 신호는 –10N에서 10N까지입니다.

  • 환경에서 관측하는 값은 카트 위치, 카트 속도, 진자 각, 진자 각 도함수입니다.

  • 막대가 수직에서 12도 이상 기울거나 카트가 원래 위치에서 2.4m 이상 이동하면 에피소드가 종료됩니다.

  • 막대가 위쪽을 향해 바로 서 있는 매 시간 스텝마다 보상 +1점이 주어집니다. 진자가 넘어지면 벌점 –10점이 적용됩니다.

이 모델에 대한 자세한 설명은 Load Predefined Control System Environments 항목을 참조하십시오.

관측값 및 행동 사양

환경에서 관측하는 값은 카트 위치, 카트 속도, 진자 각, 진자 각 도함수입니다.

ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';

환경에는 에이전트가 힘 값 -10N 또는 10N 중 하나를 카트에 적용할 수 있는 이산 행동 공간이 있습니다.

ActionInfo = rlFiniteSetSpec([-10 10]);
ActionInfo.Name = 'CartPole Action';

환경 행동과 관측값을 지정하는 방법에 대한 자세한 내용은 rlNumericSpec 항목과 rlFiniteSetSpec 항목을 참조하십시오.

함수 이름을 사용하여 환경 만들기

사용자 지정 환경을 정의하려면 먼저 사용자 지정 step 함수와 reset 함수를 지정하십시오. 이러한 함수는 현재 작업 폴더 또는 MATLAB 경로에 위치해야 합니다.

사용자 지정 reset 함수는 환경의 디폴트 상태를 설정합니다. 이 함수의 시그니처는 다음과 같아야 합니다.

[InitialObservation,LoggedSignals] = myResetFunction()

한 스텝에서 다음 스텝으로 환경 상태와 같은 정보를 전달하려면 LoggedSignals를 사용하십시오. 이 예제에서 LoggedSignals는 카트의 위치 및 속도, 진자 각, 진자 각 도함수 등 카트-폴 환경의 상태를 포함합니다. reset 함수는 환경이 재설정될 때까지 카트 각을 난수 값으로 설정합니다.

이 예제에서는 myResetFunction.m에서 정의한 사용자 지정 재설정 함수를 사용합니다.

type myResetFunction.m
function [InitialObservation, LoggedSignal] = myResetFunction()
% Reset function to place custom cart-pole environment into a random
% initial state.

% Theta (randomize)
T0 = 2 * 0.05 * rand() - 0.05;
% Thetadot
Td0 = 0;
% X
X0 = 0;
% Xdot
Xd0 = 0;

% Return initial environment state variables as logged signals.
LoggedSignal.State = [X0;Xd0;T0;Td0];
InitialObservation = LoggedSignal.State;

end

사용자 지정 step 함수는 환경이 주어진 행동에 기반하여 다음 상태로 진행하는 방법을 지정합니다. 이 함수의 시그니처는 다음과 같아야 합니다.

[Observation,Reward,IsDone,LoggedSignals] = myStepFunction(Action,LoggedSignals)

새 상태를 얻기 위해 환경은 LoggedSignals에 저장된 현재 상태에 동역학 방정식을 적용하며, 이는 미분방정식에 초기 조건을 제공하는 것과 유사합니다. 새 상태는 LoggedSignals에 저장되고 출력값으로 반환됩니다.

이 예제에서는 myStepFunction.m에 정의된 사용자 지정 스텝 함수를 사용합니다. 간단한 구현을 위해 이 함수는 step이 실행될 때마다 카트 질량과 같은 물리 상수를 재정의합니다.

type myStepFunction.m
function [NextObs,Reward,IsDone,LoggedSignals] = myStepFunction(Action,LoggedSignals)
% Custom step function to construct cart-pole environment for the function
% name case.
%
% This function applies the given action to the environment and evaluates
% the system dynamics for one simulation step.

% Define the environment constants.

% Acceleration due to gravity in m/s^2
Gravity = 9.8;
% Mass of the cart
CartMass = 1.0;
% Mass of the pole
PoleMass = 0.1;
% Half the length of the pole
HalfPoleLength = 0.5;
% Max force the input can apply
MaxForce = 10;
% Sample time
Ts = 0.02;
% Pole angle at which to fail the episode
AngleThreshold = 12 * pi/180;
% Cart distance at which to fail the episode
DisplacementThreshold = 2.4;
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1;
% Penalty when the cart-pole fails to balance
PenaltyForFalling = -10;

% Check if the given action is valid.
if ~ismember(Action,[-MaxForce MaxForce])
    error('Action must be %g for going left and %g for going right.',...
        -MaxForce,MaxForce);
end
Force = Action;

% Unpack the state vector from the logged signals.
State = LoggedSignals.State;
XDot = State(2);
Theta = State(3);
ThetaDot = State(4);

% Cache to avoid recomputation.
CosTheta = cos(Theta);
SinTheta = sin(Theta);
SystemMass = CartMass + PoleMass;
temp = (Force + PoleMass*HalfPoleLength*ThetaDot*ThetaDot*SinTheta)/SystemMass;

% Apply motion equations.
ThetaDotDot = (Gravity*SinTheta - CosTheta*temp) / ...
    (HalfPoleLength*(4.0/3.0 - PoleMass*CosTheta*CosTheta/SystemMass));
XDotDot  = temp - PoleMass*HalfPoleLength*ThetaDotDot*CosTheta/SystemMass;

% Perform Euler integration.
LoggedSignals.State = State + Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

% Transform state to observation.
NextObs = LoggedSignals.State;

% Check terminal condition.
X = NextObs(1);
Theta = NextObs(3);
IsDone = abs(X) > DisplacementThreshold || abs(Theta) > AngleThreshold;

% Get reward.
if ~IsDone
    Reward = RewardForNotFalling;
else
    Reward = PenaltyForFalling;
end

end

정의된 관측값 사양, 행동 사양, 함수 이름을 사용하여 사용자 지정 환경을 구성합니다.

env = rlFunctionEnv(ObservationInfo,ActionInfo,'myStepFunction','myResetFunction');

사용자 환경의 동작을 검증하기 위해 rlFunctionEnv는 환경을 만든 후 자동으로 validateEnvironment를 호출합니다.

함수 핸들을 사용하여 환경 만들기

최소한으로 필요한 세트 외에 추가적인 입력 인수를 가지는 사용자 함수를 정의할 수도 있습니다. 예를 들어, 추가 인수 arg1arg2를 스텝 함수와 재설정 함수에 모두 전달하려면 다음 코드를 사용하십시오.

[InitialObservation,LoggedSignals] = myResetFunction(arg1,arg2)
[Observation,Reward,IsDone,LoggedSignals] = myStepFunction(Action,LoggedSignals,arg1,arg2)

rlFunctionEnv에서 이 함수를 사용하려면 익명 함수 핸들을 사용해야 합니다.

ResetHandle = @()myResetFunction(arg1,arg2);
StepHandle = @(Action,LoggedSignals) myStepFunction(Action,LoggedSignals,arg1,arg2);

자세한 내용은 익명 함수 항목을 참조하십시오.

추가 입력 인수를 사용하면 보다 효율적인 환경을 구현할 수 있습니다. 예를 들어, myStepFunction2.m은 환경 상수를 입력 인수(envConstants)로 가져오는 사용자 지정 step 함수를 포함합니다. 그렇게 함으로써 이 함수는 매 스텝마다 환경 상수를 재정의하지 않아도 됩니다.

type myStepFunction2.m
function [NextObs,Reward,IsDone,LoggedSignals] = myStepFunction2(Action,LoggedSignals,EnvConstants)
% Custom step function to construct cart-pole environment for the function
% handle case.
%
% This function applies the given action to the environment and evaluates
% the system dynamics for one simulation step.

% Check if the given action is valid.
if ~ismember(Action,[-EnvConstants.MaxForce EnvConstants.MaxForce])
    error('Action must be %g for going left and %g for going right.',...
        -EnvConstants.MaxForce,EnvConstants.MaxForce);
end
Force = Action;

% Unpack the state vector from the logged signals.
State = LoggedSignals.State;
XDot = State(2);
Theta = State(3);
ThetaDot = State(4);

% Cache to avoid recomputation.
CosTheta = cos(Theta);
SinTheta = sin(Theta);
SystemMass = EnvConstants.MassCart + EnvConstants.MassPole;
temp = (Force + EnvConstants.MassPole*EnvConstants.Length*ThetaDot*ThetaDot*SinTheta)/SystemMass;

% Apply motion equations.
ThetaDotDot = (EnvConstants.Gravity*SinTheta - CosTheta*temp)...
    / (EnvConstants.Length*(4.0/3.0 - EnvConstants.MassPole*CosTheta*CosTheta/SystemMass));
XDotDot  = temp - EnvConstants.MassPole*EnvConstants.Length*ThetaDotDot*CosTheta/SystemMass;

% Perform Euler integration.
LoggedSignals.State = State + EnvConstants.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

% Transform state to observation.
NextObs = LoggedSignals.State;

% Check terminal condition.
X = NextObs(1);
Theta = NextObs(3);
IsDone = abs(X) > EnvConstants.XThreshold || abs(Theta) > EnvConstants.ThetaThresholdRadians;

% Get reward.
if ~IsDone
    Reward = EnvConstants.RewardForNotFalling;
else
    Reward = EnvConstants.PenaltyForFalling;
end

end

환경 상수가 포함된 구조체를 만듭니다.

% Acceleration due to gravity in m/s^2
envConstants.Gravity = 9.8;
% Mass of the cart
envConstants.MassCart = 1.0;
% Mass of the pole
envConstants.MassPole = 0.1;
% Half the length of the pole
envConstants.Length = 0.5;
% Max force the input can apply
envConstants.MaxForce = 10;
% Sample time
envConstants.Ts = 0.02;
% Angle at which to fail the episode
envConstants.ThetaThresholdRadians = 12 * pi/180;
% Distance at which to fail the episode
envConstants.XThreshold = 2.4;
% Reward each time step the cart-pole is balanced
envConstants.RewardForNotFalling = 1;
% Penalty when the cart-pole fails to balance
envConstants.PenaltyForFalling = -5;

사용자 지정 step 함수에 대한 익명 함수 핸들을 만들어 envConstants를 추가 입력 인수로 전달합니다. envConstantsStepHandle이 만들어질 때마다 사용할 수 있으므로 함수 핸들에는 이러한 값이 포함됩니다. 이러한 값은 변수를 지워도 함수 핸들 내에 유지됩니다.

StepHandle = @(Action,LoggedSignals) myStepFunction2(Action,LoggedSignals,envConstants);

동일한 reset 함수를 사용하여, 이름을 사용하는 대신 함수 핸들로 지정합니다.

ResetHandle = @() myResetFunction;

사용자 지정 함수 핸들을 사용하여 환경을 만듭니다.

env2 = rlFunctionEnv(ObservationInfo,ActionInfo,StepHandle,ResetHandle);

사용자 지정 함수 검증하기

사용자 환경에서 에이전트를 훈련시키기 전에 사용자 지정 함수의 동작을 검증하는 것이 좋습니다. 그렇게 하면 reset 함수를 사용하여 환경을 초기화하고 step 함수를 사용하여 하나의 시뮬레이션 스텝을 실행할 수 있습니다. 재현이 가능하도록, 검증하기 전에 난수 생성기 시드값을 설정합니다.

함수 이름을 사용하여 만든 환경을 검증합니다.

rng(0);
InitialObs = reset(env)
InitialObs = 4×1

         0
         0
    0.0315
         0

[NextObs,Reward,IsDone,LoggedSignals] = step(env,10);
NextObs
NextObs = 4×1

         0
    0.1947
    0.0315
   -0.2826

함수 핸들을 사용하여 만든 환경을 검증합니다.

rng(0);
InitialObs2 = reset(env2)
InitialObs2 = 4×1

         0
         0
    0.0315
         0

[NextObs2,Reward2,IsDone2,LoggedSignals2] = step(env2,10);
NextObs2
NextObs2 = 4×1

         0
    0.1947
    0.0315
   -0.2826

두 환경 모두 성공적으로 초기화 및 시뮬레이션되어 NextObs에서 동일한 상태 값을 생성합니다.

참고 항목

관련 항목