Main Content

Simulink에서 예측을 수행하고 신경망 상태 업데이트하기

이 예제에서는 Stateful Predict 블록을 사용하여 Simulink®에서 훈련된 순환 신경망의 응답 변수 예측을 수행하는 방법을 보여줍니다. 이 예제에서는 사전 훈련된 장단기 기억(LSTM) 신경망을 사용합니다.

사전 훈련된 신경망 불러오기

사전 훈련된 장단기 기억(LSTM) 신경망 JapaneseVowelsNet을 불러옵니다. 이것은 [1]과 [2]에서 설명한 Japanese Vowels 데이터 세트에서 훈련된 신경망입니다. 이 신경망은 미니 배치 크기 27을 가지며 시퀀스 길이를 기준으로 정렬된 시퀀스에서 훈련되었습니다.

load JapaneseVowelsNet

신경망 아키텍처를 표시합니다.

analyzeNetwork(net);

테스트 데이터 불러오기

Japanese Vowels 테스트 데이터를 불러옵니다. XTest는 12개 차원으로 된 서로 다른 길이의 시퀀스 370개를 포함하는 셀형 배열입니다. TTest는 9명의 화자에 대응하는 레이블 "1","2",..."9"로 구성된 categorical형 벡터입니다.

타임스탬프가 지정된 행과 X의 반복되는 복사본을 포함하는 timetable형 배열 simin을 생성합니다.

load JapaneseVowelsTestData
X = XTest{94};
numTimeSteps = size(X,2);
simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));

응답 변수 예측을 위한 Simulink 모델

응답 변수 예측을 위한 Simulink 모델은 점수를 예측하는 Stateful Predict 블록과 시간 스텝에 대한 입력 데이터 시퀀스를 불러오는 From Workspace 블록을 포함합니다.

시뮬레이션 중에 순환 신경망의 상태를 초기 상태로 재설정하려면 Stateful Predict 블록을 Resettable Subsystem 내부에 배치하고 Reset 제어 신호를 트리거로 사용하십시오.

open_system('StatefulPredictExample');

시뮬레이션을 위한 모델 구성하기

Stateful Predict 블록의 모델 구성 파라미터를 설정합니다.

set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

시뮬레이션 실행하기

JapaneseVowelsNet 신경망에 대한 응답 변수를 계산하기 위해 시뮬레이션을 실행합니다. 예측 점수는 MATLAB® 작업 공간에 저장됩니다.

out = sim('StatefulPredictExample');

예측 점수를 플로팅합니다. 이 플롯은 시간 스텝 간에 예측 점수가 변화하는 것을 보여줍니다.

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
figure
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
ylabel("Score")
title("Prediction Scores Over Time Steps")

시간 스텝에 대한 예측 점수에서 올바른 클래스를 강조 표시합니다.

trueLabel = TTest(94);
lines(trueLabel).LineWidth = 3;

최종 시간 스텝 예측을 막대 차트로 표시합니다.

figure
bar(scores(:,end))
title("Final Prediction Scores")
xlabel("Class")
ylabel("Score")

참고 문헌

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

참고 항목

| | |

관련 항목