주요 콘텐츠

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

여러 개의 출력값을 갖는 신경망 훈련시키기

이 예제에서는 손으로 쓴 숫자의 레이블과 회전 각도를 모두 예측하는, 여러 개의 출력값을 갖는 딥러닝 신경망을 훈련시키는 방법을 보여줍니다.

훈련 데이터 불러오기

숫자 데이터를 불러옵니다. 이 데이터에는 숫자의 영상과 숫자 레이블, 그리고 세로 방향을 기준으로 한 회전 각도가 들어 있습니다.

load DigitsDataTrain

영상, 레이블 및 각도에 대한 arrayDatastore 객체를 만든 다음, combine 함수를 사용하여 모든 훈련 데이터를 포함하는 단일 데이터저장소를 만듭니다.

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsT1Train = arrayDatastore(labelsTrain);
dsT2Train = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);

classNames = categories(labelsTrain);
numClasses = numel(classNames);
numObservations = numel(labelsTrain);

훈련 데이터에서 일부 영상을 표시합니다.

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

딥러닝 모델 정의하기

레이블과 회전 각도를 모두 예측하는 신경망을 다음과 같이 정의합니다.

  • 5×5 필터 16개를 갖는 convolution-batchnorm-ReLU 블록 한 개.

  • 각각 3×3 필터 32개를 갖는 convolution-batchnorm-ReLU 블록 두 개.

  • 1×1 컨벌루션 32개를 갖는 convolution-batchnorm-ReLU 블록이 포함된,이전의 두 블록을 둘러싸는 건너뛰기 연결.

  • 추가를 사용하여 건너뛰기 연결 병합.

  • 분류 출력의 경우, 크기 10(클래스 개수)의 완전 연결 연산과 소프트맥스 연산을 사용하는 분기.

  • 회귀 출력의 경우, 크기 1(응답 변수 개수)의 완전 연결 연산을 사용하는 분기.

계층의 메인 블록을 정의합니다.

net = dlnetwork;

layers = [
    imageInputLayer([28 28 1],Normalization="none")

    convolution2dLayer(5,16,Padding="same")
    batchNormalizationLayer
    reluLayer(Name="relu_1")

    convolution2dLayer(3,32,Padding="same",Stride=2)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer

    additionLayer(2,Name="add")

    fullyConnectedLayer(numClasses)
    softmaxLayer(Name="softmax")];

net = addLayers(net,layers);

건너뛰기 연결을 추가합니다.

layers = [
    convolution2dLayer(1,32,Stride=2,Name="conv_skip")
    batchNormalizationLayer
    reluLayer(Name="relu_skip")];

net = addLayers(net,layers);
net = connectLayers(net,"relu_1","conv_skip");
net = connectLayers(net,"relu_skip","add/in2");

회귀를 위해 완전 연결 계층을 추가합니다.

layers = fullyConnectedLayer(1,Name="fc_2");
net = addLayers(net,layers);
net = connectLayers(net,"add","fc_2");

계층 그래프를 플롯에 표시합니다.

figure
plot(net)

훈련 옵션 지정하기

훈련 옵션을 지정합니다. 옵션 중에서 선택하려면 경험적 분석이 필요합니다. 실험을 실행하여 다양한 훈련 옵션 구성을 살펴보려면 실험 관리자 앱을 사용합니다.

options = trainingOptions("adam", ...
    Plots="training-progress", ...
    Verbose=false);

신경망 훈련시키기

trainnet 함수를 사용하여 신경망을 훈련시킵니다. 분류를 위해, 예측 레이블과 목표 레이블의 교차 엔트로피 손실과, 예측 각도와 목표 각도의 평균제곱오차 손실에 0.1을 곱한 값의 합을 사용자 지정 손실 함수로 사용하십시오. 기본적으로 trainnet 함수는 GPU를 사용할 수 있으면 GPU를 사용합니다. GPU를 사용하려면 Parallel Computing Toolbox™ 라이선스와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. GPU를 사용할 수 없는 경우, 함수는 CPU를 사용합니다. 실행 환경을 지정하려면 ExecutionEnvironment 훈련 옵션을 사용하십시오.

사용자 지정 손실 함수를 함수 핸들로 정의합니다. 예측 레이블과 목표 레이블의 교차 엔트로피 손실과, 예측 각도와 목표 각도의 평균제곱오차(0.1배 스케일링됨)의 합을 손실로 정의하십시오.

lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + 0.1*mse(Y2,T2);

신경망을 훈련시킵니다.

net = trainnet(dsTrain,net,lossFcn,options);

모델 테스트하기

숫자 데이터를 불러옵니다. 이 데이터에는 숫자의 영상과 숫자 레이블, 그리고 세로 방향을 기준으로 한 회전 각도가 들어 있습니다.

load DigitsDataTest

minibatchpredict 함수를 사용하여 예측을 수행하고 scores2label 함수를 사용하여 분류 점수를 레이블로 변환합니다. 기본적으로 minibatchpredict 함수는 GPU를 사용할 수 있으면 GPU를 사용합니다. GPU를 사용하려면 Parallel Computing Toolbox™ 라이선스와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. GPU를 사용할 수 없는 경우, 함수는 CPU를 사용합니다. 실행 환경을 수동으로 선택하려면 minibatchpredict 함수의 ExecutionEnvironment 인수를 사용하십시오.

[scores,Y2] = minibatchpredict(net,XTest);
Y1 = scores2label(scores,classNames);

레이블의 분류 정확도를 계산합니다.

accuracy = mean(Y1 == labelsTest)
accuracy = 0.9732

예측 각도와 목표 각도 사이의 RMS 오차를 계산합니다.

err = rmse(Y2,anglesTest)
err = single
    6.9265

영상 일부를 예측값과 함께 표시합니다. 예측 각도를 빨간색으로, 정확한 레이블을 녹색으로 표시합니다.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on

    sz = size(I,1);
    offset = sz/2;

    theta = Y2(idx(i));
    plot(offset*[1-tand(theta) 1+tand(theta)],[sz 0],"r--")

    thetaTest = anglesTest(idx(i));
    plot(offset*[1-tand(thetaTest) 1+tand(thetaTest)],[sz 0],"g--")

    hold off
    label = Y1(idx(i));
    title("Label: " + string(label))
end

참고 항목

| | | | | | | | | | |

도움말 항목