Main Content

클래스 템플릿에서 사용자 지정 환경 만들기

템플릿 환경 클래스를 만들고 수정하여 사용자 지정 강화 학습 환경을 정의할 수 있습니다. 사용자 지정 템플릿 환경을 사용하여 다음을 수행할 수 있습니다.

  • 보다 복잡한 환경 동특성을 구현합니다.

  • 환경에 사용자 지정 시각화를 추가합니다.

  • C++, Java® 또는 Python® 같은 언어로 정의된 타사 라이브러리에 대한 인터페이스를 생성합니다. 자세한 내용은 외부 언어 인터페이스 항목을 참조하십시오.

MATLAB® 클래스 생성에 대한 자세한 내용은 사용자 정의 클래스 항목을 참조하십시오.

덜 복잡한 사용자 지정 강화 학습 환경을 만들려면 Create Custom Environment Using Step and Reset Functions에 설명된 것처럼 사용자 지정 함수를 사용하면 됩니다.

템플릿 클래스 만들기

사용자 지정 환경을 정의하려면 먼저 클래스 이름을 지정하여 템플릿 클래스 파일을 만듭니다. 이 예제에서는 클래스 이름을 MyEnvironment로 지정합니다.

rlCreateEnvTemplate("MyEnvironment")

함수 rlCreateEnvTemplate은 템플릿 클래스 파일을 생성하고 엽니다. 템플릿 파일의 시작 부분에 명시된 클래스 정의에서 볼 수 있듯이, 템플릿 클래스는 rl.env.MATLABEnvironment 추상 클래스의 서브클래스입니다. 이 추상 클래스는 다른 MATLAB 강화 학습 환경 객체에 사용되는 것과 동일합니다.

classdef MyEnvironment < rl.env.MATLABEnvironment

기본적으로 템플릿 클래스는 Load Predefined Control System Environments에서 설명하는 사전 정의된 카트-폴 환경과 유사한 단순한 카트-폴 균형 유지 모델을 구현합니다.

환경 동특성을 정의하려면 파일을 MyEnvironment.m으로 저장하십시오. 그런 후 다음을 지정하여 템플릿 클래스를 수정합니다.

  • 환경 속성

  • 필수 환경 메서드

  • 선택적 환경 메서드

환경 속성

템플릿의 properties 섹션에서, 환경을 생성하고 시뮬레이션하는 데 필요한 파라미터들을 지정하십시오. 파라미터에는 다음이 포함될 수 있습니다.

  • 물리 상수 — 샘플 환경에서 중력(Gravity)으로 인한 가속을 정의합니다.

  • 환경 기하 — 샘플 환경에서 카트와 막대의 질량(CartMassPoleMass)과 막대의 절반 길이(HalfPoleLength)를 정의합니다.

  • 환경 제약 조건 — 샘플 환경에서 막대 각과 카트 거리의 임계값(AngleThresholdDisplacementThreshold)을 정의합니다. 환경은 이러한 값을 사용하여 훈련 에피소드가 완료되는 시점을 감지합니다.

  • 환경 실행에 필요한 변수 — 샘플 환경에서 상태 벡터(State) 그리고 에피소드가 완료되는 시점을 나타내는 플래그(IsDone)를 정의합니다.

  • 행동 공간 또는 관측값 공간 정의를 위한 상수 — 샘플 환경에서 행동 공간의 최대 힘(MaxForce)을 정의합니다.

  • 보상 신호 계산을 위한 상수 — 샘플 환경에서 상수 RewardForNotFallingPenaltyForFalling을 정의합니다.

properties
    % Specify and initialize the necessary properties of the environment  
    % Acceleration due to gravity in m/s^2
    Gravity = 9.8
    
    % Mass of the cart
    CartMass = 1.0
    
    % Mass of the pole
    PoleMass = 0.1
    
    % Half the length of the pole
    HalfPoleLength = 0.5
    
    % Max force the input can apply
    MaxForce = 10
           
    % Sample time
    Ts = 0.02
    
    % Angle at which to fail the episode (radians)
    AngleThreshold = 12 * pi/180
        
    % Distance at which to fail the episode
    DisplacementThreshold = 2.4
        
    % Reward each time step the cart-pole is balanced
    RewardForNotFalling = 1
    
    % Penalty when the cart-pole fails to balance
    PenaltyForFalling = -10 
end
    
properties
    % Initialize system state [x,dx,theta,dtheta]'
    State = zeros(4,1)
end

properties(Access = protected)
    % Initialize internal flag to indicate episode termination
    IsDone = false        
end

필수 함수

강화 학습 환경을 정의하려면 다음 함수가 필요합니다. getObservationInfo, getActionInfo, simvalidateEnvironment 함수는 기본 추상 클래스에 이미 정의되어 있습니다. 환경을 생성하려면 생성자 함수와 reset 함수 및 step 함수를 정의해야 합니다.

