Main Content

장단기 기억 신경망을 사용하여 심전도 신호 분류하기

이 예제에서는 딥러닝과 신호 처리를 사용하여 PhysioNet 2017 Challenge의 심전도(ECG) 데이터를 분류하는 방법을 보여줍니다. 특히, 이 예제에서는 장단기 기억 신경망과 시간-주파수 분석을 사용합니다.

GPU 및 Parallel Computing Toolbox™를 사용하여 워크플로를 재현하고 가속화하는 예제는 Classify ECG Signals Using Long Short-Term Memory Networks with GPU Acceleration (Signal Processing Toolbox) 항목을 참조하십시오.

소개

심전도는 일정 시간 동안 사람 심장의 전기적 활성을 기록합니다. 의사들은 심전도를 사용하여 환자의 심박이 정상인지 불규칙적인지를 시각적으로 판단합니다.

심방세동(AFib)은 심장의 심방이 심실과 다른 주기로 뛸 때 발생하는 불규칙적인 심박입니다.

이 예제에서는 PhysioNet 2017 Challenge [1], [2], [3]의 심전도 데이터를 사용합니다. 이 데이터는 https://physionet.org/challenge/2017/에서 확인할 수 있습니다. 데이터는 300Hz로 샘플링되고 전문가 그룹이 네 가지 클래스로 분류한 심전도 신호로 구성되어 있습니다. 이들 네 가지 클래스는 정상(N), AFib(A), 기타 리듬(O), 잡음이 있는 기록(~)입니다. 이 예제에서는 딥러닝을 사용하여 분류 과정을 자동화하는 방법을 보여줍니다. 이 절차에서는 정상 심전도 신호를 AFib 증상을 보이는 신호와 구분할 수 있는 이진 분류기를 살펴봅니다.

장단기 기억(LSTM) 신경망은 시퀀스 데이터와 시계열 데이터를 연구하는 데 적합한 일종의 순환 신경망(RNN)입니다. LSTM 신경망은 시퀀스의 시간 스텝 간의 장기적인 종속성을 학습할 수 있습니다. LSTM 계층(lstmLayer)은 순방향의 시간 시퀀스를 살펴볼 수 있고, 양방향 LSTM 계층(bilstmLayer)은 순방향과 역방향의 시간 시퀀스를 살펴볼 수 있습니다. 이 예제에서는 양방향 LSTM 계층을 사용합니다.

이 예제는 인공 지능(AI) 문제 해결에 데이터 중심의 방식을 사용할 경우의 이점을 보여줍니다. 원시 데이터를 사용하여 LSTM 신경망을 훈련시키는 초기 시도에서는 표준 이하의 결과가 생성됩니다. 추출된 특징을 사용하여 동일한 모델 아키텍처를 훈련시키면 분류 성능이 크게 향상됩니다.

훈련 과정을 가속화하려면 GPU가 있는 시스템에서 이 예제를 실행하십시오. 시스템에 GPU와 Parallel Computing Toolbox™가 있는 경우 MATLAB®은 자동으로 훈련에 GPU를 사용하고, 그렇지 않은 경우 CPU를 사용합니다.

데이터를 불러오고 검토하기

ReadPhysionetData 스크립트를 실행하여 PhysioNet 웹 사이트에서 데이터를 다운로드하고 적절한 형식의 심전도 신호를 포함하는 MAT 파일(PhysionetData.mat)을 생성합니다. 데이터를 다운로드하는 데 몇 분 정도 걸릴 수 있습니다. PhysionetData.mat가 아직 현재 폴더에 없는 경우에만 스크립트를 실행하는 조건문을 사용하십시오.

if ~isfile('PhysionetData.mat')
    ReadPhysionetData         
end
load PhysionetData

불러오기 작업은 작업 공간에 두 개의 변수 SignalsLabels를 추가합니다. Signals는 심전도 신호를 포함하는 셀형 배열입니다. Labels는 신호의 대응되는 ground-truth 레이블을 포함하는 categorical형 배열입니다.

Signals(1:5)'
ans=1×5 cell array
    {1×9000 double}    {1×9000 double}    {1×18000 double}    {1×9000 double}    {1×18000 double}

