이 페이지의 최신 내용은 아직 번역되지 않았습니다. 최신 내용은 영문으로 볼 수 있습니다.

장단기 기억 네트워크를 사용하여 심전도 신호 분류하기

이 예제에서는 심층 학습과 신호 처리를 사용하여 PhysioNet 2017 Challenge의 심전도(ECG) 데이터를 분류하는 방법을 보여줍니다. 특히, 이 예제에서는 장단기 기억(LSTM) 네트워크와 시간-주파수 분석을 사용합니다.

소개

심전도는 일정 시간 동안 사람 심장의 전기적 활성을 기록합니다. 의사들은 심전도를 사용하여 환자의 심박이 정상인지 불규칙적인지를 시각적으로 판단합니다.

심방세동(AFib)은 심장의 심방이 심실과 다른 주기로 뛸 때 발생하는 불규칙적인 심박입니다.

이 예제에서는 PhysioNet 2017 Challenge [1], [2], [3]의 심전도 데이터를 사용합니다. 이 데이터는 https://physionet.org/challenge/2017/에서 확인할 수 있습니다. 데이터는 300Hz에서 샘플링되고 전문가 그룹이 네 가지 클래스로 분류한 심전도 신호로 구성되어 있습니다. 이들 네 가지 클래스는 정상(N), AFib(A), 기타 리듬(O), 잡음이 있는 기록(~)입니다. 이 예제에서는 심층 학습을 사용하여 분류 공정을 자동화하는 방법을 보여줍니다. 이 절차에서는 정상 심전도 신호를 AFib 증상을 보이는 신호와 구분할 수 있는 이진 분류기를 살펴봅니다.

이 예제에서는 시퀀스 데이터와 시계열 데이터를 연구하는 데 적합한 일종의 순환 신경망(RNN)인 장단기 기억(LSTM) 네트워크를 사용합니다. LSTM 네트워크는 시퀀스의 시간 스텝 간의 장기적인 종속성을 학습할 수 있습니다. LSTM 계층(lstmLayer)은 순방향의 시간 시퀀스를 살펴볼 수 있고, 양방향 LSTM 계층(bilstmLayer)은 순방향과 역방향의 시간 시퀀스를 살펴볼 수 있습니다. 이 예제에서는 양방향 LSTM 계층을 사용합니다.

훈련 과정을 가속화하려면 GPU가 있는 시스템에서 이 예제를 실행하십시오. 시스템에 GPU와 Parallel Computing Toolbox™가 있는 경우 MATLAB®은 자동으로 훈련에 GPU를 사용하고, 그렇지 않은 경우 CPU를 사용합니다.

데이터를 불러오고 검토하기

ReadPhysionetData 스크립트를 실행하여 PhysioNet 웹 사이트에서 데이터를 다운로드하고 적절한 형식의 심전도 신호를 포함하는 MAT 파일(PhysionetData.mat)을 생성합니다. 데이터를 다운로드하는 데 몇 분 정도 걸릴 수 있습니다.

ReadPhysionetData
load PhysionetData

불러오기 작업은 작업 공간에 두 개의 변수 SignalsLabels를 추가합니다. Signals는 심전도 신호를 포함하는 셀형 배열입니다. Labels는 신호의 대응되는 실측 레이블을 포함하는 categorical형 배열입니다.

Signals(1:5)
ans = 5×1 cell array
    {1×9000  double}
    {1×9000  double}
    {1×18000 double}
    {1×9000  double}
    {1×18000 double}

Labels(1:5)
ans = 5×1 categorical array
     N 
     N 
     N 
     A 
     A 

summary 함수를 사용하여 738개의 AFib 신호와 5050개의 정상 신호가 있는 것을 확인합니다.

summary(Labels)
     A       738 
     N      5050 

신호 길이에 대한 히스토그램을 생성합니다. 대부분의 신호의 길이가 9000개 샘플인 것을 볼 수 있습니다.

