이 페이지의 최신 내용은 아직 번역되지 않았습니다. 최신 내용은 영문으로 볼 수 있습니다.

검사 지점 네트워크에서 훈련 재개하기

이 예제에서는 딥러닝 신경망을 훈련시킬 때 검사 지점 네트워크를 저장하고 이전에 저장한 네트워크에서 훈련을 재개하는 방법을 보여줍니다.

샘플 데이터 불러오기

샘플 데이터를 4차원 배열로 불러옵니다. digitTrain4DArrayData는 숫자 훈련 세트를 4차원 배열 데이터로 불러옵니다. XTrain은 28x28x1x5000 배열입니다. 여기서 28은 영상의 높이이고 28은 영상의 너비입니다. 1은 채널 개수이고 5,000은 손으로 쓴 숫자를 표시하는 합성 영상의 개수입니다. YTrain은 각 관측값에 대한 레이블을 포함하는 categorical형 벡터입니다.

[XTrain,YTrain] = digitTrain4DArrayData;
size(XTrain)
ans = 1×4

          28          28           1        5000

XTrain의 영상 몇 개를 표시합니다.

figure;
perm = randperm(size(XTrain,4),20);
for i = 1:20
    subplot(4,5,i);
    imshow(XTrain(:,:,:,perm(i)));
end

네트워크 아키텍처 정의하기

신경망 아키텍처를 정의합니다.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2) 
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    averagePooling2dLayer(7)  
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

훈련 옵션을 지정하고 네트워크 훈련시키기

SGDM(모멘텀을 사용한 확률적 경사하강법)의 훈련 옵션을 지정하고 검사 지점 네트워크를 저장할 경로를 지정합니다.

checkpointPath = pwd;
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',20, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

네트워크를 훈련시킵니다. trainNetwork는 GPU를 사용할 수 있으면 GPU를 사용합니다. GPU를 사용할 수 없으면 CPU를 사용합니다. trainNetwork는 매 Epoch마다 하나의 검사 지점 네트워크를 저장하고 검사 지점 파일에 자동으로 고유한 이름을 할당합니다.

net1 = trainNetwork(XTrain,YTrain,layers,options);

검사 지점 네트워크를 불러와서 훈련 재개하기

훈련이 중단되어서 완료되지 않았다고 가정하겠습니다. 훈련을 처음부터 시작하는 대신 마지막 검사 지점 네트워크를 불러와서 그 지점부터 훈련을 재개할 수 있습니다. trainNetwork는 검사 지점 파일을 net_checkpoint__195__2018_07_13__11_59_10.mat 형식의 파일 이름으로 저장합니다. 여기서 195는 반복 번호이고 2018_07_13은 날짜이고 11_59_10trainNetwork가 네트워크를 저장한 시간입니다. 검사 지점 네트워크는 net이라는 변수 이름을 갖습니다.

검사 지점 네트워크를 작업 공간으로 불러옵니다.

load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')

훈련 옵션을 지정하고 최대 Epoch 횟수를 줄입니다. 초기 학습률과 같은 다른 훈련 옵션을 조정할 수도 있습니다.

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',15, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

불러온 검사 지점 네트워크 계층의 훈련을 새로운 훈련 옵션을 사용하여 재개합니다. 검사 지점 네트워크가 DAG 네트워크인 경우, net.Layers 대신 layerGraph(net)을 인수로 사용하십시오.

net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

참고 항목

|

관련 예제

세부 정보