Labels(1:5)
ans = 5×1 categorical
     N 
     N 
     N 
     A 
     A 

summary 함수를 사용하여 데이터에 몇 개의 AFib 신호와 정상 신호가 있는지 확인합니다.

summary(Labels)
     A       738 
     N      5050 

신호 길이에 대한 히스토그램을 생성합니다. 대부분의 신호의 길이가 9000개 샘플입니다.

L = cellfun(@length,Signals);
h = histogram(L);
xticks(0:3000:18000);
xticklabels(0:3000:18000);
title('Signal Lengths')
xlabel('Length')
ylabel('Count')

각 클래스별로 신호 하나의 세그먼트를 시각화합니다. AFib 심박은 간격이 불규칙한 반면 정상 심박은 규칙적으로 뜁니다. AFib 심박 신호에는 P파가 없는 경우가 잦은데 정상 심박 신호에서는 QRS 복합파 전에 P파가 뜁니다. 정상 신호의 플롯에서는 P파와 QRS 복합파를 볼 수 있습니다.

normal = Signals{1};
aFib = Signals{4};

subplot(2,1,1)
plot(normal)
title('Normal Rhythm')
xlim([4000,5200])
ylabel('Amplitude (mV)')
text(4330,150,'P','HorizontalAlignment','center')
text(4370,850,'QRS','HorizontalAlignment','center')

subplot(2,1,2)
plot(aFib)
title('Atrial Fibrillation')
xlim([4000,5200])
xlabel('Samples')
ylabel('Amplitude (mV)')

훈련을 위해 데이터 준비하기

훈련 중에 trainnet 함수는 데이터를 미니 배치로 분할합니다. 이 함수는 그런 다음 동일한 미니 배치의 신호가 모두 동일한 길이를 갖도록 채우거나 자릅니다. 이처럼 추가된 정보나 제거된 정보를 바탕으로 신경망이 신호를 잘못 해석할 가능성이 있기 때문에 채우기나 자르기를 너무 많이 하면 신경망 성능이 저하될 수 있습니다.

과도한 채우기나 자르기를 방지하려면 심전도 신호의 길이가 모두 9000개 샘플이 되도록 심전도 신호에 segmentSignals 함수를 적용하십시오. 이 함수는 샘플이 9000개 미만인 신호를 무시합니다. 신호의 샘플이 9000개보다 많은 경우 segmentSignals는 신호를 분할해 9000개 샘플 크기의 세그먼트가 가능한 한 가장 많이 생성되게 하고 나머지 샘플은 무시합니다. 예를 들어, 18500개의 샘플로 구성된 신호는 9000개의 샘플로 구성된 신호 2개가 되고 나머지 500개의 샘플은 무시됩니다.

[Signals,Labels] = segmentSignals(Signals,Labels);

Signals 배열의 처음 5개 요소를 보고 각 항목의 길이가 이제 9000개 샘플이 된 것을 확인합니다.

Signals(1:5)'
ans=1×5 cell array
    {1×9000 double}    {1×9000 double}    {1×9000 double}    {1×9000 double}    {1×9000 double}

첫 번째 시도: 원시 신호 데이터를 사용하여 분류기 훈련시키기

분류기를 설계하려면 이전 섹션에서 생성된 원시 신호를 사용하십시오. 신호를 분류기 훈련을 위한 훈련 세트와 새 데이터에 대한 분류기의 정확도 테스트를 위한 테스트 세트로 분할합니다.

summary 함수를 사용하여 AFib 신호와 정상 신호의 비가 약 1:7임을 확인합니다.

summary(Labels)
     A       718 
     N      4937 

신호의 7/8이 정상 신호이기 때문에 분류기는 모든 신호를 정상으로 분류하면 높은 정확도를 달성할 수 있다고 학습하게 됩니다. 이러한 편향을 방지하려면 정상 신호와 AFib 신호의 개수가 같아지도록 데이터셋에서 AFib 신호를 복제하여 AFib 신호의 개수를 늘리십시오. 일반적으로 오버샘플링이라고 하는 이러한 복제 기법은 딥러닝에서 사용되는 데이터 증대의 한 가지 형태입니다.