L = cellfun(@length,Signals);
h = histogram(L);
xticks(0:3000:18000);
xticklabels(0:3000:18000);
title('Signal Lengths')
xlabel('Length')
ylabel('Count')

각 클래스별로 신호 하나의 세그먼트를 시각화합니다. AFib 심박은 불규칙한 간격으로 샘플이 배치되는 반면 정상 심박은 규칙적인 간격을 갖습니다. AFib 심박 신호에는 정상 심박 신호에서 QRS 복합파 전에 뛰는 P파가 없는 경우가 많습니다. 정상 신호의 플롯에서는 P파와 QRS 복합파를 볼 수 있습니다.

normal = Signals{1};
aFib = Signals{4};

subplot(2,1,1)
plot(normal)
title('Normal Rhythm')
xlim([4000,5200])
ylabel('Amplitude (mV)')
text(4330,150,'P','HorizontalAlignment','center')
text(4370,850,'QRS','HorizontalAlignment','center')

subplot(2,1,2)
plot(aFib)
title('Atrial Fibrillation')
xlim([4000,5200])
xlabel('Samples')
ylabel('Amplitude (mV)')

훈련을 위해 데이터 준비하기

훈련 중에 trainNetwork 함수는 데이터를 미니 배치로 분할합니다. 이 함수는 그런 다음 동일한 미니 배치의 신호가 모두 동일한 길이를 갖도록 채우거나 자릅니다. 추가되었거나 제거된 정보를 기반으로 하여 네트워크에서 신호를 올바르지 않게 해석할 수 있기 때문에 너무 많이 채우거나 자르면 네트워크 성능이 저하될 수 있습니다.

과도한 채우기나 자르기를 방지하려면 심전도 신호의 길이가 모두 9000개 샘플이 되도록 심전도 신호에 segmentSignals 함수를 적용하십시오. 이 함수는 샘플이 9000개 미만인 신호를 무시합니다. 신호의 샘플이 9000개보다 많은 경우 segmentSignals는 신호를 분할해 9000개 샘플 크기의 세그먼트가 가능한 한 가장 많이 생성되게 하고 나머지 샘플은 무시합니다. 예를 들어, 18500개의 샘플로 구성된 신호는 9000개의 샘플로 구성된 신호 2개가 되고 나머지 500개의 샘플은 무시됩니다.

[Signals,Labels] = segmentSignals(Signals,Labels);

Signals 배열의 처음 5개 요소를 보고 각 항목의 길이가 이제 9000개 샘플이 된 것을 확인합니다.

Signals(1:5)
ans = 5×1 cell array
    {1×9000 double}
    {1×9000 double}
    {1×9000 double}
    {1×9000 double}
    {1×9000 double}

원시 신호 데이터를 사용하여 분류기 훈련시키기

분류기를 설계하려면 이전 섹션에서 생성된 원시 신호를 사용하십시오. 신호를 분류기 훈련을 위한 훈련 세트와 새 데이터에 대한 분류기의 정확도 테스트를 위한 테스트 세트로 분할합니다.

summary 함수를 사용하여 718개의 AFib 신호와 4937개의 정상 신호가 있는 것을 표시합니다. AFib 신호와 정상 신호의 비는 1:7입니다.

summary(Labels)
     A       718 
     N      4937 

신호의 87.3%가 정상 신호이므로 분류기는 모든 신호를 정상으로 분류하면 높은 정확도를 달성할 수 있음을 학습하게 됩니다. 이러한 편향을 방지하려면 정상 신호와 AFib 신호의 개수가 같아지도록 데이터셋에서 AFib 신호를 복제하여 AFib 신호의 개수를 늘리십시오. 일반적으로 오버샘플링이라고 하는 이러한 복제 기법은 심층 학습에서 사용되는 데이터 증대의 한 가지 형태입니다.

신호를 클래스에 따라 분할합니다.

afibX = Signals(Labels=='A');
afibY = Labels(Labels=='A');

normalX = Signals(Labels=='N');
normalY = Labels(Labels=='N');

