이 페이지의 내용은 이전 릴리스에 관한 것입니다. 해당 영문 페이지는 최신 릴리스에서 제거되었습니다.

딥러닝 신경망 훈련 중의 출력값 사용자 지정하기

이 예제에서는 딥러닝 신경망 훈련 중에 각 반복에서 실행되는 출력 함수를 정의하는 방법을 보여줍니다. trainingOptions'OutputFcn' 이름-값 쌍 인수를 사용하여 출력 함수를 지정하면 trainNetwork가 훈련을 시작하기 전에 한 번, 각 훈련 반복이 실행된 뒤에 한 번, 그리고 훈련이 완료된 뒤에 한 번 이러한 함수를 호출합니다. 출력 함수가 호출될 때마다 trainNetwork는 현재 반복 횟수, 손실, 정확도와 같은 정보를 포함하는 구조체를 전달합니다. 출력 함수를 사용하여 진행 상황 정보를 표시하거나 플로팅할 수도 있고 훈련을 중단할 수도 있습니다. 훈련을 조기에 중단하려면 출력 함수가 true를 반환하도록 하십시오. 출력 함수가 true를 반환하면 훈련이 완료되고 trainNetwork 가 마지막 네트워크를 반환합니다.

검증 세트에서 손실이 더 이상 감소하지 않을 때 훈련을 중단하려면 각각 trainingOptions'ValidationData' 이름-값 쌍 인수와 'ValidationPatience' 이름-값 쌍 인수를 사용하여 검증 데이터와 검증 인내도를 지정하십시오. 검증 인내도(validation patience)는 네트워크 훈련이 중단되기 전까지 검증 세트에서의 손실이 그전까지의 가장 작은 손실보다 크거나 같아도 되는 횟수입니다. 출력 함수를 사용하여 더 많은 중지 기준을 추가할 수 있습니다. 이 예제에서는 검증 데이터의 분류 정확도가 더 이상 향상되지 않을 때 훈련을 중단하는 출력 함수를 만드는 방법을 설명합니다. 출력 함수는 스크립트의 끝에서 정의됩니다.

숫자 영상 5,000개를 포함하는 훈련 데이터를 불러옵니다. 네트워크 검증을 위해 영상 1,000개를 따로 빼 둡니다.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

숫자 영상 데이터를 분류할 네트워크를 구성합니다.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

네트워크 훈련 옵션을 지정합니다. 훈련 중에 규칙적인 간격으로 네트워크를 검증하려면 검증 데이터를 지정하십시오. Epoch당 한 번씩 네트워크 검증이 실시되도록 'ValidationFrequency' 값을 선택합니다.

검증 세트에 대한 분류 정확도가 더 이상 향상되지 않을 때 훈련을 중단하려면 stopIfAccuracyNotImproving을 출력 함수로 지정하십시오. stopIfAccuracyNotImproving의 두 번째 입력 인수는 네트워크 훈련이 중단되기 전까지 검증 세트에 대한 정확도가 그전까지의 가장 높은 정확도보다 작거나 같아도 되는 횟수입니다. 훈련시킬 최대 Epoch 횟수로 임의의 큰 값을 선택합니다. 훈련이 자동으로 중단되므로 마지막 Epoch까지 훈련이 진행되지 않습니다.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

네트워크를 훈련시킵니다. 검증의 정확도가 더 이상 증가하지 않으면 훈련이 중단됩니다.

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:01 |        7.81% |       12.70% |       2.7155 |       2.5169 |          0.0100 |
|       1 |          31 |       00:00:06 |       71.88% |       74.80% |       0.8816 |       0.8143 |          0.0100 |
|       2 |          62 |       00:00:12 |       87.50% |       87.70% |       0.3898 |       0.4476 |          0.0100 |
|       3 |          93 |       00:00:18 |       95.31% |       94.10% |       0.2203 |       0.2567 |          0.0100 |
|       4 |         124 |       00:00:24 |       96.09% |       96.60% |       0.1475 |       0.1754 |          0.0100 |
|       5 |         155 |       00:00:31 |       98.44% |       97.60% |       0.0992 |       0.1308 |          0.0100 |
|       6 |         186 |       00:00:36 |       99.22% |       97.70% |       0.0775 |       0.1124 |          0.0100 |
|       7 |         217 |       00:00:42 |      100.00% |       98.10% |       0.0554 |       0.0938 |          0.0100 |
|       8 |         248 |       00:00:47 |      100.00% |       98.00% |       0.0438 |       0.0858 |          0.0100 |
|       9 |         279 |       00:00:52 |      100.00% |       98.00% |       0.0346 |       0.0784 |          0.0100 |
|      10 |         310 |       00:00:58 |      100.00% |       98.40% |       0.0275 |       0.0679 |          0.0100 |
|      11 |         341 |       00:01:03 |      100.00% |       98.50% |       0.0236 |       0.0617 |          0.0100 |
|      12 |         372 |       00:01:08 |      100.00% |       98.70% |       0.0212 |       0.0565 |          0.0100 |
|      13 |         403 |       00:01:13 |      100.00% |       98.60% |       0.0186 |       0.0531 |          0.0100 |
|      14 |         434 |       00:01:18 |      100.00% |       98.70% |       0.0163 |       0.0505 |          0.0100 |
|      15 |         465 |       00:01:23 |      100.00% |       98.80% |       0.0143 |       0.0480 |          0.0100 |
|      16 |         496 |       00:01:27 |      100.00% |       99.00% |       0.0126 |       0.0457 |          0.0100 |
|      17 |         527 |       00:01:32 |      100.00% |       99.00% |       0.0112 |       0.0433 |          0.0100 |
|      18 |         558 |       00:01:37 |      100.00% |       98.90% |       0.0100 |       0.0415 |          0.0100 |
|      19 |         589 |       00:01:41 |      100.00% |       99.10% |       0.0092 |       0.0398 |          0.0100 |
|      20 |         620 |       00:01:46 |      100.00% |       99.30% |       0.0086 |       0.0385 |          0.0100 |
|      21 |         651 |       00:01:50 |      100.00% |       99.20% |       0.0081 |       0.0373 |          0.0100 |
|      22 |         682 |       00:01:55 |      100.00% |       99.20% |       0.0076 |       0.0362 |          0.0100 |
|      23 |         713 |       00:02:00 |      100.00% |       99.20% |       0.0072 |       0.0353 |          0.0100 |
|======================================================================================================================|

출력 함수 정의하기

검증 데이터에 대한 가장 높은 분류 정확도가 네트워크 검증에서 N번 연속으로 향상되지 않으면 네트워크 훈련을 중지하는 출력 함수 stopIfAccuracyNotImproving(info,N) 을 정의합니다. 이 조건은 손실이 아닌 분류 정확도에 적용된다는 점만 제외하면 검증 손실을 사용하는 내장된 중지 기준과 비슷합니다.

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;
    
elseif ~isempty(info.ValidationLoss)
    
    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end
    
    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end
    
end

end

참고 항목

|

관련 항목