Main Content

LSTM 신경망의 활성화 시각화

이 예제에서는 활성화를 추출하여 LSTM 신경망에서 학습한 특징을 검사하고 시각화하는 방법을 보여줍니다.

사전 훈련된 신경망을 불러옵니다. JapaneseVowelsNet은 [1]과 [2]에서 설명한 Japanese Vowels 데이터셋에서 사전 훈련된 LSTM 신경망입니다. 이 데이터 세트는 미니 배치 크기 27을 가지며 시퀀스 길이를 기준으로 정렬된 시퀀스에서 훈련되었습니다.

load JapaneseVowelsNet

신경망 아키텍처를 표시합니다.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

테스트 데이터를 불러옵니다.

[XTest,YTest] = japaneseVowelsTestData;

첫 번째 시계열을 플롯으로 시각화합니다. 선은 각각 하나의 특징에 대응됩니다.

X = XTest{1};

figure
plot(XTest{1}')
xlabel("Time Step")
title("Test Observation 1")
numFeatures = size(XTest{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes object. The axes object with title Test Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

시퀀스의 각 시간 스텝마다 해당 시간 스텝의 LSTM 계층(계층 2)에 의한 활성화 출력을 가져오고 신경망 상태를 업데이트합니다.

sequenceLength = size(X,2);
idxLayer = 2;
outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength
    features(:,i) = activations(net,X(:,i),idxLayer);
    [net, YPred(i)] = classifyAndUpdateState(net,X(:,i));
end

히트맵을 사용하여 첫 번째 10개의 은닉 유닛을 시각화합니다.

figure
heatmap(features(1:10,:));
xlabel("Time Step")
ylabel("Hidden Unit")
title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

히트맵은 각 은닉 유닛이 얼마나 강하게 활성화되는지, 그리고 시간이 지남에 따라 활성화가 어떻게 변하는지 보여줍니다.

참고 문헌

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

참고 항목

| | | | |

관련 항목