Main Content

1차원 컨벌루션을 사용한 시퀀스 분류

이 예제에서는 1차원 컨벌루션 신경망을 사용하여 시퀀스 데이터를 분류하는 방법을 보여줍니다.

시퀀스 데이터를 분류하도록 심층 신경망을 훈련시키기 위해 1차원 컨벌루션 신경망을 사용할 수 있습니다. 1차원 컨벌루션 계층은 1차원 입력값에 슬라이딩 컨벌루션 필터를 적용하여 특징을 학습합니다. 컨벌루션 계층은 한 번의 연산으로 입력을 처리할 수 있기 때문에 1차원 컨벌루션 계층을 사용하는 것이 순환 계층을 사용하는 것보다 더 빠를 수 있습니다. 반면에 순환 계층은 입력값의 시간 스텝마다 반복해야 합니다. 그러나 순환 계층은 시간 스텝 간의 장기적인 종속성을 학습할 수 있기 때문에 신경망 아키텍처와 필터 크기가 어떤지에 따라 1차원 컨벌루션 계층이 순환 계층보다 성능이 떨어질 수도 있습니다.

시퀀스 데이터 불러오기

WaveformData.mat에서 예제 데이터를 불러옵니다. 이 데이터는 시퀀스로 구성된 numObservations×1 셀형 배열이며, 여기서 numObservations는 시퀀스 개수입니다. 각 시퀀스는 numTimeSteps×numChannels 숫자형 배열이며, 여기서 numTimeSteps는 시퀀스의 시간 스텝 개수이고 numChannels는 시퀀스의 채널 개수입니다.

load WaveformData

시퀀스 중 일부를 플롯으로 시각화합니다.

numChannels = size(data{1},2);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end

검증과 테스트를 위해 데이터를 남겨 둡니다. 데이터의 80%가 포함된 훈련 세트, 데이터의 10%가 포함된 검증 세트, 데이터의 나머지 10%가 포함된 테스트 세트로 데이터를 분할합니다. 데이터를 분할하려면 이 예제에 지원 파일로 첨부된 trainingPartitions 함수를 사용합니다. 이 파일에 액세스하려면 이 예제를 라이브 스크립트로 여십시오.

numObservations = numel(data);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations, [0.8 0.1 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XValidation = data(idxValidation);
TValidation = labels(idxValidation);

XTest = data(idxTest);
TTest = labels(idxTest);

1차원 컨벌루션 신경망 아키텍처 정의하기

1차원 컨벌루션 신경망 아키텍처를 정의합니다.

  • 입력 크기를 입력 데이터의 채널 개수로 지정합니다.

  • 두 블록의 1차원 컨벌루션 계층, ReLU 계층, 계층 정규화 계층을 지정합니다. 여기서 컨벌루션 계층의 필터 크기는 5입니다. 첫 번째 컨벌루션 계층과 두 번째 컨벌루션 계층에 대해 각각 32개 필터와 64개 필터를 지정합니다. 이 두 컨벌루션 계층의 출력값의 길이가 동일하도록 입력값의 왼쪽을 채웁니다(인과적(causal) 채우기).

  • 컨벌루션 계층의 출력값을 단일 벡터로 줄이기 위해 1차원 전역 평균값 풀링 계층을 사용합니다.

  • 출력값을 확률로 구성된 벡터에 매핑하기 위해, 출력 크기가 클래스 개수와 일치하는 완전 연결 계층을 지정하고 그 뒤에 소프트맥스 계층을 지정합니다.

심층 신경망 디자이너 앱을 사용하여 이 신경망을 구축할 수도 있습니다. 심층 신경망 디자이너 시작 페이지의 Sequence-to-Label 분류 신경망(훈련되지 않음) 섹션에서 1차원 CNN을 클릭합니다.

filterSize = 5;
numFilters = 32;

classNames = categories(TTrain);
numClasses = numel(classNames);

layers = [ ...
    sequenceInputLayer(numChannels)
    convolution1dLayer(filterSize,numFilters,Padding="causal")
    reluLayer
    layerNormalizationLayer
    convolution1dLayer(filterSize,2*numFilters,Padding="causal")
    reluLayer
    layerNormalizationLayer
    globalAveragePooling1dLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

훈련 옵션 지정하기

훈련 옵션을 지정합니다. 옵션 중에서 선택하려면 경험적 분석이 필요합니다. 실험을 실행하여 다양한 훈련 옵션 구성을 살펴보려면 실험 관리자 앱을 사용합니다.

  • Adam 최적화 함수를 사용하여 학습률 0.01로 Epoch 60회만큼 훈련시킵니다.

  • 시퀀스의 왼쪽을 채웁니다.

  • 검증 데이터를 사용하여 신경망을 검증합니다.

  • 훈련 진행 상황을 플롯에서 모니터링하고 세부 정보가 출력되지 않도록 합니다.

options = trainingOptions("adam", ...
    MaxEpochs=60, ...
    InitialLearnRate=0.01, ...
    SequencePaddingDirection="left", ...
    ValidationData={XValidation,TValidation}, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

신경망 훈련시키기

trainnet 함수를 사용하여 신경망을 훈련시킵니다. 분류에는 교차 엔트로피 손실을 사용합니다. 기본적으로 trainnet 함수는 GPU를 사용할 수 있으면 GPU를 사용합니다. GPU에서 훈련시키려면 Parallel Computing Toolbox™ 라이선스와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. GPU를 사용할 수 없는 경우, trainnet 함수는 CPU를 사용합니다. 실행 환경을 지정하려면 ExecutionEnvironment 훈련 옵션을 사용하십시오.

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

신경망 테스트하기

testnet 함수를 사용하여 신경망을 테스트하고 훈련에 사용된 것과 동일한 인수를 사용합니다. 단일 레이블 분류의 경우 정확도를 평가합니다. 정확도는 올바른 예측의 비율입니다. 기본적으로 testnet 함수는 GPU를 사용할 수 있으면 GPU를 사용합니다. 실행 환경을 수동으로 선택하려면 testnet 함수의 ExecutionEnvironment 인수를 사용하십시오.

accuracy = testnet(net,XTest,TTest,"accuracy",SequencePaddingDirection="left")
accuracy = 
72

예측값을 혼동행렬로 시각화합니다. minibatchpredict 함수를 사용하여 예측을 수행하고 훈련에 사용된 것과 동일한 시퀀스 채우기 옵션을 사용합니다. 여러 개의 관측값을 사용하여 예측을 수행하려면 minibatchpredict 함수를 사용합니다. 예측 점수를 레이블로 변환하려면 scores2label 함수를 사용합니다. minibatchpredict 함수는 GPU를 사용할 수 있으면 자동으로 GPU를 사용합니다. 실행 환경을 수동으로 선택하려면 minibatchpredict 함수의 ExecutionEnvironment 인수를 사용하십시오.

scores = minibatchpredict(net,XTest,SequencePaddingDirection="left");
YTest = scores2label(scores, classNames);
figure
confusionchart(TTest,YTest)

참고 항목

| | | | | | | |

관련 항목