함수설명
getObservationInfo환경 관측값에 대한 정보를 반환
getActionInfo환경 행동에 대한 정보를 반환
sim에이전트를 사용하여 환경 시뮬레이션
validateEnvironmentreset 함수를 호출하고 step을 사용하여 하나의 시간 스텝을 시뮬레이션하여 환경을 검증
reset환경 상태를 초기화하고 모든 시각화를 정리
step행동을 적용하고, 하나의 스텝에 대해 환경을 시뮬레이션하고, 관측값과 보상을 출력. 또한 에피소드의 완료 여부를 나타내는 플래그를 설정
생성자 함수클래스의 인스턴스를 생성하는 클래스와 동일한 이름을 갖는 함수

샘플 생성자 함수

샘플 카트-폴 생성자 함수는 다음 단계를 수행하여 환경을 생성합니다.

  • 행동 사양과 관측값 사양 정의. 이러한 사양 생성에 대한 자세한 내용은 rlNumericSpecrlFiniteSetSpec 항목을 참조하십시오.

  • 기본 추상 클래스의 생성자 호출.

function this = MyEnvironment()
    % Initialize observation settings
    ObservationInfo = rlNumericSpec([4 1]);
    ObservationInfo.Name = 'CartPole States';
    ObservationInfo.Description = 'x, dx, theta, dtheta';

    % Initialize action settings   
    ActionInfo = rlFiniteSetSpec([-1 1]);
    ActionInfo.Name = 'CartPole Action';

    % The following line implements built-in functions of the RL environment
    this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);

    % Initialize property values and precompute necessary values
    updateActionInfo(this);
end

이 샘플 생성자 함수는 입력 인수를 포함하지 않습니다. 하지만 사용자 지정 생성자에 대해서는 입력 인수를 추가할 수 있습니다.

샘플 reset 함수

샘플 카트-폴 reset 함수는 모델의 초기 조건을 설정하고 관측값의 초기값을 반환합니다. 또한 envUpdatedCallback 함수를 호출함으로써 환경이 업데이트되었다는 알림을 생성합니다. 이는 환경 시각화 업데이트에 유용합니다.

% Reset environment to initial state and return initial observation
function InitialObservation = reset(this)
    % Theta (+- .05 rad)
    T0 = 2 * 0.05 * rand - 0.05;  
    % Thetadot
    Td0 = 0;
    % X 
    X0 = 0;
    % Xdot
    Xd0 = 0;

    InitialObservation = [X0;Xd0;T0;Td0];
    this.State = InitialObservation;

    % (Optional) Use notifyEnvUpdated to signal that the 
    % environment is updated (for example, to update the visualization)
    notifyEnvUpdated(this);
end

샘플 step 함수

샘플 카트-폴 step 함수는 다음 동작을 수행합니다.

  • 입력 행동을 처리합니다.

  • 하나의 시간 스텝에 대해 환경 동특성 방정식을 실행합니다.

  • 업데이트된 관측값을 계산하고 반환합니다.

  • 보상 신호를 계산하고 반환합니다.

  • 에피소드가 완료되었는지 확인하고 그에 맞게 적절한 IsDone 신호를 반환합니다.

  • 환경이 업데이트되었다는 알림을 생성합니다.

function [Observation,Reward,IsDone,Info] = step(this,Action)
    Info = [];

    % Get action
    Force = getForce(this,Action);            

    % Unpack state vector
    XDot = this.State(2);
    Theta = this.State(3);
    ThetaDot = this.State(4);

    % Cache to avoid recomputation
    CosTheta = cos(Theta);
    SinTheta = sin(Theta);            
    SystemMass = this.CartMass + this.PoleMass;
    temp = (Force + this.PoleMass*this.HalfPoleLength*ThetaDot^2*SinTheta)...
        /SystemMass;

    % Apply motion equations            
    ThetaDotDot = (this.Gravity*SinTheta - CosTheta*temp)...
        / (this.HalfPoleLength*(4.0/3.0 - this.PoleMass*CosTheta*CosTheta/SystemMass));
    XDotDot  = temp - this.PoleMass*this.HalfPoleLength*ThetaDotDot*CosTheta/SystemMass;

    % Euler integration
    Observation = this.State + this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

    % Update system states
    this.State = Observation;

    % Check terminal condition
    X = Observation(1);
    Theta = Observation(3);
    IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;
    this.IsDone = IsDone;

    % Get reward
    Reward = getReward(this);

    % (Optional) Use notifyEnvUpdated to signal that the 
    % environment has been updated (for example, to update the visualization)
    notifyEnvUpdated(this);
end

선택적 함수

