Simulink에서 분류하고 신경망 상태 업데이트하기
이 예제에서는 Stateful Classify
블록을 사용하여 Simulink®에서 훈련된 순환 신경망의 데이터를 분류하는 방법을 보여줍니다. 이 예제에서는 사전 훈련된 장단기 기억(LSTM) 신경망을 사용합니다.
사전 훈련된 신경망 불러오기
사전 훈련된 장단기 기억(LSTM) 신경망 JapaneseVowelsNet
을 불러옵니다. 이것은 [1]과 [2]에서 설명한 Japanese Vowels 데이터 세트에서 훈련된 신경망입니다. 이 신경망은 미니 배치 크기 27을 가지며 시퀀스 길이를 기준으로 정렬된 시퀀스에서 훈련되었습니다.
load JapaneseVowelsNet
신경망 아키텍처를 표시합니다.
analyzeNetwork(net);
테스트 데이터 불러오기
Japanese Vowels 테스트 데이터를 불러옵니다. XTest
는 12개 차원으로 된 서로 다른 길이의 시퀀스 370개를 포함하는 셀형 배열입니다. TTest
는 9명의 화자에 대응하는 레이블 "1","2",..."9"로 구성된 categorical형 벡터입니다.
타임스탬프가 지정된 행과 X
의 반복되는 복사본을 포함하는 timetable형 배열 simin
을 생성합니다.
load JapaneseVowelsTestData; X = XTest{94}; numTimeSteps = size(X,2); simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));
데이터 분류를 위한 Simulink 모델
데이터 분류를 위한 Simulink 모델은 레이블을 예측하는 Stateful Classify
블록과 시간 스텝에 대한 입력 데이터 시퀀스를 불러오는 From Workspace
블록을 포함합니다.
시뮬레이션 중에 순환 신경망의 상태를 초기 상태로 재설정하려면 Stateful Classify
블록을 Resettable Subsystem
내부에 배치하고 Reset
제어 신호를 트리거로 사용하십시오.
open_system('StatefulClassifyExample');
시뮬레이션을 위한 모델 구성하기
Stateful Classify
블록의 모델 구성 파라미터를 설정합니다.
set_param('StatefulClassifyExample/Stateful Classify','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulClassifyExample','SimulationMode','Normal');
시뮬레이션 실행하기
JapaneseVowelsNet
신경망에 대한 응답 변수를 계산하기 위해 시뮬레이션을 실행합니다. 예측 레이블은 MATLAB® 작업 공간에 저장됩니다.
out = sim('StatefulClassifyExample');
예측 레이블을 계단 플롯으로 플로팅합니다. 이 플롯은 시간 스텝 간의 예측 변화를 보여줍니다.
labels = squeeze(out.YPred.Data(1:numTimeSteps,1)); figure stairs(labels, '-o') xlim([1 numTimeSteps]) xlabel("Time Step") ylabel("Predicted Class") title("Classification Over Time Steps")
예측 레이블과 참 레이블을 비교합니다. 관측값의 참 레이블을 보여주는 가로선을 플로팅합니다.
trueLabel = double(TTest(94)); hold on line([1 numTimeSteps],[trueLabel trueLabel], ... 'Color','red', ... 'LineStyle','--') legend(["Prediction" "True Label"]) axis([1 numTimeSteps+1 0 9]);
참고 문헌
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
참고 항목
Stateful Predict | Stateful Classify | Predict | Image Classifier