다음으로, dividerand를 사용하여 각 클래스의 목표값을 훈련 세트와 세트로 무작위로 나눕니다.

[trainIndA,~,testIndA] = dividerand(718,0.9,0.0,0.1);
[trainIndN,~,testIndN] = dividerand(4937,0.9,0.0,0.1);

XTrainA = afibX(trainIndA);
YTrainA = afibY(trainIndA);

XTrainN = normalX(trainIndN);
YTrainN = normalY(trainIndN);

XTestA = afibX(testIndA);
YTestA = afibY(testIndA);

XTestN = normalX(testIndN);
YTestN = normalY(testIndN);

이제 훈련에 사용할 646개의 AFib 신호와 4443개의 정상 신호가 있습니다. 각 클래스에서 동일한 개수의 신호가 있도록 하려면 처음 4438개의 정상 신호를 사용한 다음 repmat를 사용하여 처음 634개의 AFib 신호가 7번 반복되도록 복제하십시오.

테스트에 사용할 신호는 72개의 AFib 신호와 494개의 정상 신호가 있습니다. 처음 490개의 정상 신호를 사용한 다음 repmat를 사용하여 처음 70개의 AFib 신호가 7번 반복되도록 복제하십시오. 기본적으로, 신경망은 연속된 신호가 모두 동일한 레이블을 갖지 않도록 하기 위해 훈련 전에 데이터를 무작위로 섞습니다.

XTrain = [repmat(XTrainA(1:634),7,1); XTrainN(1:4438)];
YTrain = [repmat(YTrainA(1:634),7,1); YTrainN(1:4438)];

XTest = [repmat(XTestA(1:70),7,1); XTestN(1:490)];
YTest = [repmat(YTestA(1:70),7,1); YTestN(1:490);];

이제 훈련 세트와 테스트 세트에서 정상 신호와 AFib 신호 사이의 분포가 균일합니다.

summary(YTrain)
     A      4438 
     N      4438 
summary(YTest)
     A      490 
     N      490 

LSTM 네트워크 아키텍처 정의하기

LSTM 네트워크는 시퀀스 데이터의 시간 스텝 간의 장기적인 종속성을 학습할 수 있습니다. 이 예제에서는 순방향과 역방향의 시퀀스를 모두 살펴보므로 양방향 LSTM 계층 bilstmLayer를 사용합니다.

입력 신호가 각각 1개의 차원을 가지므로 입력 크기가 크기 1로 구성된 시퀀스가 되도록 지정합니다. 출력 크기가 100인 양방향 LSTM 계층을 지정하고 시퀀스의 마지막 요소를 출력합니다. 다음 명령은 입력 시계열을 100개의 특징으로 매핑하도록 양방향 LSTM 계층에 지시하고 완전 연결 계층에 대한 출력값을 준비합니다. 마지막으로, 크기가 2인 완전 연결 계층을 포함하여 2개의 클래스를 지정하고, 이어서 소프트맥스 계층과 분류 계층을 지정합니다.