신호를 클래스에 따라 분할합니다.

afibX = Signals(Labels=='A');
afibY = Labels(Labels=='A');

normalX = Signals(Labels=='N');
normalY = Labels(Labels=='N');

다음으로, dividerand를 사용하여 각 클래스의 목표값을 훈련 세트, 검증 세트, 테스트 세트로 무작위로 나눕니다.

rng("default")
[trainIndA,validIndA,testIndA] = dividerand(length(afibX),0.8,0.1,0.1);
[trainIndN,validIndN,testIndN] = dividerand(length(normalX),0.8,0.1,0.1);
XTrainA = afibX(trainIndA);
YTrainA = afibY(trainIndA);
XTrainN = normalX(trainIndN);
YTrainN = normalY(trainIndN);

XValidA = afibX(validIndA);
YValidA = afibY(validIndA);
XValidN = normalX(validIndN);
YValidN = normalY(validIndN);

XTestA = afibX(testIndA);
YTestA = afibY(testIndA);
XTestN = normalX(testIndN);
YTestN = normalY(testIndN);

데이터셋의 균형이 맞지 않습니다. AFib 신호와 정상 신호를 비슷한 개수로 만들기 위해 AFib 신호를 7번 반복합니다.

기본적으로, 신경망은 연속된 신호가 모두 동일한 레이블을 갖지 않도록 하기 위해 훈련 전에 데이터를 무작위로 섞습니다.

XTrain = [repmat(XTrainA,7,1); XTrainN];
YTrain = [repmat(YTrainA,7,1); YTrainN];

XValid = [repmat(XValidA,7,1); XValidN];
YValid = [repmat(YValidA,7,1); YValidN];

XTest = [repmat(XTestA,7,1); XTestN];
YTest = [repmat(YTestA,7,1); YTestN];

이제 훈련 세트, 검증 세트, 테스트 세트에서 정상 신호와 AFib 신호 사이의 분포가 사실상 균일합니다.

summary(YTrain)
     A      4018 
     N      3949 
summary(YValid)
     A      504 
     N      494 
summary(YTest)
     A      504 
     N      494 

LSTM 신경망 아키텍처 정의하기

LSTM 신경망은 시퀀스 데이터의 시간 스텝 간의 장기적인 종속성을 학습할 수 있습니다. 이 예제에서는 순방향과 역방향의 시퀀스를 모두 살펴보므로 양방향 LSTM 계층 bilstmLayer를 사용합니다.

입력 신호가 각각 1개의 차원을 가지므로 입력 크기가 크기 1로 구성된 시퀀스가 되도록 지정합니다. 출력 크기가 50인 양방향 LSTM 계층을 지정하고 시퀀스의 마지막 요소를 출력합니다. 다음 명령은 입력 시계열을 50개의 특징으로 매핑하도록 양방향 LSTM 계층에 지시하고 완전 연결 계층에 대한 출력을 준비합니다. 마지막으로, 크기가 2인 완전 연결 계층을 포함하여 2개의 클래스를 지정하고, 이어서 소프트맥스 계층을 지정합니다.