필요한 경우 템플릿 클래스에서 그 외 다른 함수를 정의할 수 있습니다. 예를 들어 step 또는 reset으로 호출되는 헬퍼 함수를 생성할 수 있습니다. 카트-폴 템플릿 모델은 각 시간 스텝에서의 보상을 계산하는 getReward 함수를 구현합니다.

function Reward = getReward(this)
    if ~this.IsDone
        Reward = this.RewardForNotFalling;
    else
        Reward = this.PenaltyForFalling;
    end          
end

환경 시각화

plot 함수를 구현하여 사용자 지정 환경에 시각화를 추가할 수 있습니다. plot 함수에서 다음을 수행합니다.

  • 사용자가 직접 구현한 시각화 클래스에 대한 Figure 또는 인스턴스를 생성합니다. 이 예제에서는 Figure를 만들고 환경 객체 내에 Figure에 대한 핸들을 저장합니다.

  • envUpdatedCallback 함수를 호출합니다.

function plot(this)
    % Initiate the visualization
    this.Figure = figure('Visible','on','HandleVisibility','off');
    ha = gca(this.Figure);
    ha.XLimMode = 'manual';
    ha.YLimMode = 'manual';
    ha.XLim = [-3 3];
    ha.YLim = [-1 2];
    hold(ha,'on');
    % Update the visualization
    envUpdatedCallback(this)
end

이 예제에서는 Figure에 대한 핸들을 환경 객체의 보호 속성으로 저장합니다.

properties(Access = protected)
    % Initialize internal flag to indicate episode termination
    IsDone = false 

    % Handle to figure
    Figure
end

envUpdatedCallback에서 시각화를 Figure에 플로팅하거나 사용자 지정 시각화 객체를 사용합니다. 예를 들어 Figure 핸들이 설정되었는지 확인합니다. 설정된 경우 시각화를 플로팅합니다.

function envUpdatedCallback(this)
    if ~isempty(this.Figure) && isvalid(this.Figure)
        % Set visualization figure as the current figure
        ha = gca(this.Figure);

        % Extract the cart position and pole angle
        x = this.State(1);
        theta = this.State(3);

        cartplot = findobj(ha,'Tag','cartplot');
        poleplot = findobj(ha,'Tag','poleplot');
        if isempty(cartplot) || ~isvalid(cartplot) ...
                || isempty(poleplot) || ~isvalid(poleplot)
            % Initialize the cart plot
            cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);
            cartpoly = translate(cartpoly,[x 0]);
            cartplot = plot(ha,cartpoly,'FaceColor',[0.8500 0.3250 0.0980]);
            cartplot.Tag = 'cartplot';

            % Initialize the pole plot
            L = this.HalfPoleLength*2;
            polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);
            polepoly = translate(polepoly,[x,0]);
            polepoly = rotate(polepoly,rad2deg(theta),[x,0]);
            poleplot = plot(ha,polepoly,'FaceColor',[0 0.4470 0.7410]);
            poleplot.Tag = 'poleplot';
        else
            cartpoly = cartplot.Shape;
            polepoly = poleplot.Shape;
        end

        % Compute the new cart and pole position
        [cartposx,~] = centroid(cartpoly);
        [poleposx,poleposy] = centroid(polepoly);
        dx = x - cartposx;
        dtheta = theta - atan2(cartposx-poleposx,poleposy-0.25/2);
        cartpoly = translate(cartpoly,[dx,0]);
        polepoly = translate(polepoly,[dx,0]);
        polepoly = rotate(polepoly,rad2deg(dtheta),[x,0.25/2]);

        % Update the cart and pole positions on the plot
        cartplot.Shape = cartpoly;
        poleplot.Shape = polepoly;

        % Refresh rendering in the figure window
        drawnow();
    end
end

환경에서 envUpdatedCallback 함수를 호출하기 때문에 환경이 업데이트될 때마다 시각화를 업데이트합니다.

사용자 지정 환경 인스턴스화하기

사용자 지정 환경 클래스를 정의한 후 MATLAB 작업 공간에서 이 클래스의 인스턴스를 만듭니다. 명령줄에 다음을 입력합니다.

env = MyEnvironment;

생성자가 입력 인수를 가지면 클래스 이름 뒤에 입력 인수를 지정합니다. 예를 들면 MyEnvironment(arg1,arg2)처럼 지정합니다.

환경을 만든 후에는 환경 동특성을 검증하는 것이 좋습니다. 그렇게 하려면 validateEnvironment 함수를 사용하십시오. 그러면 환경 구현에 문제가 있는 경우 명령 창에 오류가 표시됩니다.

validateEnvironment(env)

환경 객체를 검증한 다음에는 환경 객체를 사용하여 강화 학습 에이전트를 훈련시킬 수 있습니다. 에이전트 훈련에 대한 자세한 내용은 강화 학습 에이전트 훈련시키기 항목을 참조하십시오.

참고 항목

함수

객체

관련 항목