layers = [ ...
    sequenceInputLayer(1)
    bilstmLayer(100,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         2 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

다음은 분류기의 훈련 옵션을 지정합니다. 네트워크가 훈련 데이터를 10번 통과하도록 'MaxEpochs'를 10으로 설정합니다. 네트워크가 한 번에 150개의 훈련 신호를 살펴보도록 'MiniBatchSize'를 150으로 설정합니다. 'InitialLearnRate'를 0.01로 설정하면 훈련 과정의 속도를 높이는 데 도움이 됩니다. 시스템이 한 번에 지나치게 많은 데이터를 살펴봄으로써 메모리가 부족해지지 않도록 'SequenceLength'를 1000으로 지정하여 신호를 더 작은 조각으로 분할합니다. 기울기가 지나치게 커지지 않도록 'GradientThreshold'를 1로 설정하여 훈련 과정을 안정화합니다. 'Plots''training-progress'로 지정하여 반복 횟수가 늘어남에 따라 훈련 진행 상황을 그래픽으로 표시하는 플롯을 생성합니다. 'Verbose'false로 설정하여 플롯에 표시되는 데이터에 대응되는 표의 형태로 출력값이 표시되지 않도록 합니다. 이 표를 보려면 'Verbose'true로 설정하십시오.

이 예제에서는 ADAM(적응적 모멘트 추정) 솔버를 사용합니다. ADAM은 LSTM과 같은 순환 신경망(RNN)에서 디폴트 값인 모멘텀을 사용한 확률적 경사하강법(SGDM) 솔버보다 더 나은 성능을 보입니다.

options = trainingOptions('adam', ...
    'MaxEpochs',10, ...
    'MiniBatchSize', 150, ...
    'InitialLearnRate', 0.01, ...
    'SequenceLength', 1000, ...
    'GradientThreshold', 1, ...
    'ExecutionEnvironment',"auto",...
    'plots','training-progress', ...
    'Verbose',false);

LSTM 네트워크 훈련시키기

trainNetwork를 사용하여 지정된 훈련 옵션과 계층 아키텍처로 LSTM 네트워크를 훈련시킵니다. 훈련 세트가 크기 때문에 훈련 과정에 몇 분 정도 걸릴 수 있습니다.

net = trainNetwork(XTrain,YTrain,layers,options);

훈련 진행 과정 플롯의 상단 서브플롯은 훈련 정확도, 즉 각 미니 배치의 분류 정확도를 나타냅니다. 훈련이 성공적으로 진행되면 이 값은 일반적으로 100%를 향해 늘어납니다. 하단 서브플롯에는 훈련 손실, 즉 각 미니 배치의 교차 엔트로피 손실이 표시됩니다. 훈련이 성공적으로 진행되면 이 값은 일반적으로 0을 향해 줄어듭니다.

훈련이 수렴되지 않을 경우 플롯이 특정 상향 또는 하향 방향으로 진행되는 대신 특정 값 사이를 진동할 수 있습니다. 이러한 진동은 훈련 정확도가 증가하지 않고 있으며 훈련 손실이 감소하지 않고 있음을 의미합니다. 이 상황은 훈련 시작부터 발생할 수도 있고, 또는 훈련 정확도가 초기에는 개선되다가 이어서 플롯이 일정해진 경우일 수도 있습니다. 대부분의 경우, 훈련 옵션을 변경하면 네트워크가 수렴하는 데 도움이 됩니다. MiniBatchSize를 줄이거나 InitialLearnRate를 줄이면 훈련 시간이 늘어나긴 하지만 네트워크가 더욱 잘 학습하는 데 도움이 될 수 있습니다.

분류기의 훈련 정확도는 50%와 60% 사이를 진동합니다. 그리고 10회의 Epoch를 마친 후에 이미 훈련에 수 분이 소요된 것을 볼 수 있습니다.

훈련 및 테스트 정확도 시각화하기

훈련이 진행된 신호에 대한 분류기의 정확도를 나타내는 훈련 정확도를 계산합니다. 먼저 훈련 데이터를 분류합니다.

trainPred = classify(net,XTrain,'SequenceLength',1000);

분류 문제에서는 실제 값이 알려져 있는 데이터 세트에 대한 분류기의 성능을 시각화하기 위해 정오분류표가 사용됩니다. 목표 클래스는 신호의 실측 레이블이고, 출력 클래스는 네트워크가 신호에 할당한 레이블입니다. 좌표축 레이블은 클래스 레이블, 즉 AFib(A)와 정상(N)을 나타냅니다.

confusionchart 명령을 사용하여 테스트 데이터 예측값에 대한 전체적인 분류 정확도를 계산합니다. 'RowSummary''row-normalized'로 지정하여 행 요약에 참양성률과 거짓양성률을 표시합니다. 또한, 'ColumnSummary''column-normalized'로 지정하여 열 요약에 양성예측도와 오발견율을 표시합니다.

LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 61.2100
figure
ccLSTM = confusionchart(YTrain,trainPred);
ccLSTM.Title = 'Confusion Chart for LSTM';
ccLSTM.ColumnSummary = 'column-normalized';
ccLSTM.RowSummary = 'row-normalized';

정오분류표를 보면 실측 AFib 신호의 81.7%가 AFib로 올바르게 분류되었고 실측 정상 신호의 31.1%가 정상으로 올바르게 분류되었음을 알 수 있습니다. 아울러 AFib로 분류된 신호의 54.2%가 실제로 AFib이고, 정상으로 분류된 신호의 63.0%가 실제로 정상임을 알 수 있습니다. 전체적인 훈련 정확도는 56.4%입니다.

이번에는 같은 네트워크를 사용하여 테스트 데이터를 분류합니다.

testPred = classify(net,XTest,'SequenceLength',1000);

테스트 정확도를 계산하고, 분류 성능을 정오분류표로 시각화합니다.

LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
LSTMAccuracy = 61.4286
figure
ccLSTM = confusionchart(YTest,testPred);
ccLSTM.Title = 'Confusion Chart for LSTM';
ccLSTM.ColumnSummary = 'column-normalized';
ccLSTM.RowSummary = 'row-normalized';

이 정오분류표는 훈련 정오분류표와 비슷합니다. 전체적인 테스트 정확도는 55.8%입니다.

특징 추출을 사용하여 성능 개선하기

데이터에서 특징을 추출하여 분류기의 훈련 및 테스트 정확도를 개선할 수 있습니다. 어느 특징을 추출할지 판단하기 위해 이 예제에서는 스펙트로그램과 같은 시간-주파수 이미지를 계산하는 방법을 따르며, 이를 사용하여 컨벌루션 신경망(CNN)을 훈련시킵니다. 이러한 응용 사례에 대한 자세한 내용은 [4]와 [5]를 참조하십시오.

각 신호 유형의 스펙트로그램을 시각화합니다.

fs = 300;

figure
subplot(2,1,1);
pspectrum(normal,fs,'spectrogram','TimeResolution',0.5)
title('Normal Signal')

subplot(2,1,2);
pspectrum(aFib,fs,'spectrogram','TimeResolution',0.5)
title('AFib Signal')

이 예제에서는 CNN 대신 LSTM을 사용하므로 접근 방법을 1차원 신호에 적용되도록 변환하는 것이 중요합니다. 시간-주파수(TF) 적률은 스펙트로그램에서 정보를 추출합니다. 각 적률은 LSTM에 입력할 1차원 특징으로 사용할 수 있습니다.

시간 영역에서 두 개의 TF 적률을 검토합니다.

  • 순시 주파수(instfreq)

  • 스펙트럼 엔트로피(pentropy)

instfreq 함수는 신호의 시간 종속 주파수를 전력 스펙트로그램의 첫 번째 적률로 추정합니다. 함수는 시간 윈도우에 대해 단시간 푸리에 변환을 사용하여 스펙트로그램을 계산합니다. 이 예제에서는 255개의 시간 윈도우를 사용합니다. 함수의 시간 출력값은 시간 윈도우의 중앙에 대응됩니다.

각 신호 유형에 대해 순시 주파수를 시각화합니다.

[instFreqA,tA] = instfreq(aFib,fs);
[instFreqN,tN] = instfreq(normal,fs);

figure
subplot(2,1,1);
plot(tN,instFreqN)
title('Normal Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

subplot(2,1,2);
plot(tA,instFreqA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

cellfun 을 사용하여 훈련 세트와 테스트 세트의 모든 셀에 instfreq 함수를 적용합니다.

instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,'UniformOutput',false);
instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,'UniformOutput',false);

스펙트럼 엔트로피는 신호의 스펙트럼이 얼마나 뾰족하거나 평탄한지 측정합니다. 정현파의 합과 같이 스펙트럼이 뾰족한 신호는 스펙트럼 엔트로피가 낮습니다. 백색 잡음과 같이 스펙트럼이 평탄한 신호는 스펙트럼 엔트로피가 높습니다. pentropy 함수는 전력 스펙트로그램을 바탕으로 스펙트럼 엔트로피를 추정합니다. 순시 주파수 추정의 경우와 마찬가지로, pentropy는 255개의 시간 윈도우를 사용하여 스펙트로그램을 계산합니다. 함수의 시간 출력값은 시간 윈도우의 중앙에 대응됩니다.

각 신호 유형에 대해 스펙트럼 엔트로피를 시각화합니다.

[pentropyA,tA2] = pentropy(aFib,fs);
[pentropyN,tN2] = pentropy(normal,fs);

figure

subplot(2,1,1)
plot(tN2,pentropyN)
title('Normal Signal')
ylabel('Spectral Entropy')

subplot(2,1,2)
plot(tA2,pentropyA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Spectral Entropy')

cellfun을 사용하여 훈련 세트와 테스트 세트의 모든 셀에 pentropy 함수를 적용합니다.

pentropyTrain = cellfun(@(x)pentropy(x,fs)',XTrain,'UniformOutput',false);
pentropyTest = cellfun(@(x)pentropy(x,fs)',XTest,'UniformOutput',false);

새 훈련 세트와 테스트 세트의 각 셀이 2개의 차원, 즉 2개의 특징을 갖도록 특징들을 결합합니다.

XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,'UniformOutput',false);
XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);

새 입력값의 형식을 시각화합니다. 각 셀은 더 이상 길이가 샘플 9000개인 신호를 포함하지 않습니다. 이제 각 길이가 샘플 255개인 2개의 특징을 포함합니다.

XTrain2(1:5)
ans = 5×1 cell array
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}

데이터 표준화하기

순시 주파수와 스펙트럼 엔트로피의 평균값은 거의 10배 차이가 납니다. 또한, 순시 주파수의 평균은 LSTM이 효과적으로 학습을 진행하기에는 지나치게 높을 수 있습니다. 네트워크가 평균이 크고 값의 범위가 큰 데이터에 대해 피팅된 경우, 큰 입력값은 네트워크의 학습과 수렴 속도를 저하할 수 있습니다 [6].

mean(instFreqN)
ans = 5.5615
mean(pentropyN)
ans = 0.6326

훈련 세트의 평균과 표준편차를 사용하여 훈련 세트와 테스트 세트를 표준화합니다. 표준화(z-점수화)는 훈련 중에 네트워크 성능을 개선하는 방법으로 널리 사용됩니다.

XV = [XTrain2{:}];
mu = mean(XV,2);
sg = std(XV,[],2);

XTrainSD = XTrain2;
XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,'UniformOutput',false);

XTestSD = XTest2;
XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,'UniformOutput',false);