layers = [ ...
    sequenceInputLayer(1)
    bilstmLayer(50,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    ]
layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 1 dimensions
     2   ''   BiLSTM            BiLSTM with 50 hidden units
     3   ''   Fully Connected   2 fully connected layer
     4   ''   Softmax           softmax

다음은 분류기의 훈련 옵션을 지정합니다. 신경망이 훈련 데이터를 100번 통과하도록 'MaxEpochs'를 100으로 설정합니다. 신경망이 한 번에 300개의 훈련 신호를 살펴보도록 'MiniBatchSize'를 300으로 설정합니다. 'InitialLearnRate'를 0.01로 설정하면 훈련 과정의 속도를 높이는 데 도움이 됩니다. 'Plots''training-progress'로 지정하여 반복 횟수가 늘어남에 따라 훈련 진행 상황을 그래픽으로 표시하는 플롯을 생성합니다. 'Verbose'false로 설정하여 플롯에 표시되는 데이터에 대응되는 표의 형태로 출력값이 표시되지 않도록 합니다. 이 표를 보려면 'Verbose'true로 설정하십시오. 훈련 데이터는 행과 열이 각각 채널과 시간 스텝에 대응하는 시퀀스를 가지므로 입력 데이터 형식 'CTB'(채널, 시간, 배치)를 지정합니다.

이 예제에서는 ADAM(적응적 모멘트 추정) 솔버를 사용합니다. ADAM은 LSTM과 같은 RNN에서 디폴트 값인 모멘텀을 사용한 확률적 경사하강법(SGDM) 솔버보다 더 나은 성능을 보입니다.

options = trainingOptions('adam', ...
    'MaxEpochs',150, ...
    'MiniBatchSize', 200, ...
    'GradientThreshold',1, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate', 1e-3, ...
    'ExecutionEnvironment','auto', ...
    'plots','training-progress', ...
    'Metrics','accuracy', ...
    'InputDataFormats','CTB', ...
    'ValidationData',{XValid,YValid}, ...
    'Verbose',false, ...
    'OutputNetwork','last-iteration');

LSTM 신경망 훈련시키기

trainnet를 사용하여 지정된 훈련 옵션과 계층 아키텍처로 LSTM 신경망을 훈련시킵니다. 훈련 세트가 크기 때문에 훈련 과정에 몇 분 정도 걸릴 수 있습니다.

net = trainnet(XTrain,YTrain,layers,"crossentropy",options);

훈련 진행 과정 플롯의 상단 서브플롯은 훈련 정확도, 즉 각 미니 배치의 분류 정확도를 나타냅니다. 훈련이 성공적으로 진행되면 이 값은 일반적으로 100%를 향해 늘어납니다. 하단 서브플롯에는 훈련 손실, 즉 각 미니 배치의 교차 엔트로피 손실이 표시됩니다. 훈련이 성공적으로 진행되면 이 값은 일반적으로 0을 향해 줄어듭니다.

훈련이 수렴되지 않을 경우 플롯이 특정 상향 또는 하향 방향으로 진행되는 대신 특정 값 사이를 진동할 수 있습니다. 이러한 진동은 훈련 정확도가 증가하지 않고 있으며 훈련 손실이 감소하지 않고 있음을 의미합니다. 이 상황은 훈련 시작부터 발생할 수도 있고, 또는 훈련 정확도가 초기에는 개선되다가 이어서 플롯이 일정해진 경우일 수도 있습니다. 대부분의 경우, 훈련 옵션을 변경하면 신경망이 수렴하는 데 도움이 됩니다. MiniBatchSize를 줄이거나 InitialLearnRate를 줄이면 훈련 시간이 늘어나긴 하지만 신경망이 더욱 잘 학습하는 데 도움이 될 수 있습니다.

이 경우 훈련 정확도는 매우 높지만 검증 정확도는 그에 비해 향상되지 않았습니다. 이는 과적합을 나타낼 수 있습니다. 즉, 모델이 일반화될 수 없으며 훈련 데이터셋에 너무 가깝게 피팅됩니다. 이에 대한 원인은 여러 가지가 있을 수 있습니다. 예를 들어, 훈련 데이터에 중복 데이터나 관련 없는 정보가 많이 포함되었거나, 신경망이 실제 분류에 핵심적인 인자를 학습하지 않은 것일 수 있습니다.

훈련 및 테스트 정확도 시각화하기

훈련이 진행된 신호에 대한 분류기의 정확도를 나타내는 훈련 정확도를 계산합니다. 먼저 훈련 데이터를 분류합니다.

여러 개의 관측값을 사용하여 예측을 수행하려면 minibatchpredict 함수를 사용합니다. 예측 점수를 레이블로 변환하려면 scores2label 함수를 사용합니다. minibatchpredict 함수는 GPU를 사용할 수 있으면 자동으로 GPU를 사용합니다. GPU를 사용하려면 Parallel Computing Toolbox™ 라이선스와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. GPU를 사용할 수 없는 경우, 함수는 CPU를 사용합니다.

classNames = categories(YTrain);
scores = minibatchpredict(net,XTrain,"InputDataFormats","CTB");
trainPred = scores2label(scores,classNames);

분류 문제에서는 실제 값이 알려져 있는 데이터 세트에 대한 분류기의 성능을 시각화하기 위해 혼동행렬이 사용됩니다. 목표 클래스는 신호의 ground-truth 레이블이고, 출력 클래스는 신경망이 신호에 할당한 레이블입니다. 좌표축 레이블은 클래스 레이블, 즉 AFib(A)와 정상(N)을 나타냅니다.

confusionchart 명령을 사용하여 테스트 데이터 예측값에 대한 전체적인 분류 정확도를 계산합니다. RowSummary를 "row-normalized"로 지정하여 행 요약에 참양성률과 거짓양성률을 표시합니다. 또한, ColumnSummary를 "column-normalized"로 지정하여 열 요약에 양성예측도와 오발견율을 표시합니다.

LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 99.0335
figure
confusionchart(YTrain,trainPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

이번에는 같은 신경망을 사용하여 테스트 데이터를 분류합니다.

scores = minibatchpredict(net,XTest,InputDataFormats="CTB");
testPred = scores2label(scores,classNames);

테스트 정확도를 계산하고, 분류 성능을 혼동행렬로 시각화합니다.

LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
LSTMAccuracy = 61.1222
figure
confusionchart(YTest,testPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

두 번째 시도: 특징 추출을 사용하여 성능 개선하기

데이터에서 특징을 추출하여 분류기의 성능을 개선할 수 있습니다. 어느 특징을 추출할지 판단하기 위해 이 예제에서는 스펙트로그램과 같은 시간-주파수 영상을 계산하는 방법을 조정하며, 이를 사용하여 컨벌루션 신경망(CNN)을 훈련시킵니다 [4], [5].

각 신호 유형의 스펙트로그램을 시각화합니다.

fs = 300;

figure
subplot(2,1,1);
pspectrum(normal,fs,'spectrogram','TimeResolution',0.5)
title('Normal Signal')

subplot(2,1,2);
pspectrum(aFib,fs,'spectrogram','TimeResolution',0.5)
title('AFib Signal')

이 예제에서는 CNN 대신 LSTM을 사용하므로 접근 방법을 1차원 신호에 적용되도록 변환하는 것이 중요합니다. 시간-주파수(TF) 모멘트는 스펙트로그램에서 정보를 추출합니다. 각 모멘트는 LSTM에 입력할 1차원 특징으로 사용할 수 있습니다.

시간 영역에서 두 개의 TF 모멘트를 검토합니다.

  • 순시 주파수(instfreq)

  • 스펙트럼 엔트로피(pentropy)

instfreq 함수는 신호의 시간 종속 주파수를 파워 스펙트로그램의 1차 모멘트로 추정합니다. 함수는 시간 윈도우에 대해 단시간 푸리에 변환을 사용하여 스펙트로그램을 계산합니다. 이 예제에서는 255개의 시간 윈도우를 사용합니다. 함수의 시간 출력값은 시간 윈도우의 중앙에 대응됩니다.

각 신호 유형에 대해 순시 주파수를 시각화합니다.

[instFreqA,tA] = instfreq(aFib,fs);
[instFreqN,tN] = instfreq(normal,fs);

figure
subplot(2,1,1);
plot(tN,instFreqN)
title('Normal Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

subplot(2,1,2);
plot(tA,instFreqA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

cellfun을 사용하여 훈련 세트와 테스트 세트의 모든 셀에 instfreq 함수를 적용합니다.

instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,'UniformOutput',false);
instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,'UniformOutput',false);
instfreqValid = cellfun(@(x)instfreq(x,fs)',XValid,'UniformOutput',false);

스펙트럼 엔트로피는 신호의 스펙트럼이 얼마나 뾰족하거나 평탄한지 측정합니다. 정현파의 합과 같이 스펙트럼이 뾰족한 신호는 스펙트럼 엔트로피가 낮습니다. 백색 잡음과 같이 스펙트럼이 평탄한 신호는 스펙트럼 엔트로피가 높습니다. pentropy 함수는 파워 스펙트로그램을 바탕으로 스펙트럼 엔트로피를 추정합니다. 순시 주파수 추정의 경우와 마찬가지로, pentropy는 255개의 시간 윈도우를 사용하여 스펙트로그램을 계산합니다. 함수의 시간 출력값은 시간 윈도우의 중앙에 대응됩니다.

각 신호 유형에 대해 스펙트럼 엔트로피를 시각화합니다.

[pentropyA,tA2] = pentropy(aFib,fs);
[pentropyN,tN2] = pentropy(normal,fs);

figure
subplot(2,1,1)
plot(tN2,pentropyN)
title('Normal Signal')
ylabel('Spectral Entropy')

subplot(2,1,2)
plot(tA2,pentropyA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Spectral Entropy')

cellfun을 사용하여 훈련 세트, 테스트 세트, 검증 세트의 모든 셀에 pentropy 함수를 적용합니다.

pentropyTrain = cellfun(@(x)pentropy(x,fs)',XTrain,'UniformOutput',false);
pentropyTest = cellfun(@(x)pentropy(x,fs)',XTest,'UniformOutput',false);
pentropyValid = cellfun(@(x)pentropy(x,fs)',XValid,'UniformOutput',false);

새 훈련 세트와 테스트 세트의 각 셀이 2개의 차원, 즉 2개의 특징을 갖도록 특징들을 결합합니다.

XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,'UniformOutput',false);
XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);
XValid2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);

새 입력값의 형식을 시각화합니다. 각 셀은 더 이상 길이가 샘플 9000개인 신호를 포함하지 않습니다. 이제 각 길이가 샘플 255개인 2개의 특징을 포함합니다.

XTrain2(1:5)
ans=5×1 cell array
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}

데이터 표준화하기

순시 주파수와 스펙트럼 엔트로피의 평균값은 거의 10배 차이가 납니다. 또한, 순시 주파수의 평균은 LSTM이 효과적으로 학습을 진행하기에는 지나치게 높을 수 있습니다. 신경망이 평균이 크고 값의 범위가 큰 데이터에 대해 피팅된 경우, 큰 입력값은 신경망의 학습과 수렴 속도를 저하할 수 있습니다 [6].

mean(instFreqN)
ans = 5.5551
mean(pentropyN)
ans = 0.6324

훈련 세트의 평균과 표준편차를 사용하여 훈련 세트, 테스트 세트, 검증 세트를 표준화합니다. 표준화(z-점수화)는 훈련 중에 신경망 성능을 개선하는 방법으로 널리 사용됩니다.

XV = [XTrain2{:}];
mu = mean(XV,2);
sg = std(XV,[],2);

XTrainSD = XTrain2;
XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,'UniformOutput',false);
XValidSD = XValid2;
XValidSD = cellfun(@(x)(x-mu)./sg,XValidSD,'UniformOutput',false);
XTestSD = XTest2;
XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,'UniformOutput',false);

표준화된 순시 주파수와 스펙트럼 엔트로피의 평균을 표시합니다.

instFreqNSD = XTrainSD{1}(1,:);
pentropyNSD = XTrainSD{1}(2,:);

mean(instFreqNSD)
ans = 0.1544
mean(pentropyNSD)
ans = 0.1935

LSTM 신경망 아키텍처 수정하기

이제 신호가 각각 2개의 차원을 갖습니다. 따라서 입력 시퀀스의 크기를 2로 지정하여 신경망 아키텍처를 수정해야 합니다. 출력 크기가 100인 양방향 LSTM 계층을 지정하고 시퀀스의 마지막 요소를 출력합니다. 크기가 2인 완전 연결 계층을 포함하여 2개의 클래스를 지정하고, 이어서 소프트맥스 계층을 지정합니다.

layers = [ ...
    sequenceInputLayer(2)
    bilstmLayer(50,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    ]
layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 2 dimensions
     2   ''   BiLSTM            BiLSTM with 50 hidden units
     3   ''   Fully Connected   2 fully connected layer
     4   ''   Softmax           softmax

훈련 옵션을 지정합니다. 신경망이 훈련 데이터를 120번 통과하도록 최대 Epoch 횟수를 120으로 설정합니다.

options = trainingOptions('adam', ...
    'MaxEpochs',150, ...
    'MiniBatchSize', 200, ...
    'GradientThreshold',1, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate', 1e-3, ...
    'ExecutionEnvironment','auto',...
    'plots','training-progress', ...
    'Metrics','accuracy', ...
    'InputDataFormats','CTB', ...
    'ValidationData',{XValidSD,YValid}, ...
    'OutputNetwork','last-iteration', ...
    'Verbose',false);

시간-주파수 특징으로 LSTM 신경망 훈련시키기

trainnet를 사용하여 지정된 훈련 옵션과 계층 아키텍처로 LSTM 신경망을 훈련시킵니다.

net2 = trainnet(XTrainSD,YTrain,layers,"crossentropy",options);

TF 모멘트가 원시 시퀀스보다 짧기 때문에 훈련에 필요한 시간이 줄어듭니다.

훈련 및 테스트 정확도 시각화하기

업데이트된 LSTM 신경망을 사용하여 훈련 데이터를 분류합니다. 분류 성능을 혼동행렬로 시각화합니다.

scores = minibatchpredict(net2,XTrainSD,"InputDataFormats","CTB");
trainPred2 = scores2label(scores,classNames);
LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
LSTMAccuracy = 96.3600
figure
confusionchart(YTrain,trainPred2,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

업데이트된 신경망을 사용하여 테스트 데이터를 분류합니다. 혼동행렬을 플로팅하여 테스트 정확도를 검토합니다.

scores = minibatchpredict(net2,XTestSD,InputDataFormats="CTB");
testPred2 = scores2label(scores,classNames);
LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
LSTMAccuracy = 93.2866
figure
confusionchart(YTest,testPred2,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

결론

이 예제에서는 LSTM 신경망을 사용하여 심전도 신호에서 심방세동을 검출하는 분류기를 만드는 방법을 소개했습니다. 이 절차에서는 주로 건강한 환자로 구성된 데이터에서 이상 조건을 검출하고자 할 때 발생하는 분류 편향을 방지하기 위해 오버샘플링을 사용했습니다. 원시 신호 데이터를 사용하여 LSTM 신경망을 훈련시키면 분류 정확도가 매우 낮아집니다. 각 신호에 대해 2개의 시간-주파수 모멘트 특징을 사용하여 신경망을 훈련시키면 분류 성능이 크게 개선되고 훈련 시간도 줄어듭니다.

참고 문헌

[1] AF Classification from a Short Single Lead ECG Recording: the PhysioNet/Computing in Cardiology Challenge, 2017. https://physionet.org/challenge/2017/

[2] Clifford, Gari, Chengyu Liu, Benjamin Moody, Li-wei H. Lehman, Ikaro Silva, Qiao Li, Alistair Johnson, and Roger G. Mark. "AF Classification from a Short Single Lead ECG Recording: The PhysioNet Computing in Cardiology Challenge 2017." Computing in Cardiology (Rennes: IEEE). Vol. 44, 2017, pp. 1–4.

[3] Goldberger, A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals". Circulation. Vol. 101, No. 23, 13 June 2000, pp. e215–e220. http://circ.ahajournals.org/content/101/23/e215.full

[4] Pons, Jordi, Thomas Lidy, and Xavier Serra. "Experimenting with Musically Motivated Convolutional Neural Networks". 14th International Workshop on Content-Based Multimedia Indexing (CBMI). June 2016.

[5] Wang, D. "Deep learning reinvents the hearing aid," IEEE Spectrum, Vol. 54, No. 3, March 2017, pp. 32–37. doi: 10.1109/MSPEC.2017.7864754.

[6] Brownlee, Jason. How to Scale Data for Long Short-Term Memory Networks in Python. 7 July 2017. https://machinelearningmastery.com/how-to-scale-data-for-long-short-term-memory-networks-in-python/.

참고 항목

함수

관련 항목