Main Content

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

resetState

순환 신경망의 상태 재설정

설명

예제

updatedNet = resetState(recNet)은 순환 신경망(예: LSTM 신경망)의 상태를 초기 상태로 재설정합니다.

예제

모두 축소

시퀀스 예측마다 신경망 상태를 재설정합니다.

사전 훈련된 장단기 기억(LSTM) 신경망 JapaneseVowelsNet을 불러옵니다. 이것은 [1]과 [2]에서 설명한 Japanese Vowels 데이터 세트에서 훈련된 신경망입니다. 이 신경망은 미니 배치 크기 27을 가지며 시퀀스 길이를 기준으로 정렬된 시퀀스에서 훈련되었습니다.

load JapaneseVowelsNet

신경망 아키텍처를 표시합니다.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

테스트 데이터를 불러옵니다.

[XTest,YTest] = japaneseVowelsTestData;

시퀀스를 분류하고 신경망 상태를 업데이트합니다. 이 예제의 결과를 재현할 수 있도록 rng'shuffle'로 설정합니다.

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label
label = categorical
     3 

업데이트된 신경망을 사용하여 다른 시퀀스를 분류합니다.

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

최종 예측과 참 레이블을 비교합니다.

trueLabel = YTest(1)
trueLabel = categorical
     1 

신경망의 업데이트된 상태가 분류에 부정적인 영향을 주었을 수 있습니다. 신경망 상태를 재설정하고 시퀀스에 대해 예측을 다시 수행합니다.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

입력 인수

모두 축소

훈련된 순환 신경망으로, SeriesNetwork 또는 DAGNetwork 객체로 지정됩니다. 사전 훈련된 신경망을 가져오거나 trainNetwork 함수를 사용하여 자신만의 고유한 신경망을 훈련시켜 훈련된 신경망을 얻을 수 있습니다.

recNet은 순환 신경망입니다. 이 인수는 적어도 하나의 순환 계층을 가져야 합니다(예: LSTM 신경망). 입력 신경망이 순환 신경망이 아니면 함수는 영향을 미치지 않으며 입력 신경망을 반환합니다.

출력 인수

모두 축소

업데이트된 신경망. updatedNet은 입력 신경망과 동일한 유형의 신경망입니다.

입력 신경망이 순환 신경망이 아니면 함수는 영향을 미치지 않으며 입력 신경망을 반환합니다.

참고 문헌

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

확장 기능

R2017b에 개발됨