Main Content

분류 신경망을 회귀 신경망으로 변환하기

이 예제에서는 훈련된 분류 신경망을 회귀 신경망으로 변환하는 방법을 보여줍니다.

사전 훈련된 영상 분류 신경망은 1백만 개가 넘는 영상에 대해 훈련되었으며 영상을 키보드, 커피 머그잔, 연필, 각종 동물 등 1,000가지 사물 범주로 분류할 수 있습니다. 분류 신경망은 다양한 영상을 대표하는 다양한 특징을 학습했습니다. 이 신경망은 영상을 입력값으로 받아서 영상에 있는 사물에 대한 레이블과 각 사물 범주의 확률을 출력합니다.

전이 학습은 딥러닝 응용 분야에서 널리 사용됩니다. 사전 훈련된 신경망을 새로운 작업을 학습하기 위한 출발점으로 사용할 수 있습니다. 이 예제에서는 사전 훈련된 분류 신경망을 회귀 작업을 위해 다시 훈련시키는 방법을 보여줍니다.

이 예제에서는 분류를 위해 사전 훈련된 컨벌루션 신경망 아키텍처를 불러오고 분류를 위한 계층을 대체한 다음 회전된 손글씨 숫자의 각도를 예측하도록 신경망을 다시 훈련시킵니다.

사전 훈련된 신경망 불러오기

지원 파일 digitsClassificationConvolutionNet.mat에서 사전 훈련된 신경망을 불러옵니다. 이 파일에는 손으로 쓴 숫자를 분류하는 분류 신경망이 들어 있습니다.

load digitsClassificationConvolutionNet
layers = net.Layers
layers = 
  13x1 Layer array with layers:

     1   'imageinput'    Image Input                  28x28x1 images
     2   'conv_1'        2-D Convolution              10 3x3x1 convolutions with stride [2  2] and padding [0  0  0  0]
     3   'batchnorm_1'   Batch Normalization          Batch normalization with 10 channels
     4   'relu_1'        ReLU                         ReLU
     5   'conv_2'        2-D Convolution              20 3x3x10 convolutions with stride [2  2] and padding [0  0  0  0]
     6   'batchnorm_2'   Batch Normalization          Batch normalization with 20 channels
     7   'relu_2'        ReLU                         ReLU
     8   'conv_3'        2-D Convolution              40 3x3x20 convolutions with stride [2  2] and padding [0  0  0  0]
     9   'batchnorm_3'   Batch Normalization          Batch normalization with 40 channels
    10   'relu_3'        ReLU                         ReLU
    11   'gap'           2-D Global Average Pooling   2-D global average pooling
    12   'fc'            Fully Connected              10 fully connected layer
    13   'softmax'       Softmax                      softmax

데이터 불러오기

데이터 세트에는 손으로 쓴 숫자를 나타내는 합성 영상과 영상 각각의 회전 각도(단위: 도)가 포함되어 있습니다.

지원 파일 DigitsDataTrain.matDigitsDataTest.mat에서 훈련 영상과 테스트 영상을 4차원 배열로 불러옵니다. 변수 anglesTrainanglesTest는 회전 각도(단위: 도)입니다. 훈련 및 테스트 데이터 세트는 각각 5,000개의 영상을 포함합니다.

load DigitsDataTrain
load DigitsDataTest

imshow를 사용하여 임의의 훈련 영상 20개를 표시합니다.

numTrainImages = numel(anglesTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

마지막 계층 바꾸기

신경망의 컨벌루션 계층은 마지막 학습 가능한 계층이 입력 영상을 분류하는 데 사용하는 영상 특징을 추출합니다. 계층 'fc'는 신경망이 추출하는 특징을 클래스 확률로 조합하는 방법에 대한 정보를 포함합니다. 사전 훈련된 신경망이 회귀를 수행하도록 다시 훈련시키려면, 이 계층과 그 뒤에 오는 소프트맥스 계층을 이 작업에 적합한 새로운 계층으로 교체합니다.

마지막 완전 연결 계층을 크기가 1(응답 변수의 개수)인 완전 연결 계층으로 교체합니다.

numResponses = 1;
layer = fullyConnectedLayer(numResponses,Name="fc");

net = replaceLayer(net,"fc",layer)
net = 
  dlnetwork with properties:

         Layers: [13x1 nnet.cnn.layer.Layer]
    Connections: [12x2 table]
     Learnables: [14x3 table]
          State: [6x3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 0

  View summary with summary.

소프트맥스 계층을 제거합니다.

net = removeLayers(net,"softmax");

계층 학습률 인자 조정하기

새로운 데이터를 사용하여 신경망을 다시 훈련시킬 준비가 되었습니다. 원한다면 훈련 옵션을 지정할 때 새로 만든 완전 연결 계층의 학습률을 높이고 전역 학습률을 줄여서 신경망 앞쪽의 계층에서 가중치의 훈련 속도를 늦출 수 있습니다.

완전 연결 계층 파라미터의 학습률을 setLearnRateFactor 함수를 사용할 때의 비율만큼 높입니다.

net = setLearnRateFactor(net,"fc","Weights",10);
net = setLearnRateFactor(net,"fc","Bias",10);

훈련 옵션 지정하기

훈련 옵션을 지정합니다. 옵션 중에서 선택하려면 경험적 분석이 필요합니다. 실험을 실행하여 다양한 훈련 옵션 구성을 살펴보려면 Experiment Manager 앱을 사용합니다.

  • 학습률을 0.0001로 지정하여 낮춥니다.

  • 훈련 진행 상황을 플롯으로 표시합니다.

  • 상세 출력값을 비활성화합니다.

options = trainingOptions("sgdm",...
    InitialLearnRate=0.001, ...
    Plots="training-progress",...
    Verbose=false);

신경망 훈련시키기

trainnet 함수를 사용하여 신경망을 훈련시킵니다. 회귀의 경우 평균제곱오차 손실을 사용합니다. 기본적으로 trainnet 함수는 GPU를 사용할 수 있으면 GPU를 사용합니다. GPU를 사용하려면 Parallel Computing Toolbox™ 라이선스와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. GPU를 사용할 수 없는 경우, 함수는 CPU를 사용합니다. 실행 환경을 지정하려면 ExecutionEnvironment 훈련 옵션을 사용하십시오.

net = trainnet(XTrain,anglesTrain,net,"mse",options);

신경망 테스트하기

테스트 데이터에 대한 정확도를 평가함으로써 신경망의 성능을 테스트합니다.

predict를 사용하여 검증 영상의 회전 각도를 예측합니다.

YTest = predict(net,XTest);

산점도 플롯에서 예측을 시각화합니다. 실제 값에 대해 예측된 값을 플로팅합니다.

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],"r--")

참고 항목

| |

관련 항목