표준화된 순시 주파수와 스펙트럼 엔트로피의 평균을 표시합니다.

instFreqNSD = XTrainSD{1}(1,:);
pentropyNSD = XTrainSD{1}(2,:);

mean(instFreqNSD)
ans = -0.3210
mean(pentropyNSD)
ans = -0.2416

LSTM 네트워크 아키텍처 수정하기

이제 신호가 각각 2개의 차원을 갖습니다. 따라서 입력 시퀀스의 크기를 2로 지정하여 네트워크 아키텍처를 수정해야 합니다. 출력 크기가 100인 양방향 LSTM 계층을 지정하고 시퀀스의 마지막 요소를 출력합니다. 크기가 2인 완전 연결 계층을 포함하여 2개의 클래스를 지정하고, 이어서 소프트맥스 계층과 분류 계층을 지정합니다.

layers = [ ...
    sequenceInputLayer(2)
    bilstmLayer(100,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 2 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         2 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

훈련 옵션을 지정합니다. 네트워크가 훈련 데이터를 30번 통과하도록 최대 Epoch 횟수를 30으로 설정합니다.

options = trainingOptions('adam', ...
    'MaxEpochs',30, ...
    'MiniBatchSize', 150, ...
    'InitialLearnRate', 0.01, ...
    'GradientThreshold', 1, ...
    'ExecutionEnvironment',"auto",...
    'plots','training-progress', ...
    'Verbose',false);

시간-주파수 특징으로 LSTM 네트워크 훈련시키기

trainNetwork를 사용하여 지정된 훈련 옵션과 계층 아키텍처로 LSTM 네트워크를 훈련시킵니다.

net2 = trainNetwork(XTrainSD,YTrain,layers,options);

훈련 정확도가 크게 개선되어 이제 90%를 상회합니다. 교차 엔트로피 손실은 0을 향해가고 있습니다. 또한, TF 적률이 원시 시퀀스보다 짧기 때문에 훈련에 필요한 시간이 줄어듭니다.

훈련 및 테스트 정확도 시각화하기

업데이트된 LSTM 네트워크를 사용하여 훈련 데이터를 분류합니다. 분류 성능을 정오분류표로 시각화합니다.

trainPred2 = classify(net2,XTrainSD);
LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
LSTMAccuracy = 83.5174
figure
ccLSTM = confusionchart(YTrain,trainPred2);
ccLSTM.Title = 'Confusion Chart for LSTM';
ccLSTM.ColumnSummary = 'column-normalized';
ccLSTM.RowSummary = 'row-normalized';

업데이트된 네트워크를 사용하여 테스트 데이터를 분류합니다. 정오분류표를 플로팅하여 테스트 정확도를 검토합니다.

testPred2 = classify(net2,XTestSD);

LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
LSTMAccuracy = 82.5510
figure
ccLSTM = confusionchart(YTest,testPred2);
ccLSTM.Title = 'Confusion Chart for LSTM';
ccLSTM.ColumnSummary = 'column-normalized';
ccLSTM.RowSummary = 'row-normalized';

결론

이 예제에서는 LSTM 네트워크를 사용하여 심전도 신호에서 심방세동을 검출하는 분류기를 만드는 방법을 소개했습니다. 이 절차에서는 주로 건강한 환자로 구성된 데이터에서 이상 조건을 검출하고자 할 때 발생하는 분류 편향을 방지하기 위해 오버샘플링을 사용했습니다. 원시 신호 데이터를 사용하여 LSTM 네트워크를 훈련시키면 분류 정확도가 매우 낮아집니다. 각 신호에 대해 2개의 시간-주파수 적률 특징을 사용하여 네트워크를 훈련시키면 분류 성능이 크게 개선되고 훈련 시간도 줄어듭니다.

참고 문헌

[1] AF Classification from a Short Single Lead ECG Recording: the PhysioNet/Computing in Cardiology Challenge, 2017. https://physionet.org/challenge/2017/

[2] Clifford, Gari, Chengyu Liu, Benjamin Moody, Li-wei H. Lehman, Ikaro Silva, Qiao Li, Alistair Johnson, and Roger G. Mark. "AF Classification from a Short Single Lead ECG Recording: The PhysioNet Computing in Cardiology Challenge 2017." Computing in Cardiology (Rennes: IEEE). Vol. 44, 2017, pp. 1–4.

[3] Goldberger, A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals". Circulation. Vol. 101, No. 23, 13 June 2000, pp. e215–e220. http://circ.ahajournals.org/content/101/23/e215.full

[4] Pons, Jordi, Thomas Lidy, and Xavier Serra. "Experimenting with Musically Motivated Convolutional Neural Networks". 14th International Workshop on Content-Based Multimedia Indexing (CBMI). June 2016.

[5] Wang, D. "Deep learning reinvents the hearing aid," IEEE Spectrum, Vol. 54, No. 3, March 2017, pp. 32–37. doi: 10.1109/MSPEC.2017.7864754.

[6] Brownlee, Jason. How to Scale Data for Long Short-Term Memory Networks in Python. 7 July 2017. https://machinelearningmastery.com/how-to-scale-data-for-long-short-term-memory-networks-in-python/.

참고 항목

함수

관련 항목