Main Content

이 페이지의 내용은 이전 릴리스에 관한 것입니다. 해당 영문 페이지는 최신 릴리스에서 제거되었습니다.

딥러닝 신경망 훈련 중의 출력값 사용자 지정하기

이 예제에서는 딥러닝 신경망 훈련 중에 각 반복에서 실행되는 출력 함수를 정의하는 방법을 보여줍니다. trainingOptions'OutputFcn' 이름-값 쌍 인수를 사용하여 출력 함수를 지정하면 trainNetwork가 훈련을 시작하기 전에 한 번, 각 훈련 반복이 실행된 뒤에 한 번, 그리고 훈련이 완료된 뒤에 한 번 이러한 함수를 호출합니다. 출력 함수가 호출될 때마다 trainNetwork는 현재 반복 횟수, 손실, 정확도와 같은 정보를 포함하는 구조체를 전달합니다. 출력 함수를 사용하여 진행 상황 정보를 표시하거나 플로팅할 수도 있고 훈련을 중단할 수도 있습니다. 훈련을 조기에 중단하려면 출력 함수가 true를 반환하도록 하십시오. 출력 함수가 true를 반환하면 훈련이 완료되고 trainNetwork 가 마지막 신경망을 반환합니다.

검증 세트에서 손실이 더 이상 감소하지 않을 때 훈련을 중단하려면 각각 trainingOptions'ValidationData' 이름-값 쌍 인수와 'ValidationPatience' 이름-값 쌍 인수를 사용하여 검증 데이터와 검증 인내도를 지정하십시오. 검증 인내도(validation patience)는 신경망 훈련이 중단되기 전까지 검증 세트에서의 손실이 그전까지의 가장 작은 손실보다 크거나 같아도 되는 횟수입니다. 출력 함수를 사용하여 더 많은 중지 기준을 추가할 수 있습니다. 이 예제에서는 검증 데이터의 분류 정확도가 더 이상 향상되지 않을 때 훈련을 중단하는 출력 함수를 만드는 방법을 설명합니다. 출력 함수는 스크립트의 끝에서 정의됩니다.

숫자 영상 5,000개를 포함하는 훈련 데이터를 불러옵니다. 신경망 검증을 위해 영상 1,000개를 남겨 둡니다.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

숫자 영상 데이터를 분류할 신경망을 구성합니다.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

신경망 훈련 옵션을 지정합니다. 훈련 중에 규칙적인 간격으로 신경망을 검증하려면 검증 데이터를 지정하십시오. Epoch당 한 번씩 신경망 검증이 실시되도록 'ValidationFrequency' 값을 선택합니다.

검증 세트에 대한 분류 정확도가 더 이상 향상되지 않을 때 훈련을 중단하려면 stopIfAccuracyNotImproving을 출력 함수로 지정하십시오. stopIfAccuracyNotImproving의 두 번째 입력 인수는 신경망 훈련이 중단되기 전까지 검증 세트에 대한 정확도가 그전까지의 가장 높은 정확도보다 작거나 같아도 되는 횟수입니다. 훈련시킬 최대 Epoch 횟수로 임의의 큰 값을 선택합니다. 훈련이 자동으로 중단되므로 마지막 Epoch까지 훈련이 진행되지 않습니다.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

신경망을 훈련시킵니다. 검증의 정확도가 더 이상 증가하지 않으면 훈련이 중단됩니다.

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:05 |        7.81% |       12.70% |       2.7155 |       2.5169 |          0.0100 |
|       1 |          31 |       00:00:16 |       71.88% |       74.90% |       0.8807 |       0.8130 |          0.0100 |
|       2 |          62 |       00:00:26 |       86.72% |       88.00% |       0.3899 |       0.4436 |          0.0100 |
|       3 |          93 |       00:00:35 |       94.53% |       94.00% |       0.2224 |       0.2553 |          0.0100 |
|       4 |         124 |       00:00:45 |       95.31% |       96.80% |       0.1482 |       0.1762 |          0.0100 |
|       5 |         155 |       00:00:52 |       98.44% |       97.60% |       0.1007 |       0.1314 |          0.0100 |
|       6 |         186 |       00:01:00 |       99.22% |       97.80% |       0.0784 |       0.1136 |          0.0100 |
|       7 |         217 |       00:01:08 |      100.00% |       98.10% |       0.0559 |       0.0945 |          0.0100 |
|       8 |         248 |       00:01:14 |      100.00% |       98.00% |       0.0441 |       0.0859 |          0.0100 |
|       9 |         279 |       00:01:23 |      100.00% |       98.00% |       0.0344 |       0.0786 |          0.0100 |
|      10 |         310 |       00:01:29 |      100.00% |       98.50% |       0.0274 |       0.0678 |          0.0100 |
|      11 |         341 |       00:01:37 |      100.00% |       98.50% |       0.0240 |       0.0621 |          0.0100 |
|      12 |         372 |       00:01:43 |      100.00% |       98.70% |       0.0213 |       0.0569 |          0.0100 |
|      13 |         403 |       00:01:50 |      100.00% |       98.80% |       0.0187 |       0.0534 |          0.0100 |
|      14 |         434 |       00:01:57 |      100.00% |       98.80% |       0.0164 |       0.0508 |          0.0100 |
|      15 |         465 |       00:02:04 |      100.00% |       98.90% |       0.0144 |       0.0487 |          0.0100 |
|      16 |         496 |       00:02:12 |      100.00% |       99.00% |       0.0126 |       0.0462 |          0.0100 |
|      17 |         527 |       00:02:18 |      100.00% |       98.90% |       0.0112 |       0.0440 |          0.0100 |
|      18 |         558 |       00:02:26 |      100.00% |       98.90% |       0.0101 |       0.0420 |          0.0100 |
|      19 |         589 |       00:02:32 |      100.00% |       99.10% |       0.0092 |       0.0405 |          0.0100 |
|      20 |         620 |       00:02:39 |      100.00% |       99.00% |       0.0086 |       0.0391 |          0.0100 |
|      21 |         651 |       00:02:46 |      100.00% |       99.00% |       0.0080 |       0.0380 |          0.0100 |
|      22 |         682 |       00:02:53 |      100.00% |       99.00% |       0.0076 |       0.0369 |          0.0100 |
|======================================================================================================================|
Training finished: Stopped by OutputFcn.

Figure Training Progress (03-Aug-2023 23:43:44) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 9 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 9 objects of type patch, text, line.

출력 함수 정의하기

검증 데이터에 대한 가장 높은 분류 정확도가 신경망 검증에서 N번 연속으로 향상되지 않으면 신경망 훈련을 중지하는 출력 함수 stopIfAccuracyNotImproving(info,N)을 정의합니다. 이 조건은 손실이 아닌 분류 정확도에 적용된다는 점만 제외하면 검증 손실을 사용하는 내장된 중지 기준과 비슷합니다.

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;

elseif ~isempty(info.ValidationAccuracy)

    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end

    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end

end

end

참고 항목

|

관련 항목