Main Content

딥러닝을 사용한 sequence-to-one 회귀

이 예제에서는 장단기 기억(LSTM) 신경망을 사용하여 파형의 주파수를 예측하는 방법을 보여줍니다.

LSTM 신경망과 시퀀스 및 목표값으로 구성된 훈련 세트를 사용하여 시퀀스의 숫자형 응답 변수를 예측할 수 있습니다. LSTM 신경망은 루프를 사용하여 시간 스텝을 순회하고 신경망 상태를 업데이트하여 입력 데이터를 처리하는 순환 신경망(RNN)입니다. 신경망 상태에는 이전 시간 스텝에서 기억한 정보가 포함됩니다. 시퀀스의 숫자형 응답 변수의 예는 다음과 같습니다.

  • 주파수, 최댓값, 평균 같은 시퀀스의 속성.

  • 시퀀스의 과거 또는 미래 시간 스텝 값.

이 예제는 파형 데이터 세트를 사용하여 sequence-to-one 회귀 LSTM 신경망을 훈련시킵니다. 이 데이터 세트는 3개 채널로 이루어진 다양한 길이로 생성된 1000개의 합성 파형 데이터를 포함합니다. 일반적인 방법을 사용하여 파형의 주파수를 결정하려면 fft 항목을 참조하십시오.

시퀀스 데이터 불러오기

WaveformData.mat에서 예제 데이터를 불러옵니다. 이 데이터는 시퀀스로 구성된 numObservations×1 셀형 배열이며, 여기서 numObservations는 시퀀스 개수입니다. 각 시퀀스는 numChannels×numTimeSteps 숫자형 배열이며, 여기서 numChannels는 시퀀스의 채널 개수이고 numTimeSteps는 시퀀스의 시간 스텝 개수입니다. 대응하는 목표값은 파형 주파수로 구성된 numObservations×numResponses 숫자형 배열에 있습니다. 여기서 numResponses는 목표값의 채널 개수입니다.

load WaveformData

관측값 개수를 확인합니다.

numObservations = numel(data)
numObservations = 1000

처음 몇 개 시퀀스의 크기와 이에 대응하는 주파수를 확인합니다.

data(1:4)
ans=4×1 cell array
    {3×103 double}
    {3×136 double}
    {3×140 double}
    {3×124 double}

freq(1:4,:)
ans = 4×1

    5.8922
    2.2557
    4.5250
    4.4418

시퀀스의 채널 개수를 확인합니다. 신경망 훈련의 경우, 각 시퀀스의 채널 개수가 같아야 합니다.

numChannels = size(data{1},1)
numChannels = 3

응답 변수의 개수(목표값의 채널 개수)를 확인합니다.

numResponses = size(freq,2)
numResponses = 1

처음 몇 개의 시퀀스를 플롯으로 시각화합니다.

figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{i}', DisplayLabels="Channel " + (1:numChannels))

    xlabel("Time Step")
    title("Frequency: " + freq(i))
end

훈련을 위해 데이터 준비하기

검증과 테스트를 위해 데이터를 남겨 둡니다. 데이터의 80%가 포함된 훈련 세트, 데이터의 10%가 포함된 검증 세트, 데이터의 나머지 10%가 포함된 테스트 세트로 데이터를 분할합니다.

[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations, [0.8 0.1 0.1]);

XTrain = data(idxTrain);
XValidation = data(idxValidation);
XTest = data(idxTest);

TTrain = freq(idxTrain);
TValidation = freq(idxValidation);
TTest = freq(idxTest);

LSTM 신경망 아키텍처 정의하기

LSTM 회귀 신경망을 만듭니다.

  • 입력 데이터의 채널 개수와 일치하는 입력 크기를 가진 시퀀스 입력 계층을 사용합니다.

  • 더 적합한 피팅을 위해, 그리고 훈련의 발산을 방지하기 위해 시퀀스 입력 계층의 Normalization 옵션을 "zscore"로 설정합니다. 이는 시퀀스 데이터가 평균 0과 단위 분산을 갖도록 정규화합니다.

  • 100개의 은닉 유닛을 가진 LSTM 계층을 사용합니다. 은닉 유닛의 개수는 계층에서 학습한 정보의 양을 결정합니다. 값이 클수록 더 정확한 결과를 얻을 수 있지만 훈련 데이터에 과적합을 초래할 가능성이 더 커질 수 있습니다.

  • 각 시퀀스의 단일 시간 스텝을 출력하기 위해 LSTM 계층의 OutputMode 옵션을 "last"로 설정합니다.

  • 예측할 값의 개수를 지정하기 위해 예측 변수의 개수와 크기가 일치하는 완전 연결 계층을 포함시키고 그 뒤에 회귀 계층을 포함시킵니다.

numHiddenUnits = 100;

layers = [ ...
    sequenceInputLayer(numChannels, Normalization="zscore")
    lstmLayer(numHiddenUnits, OutputMode="last")
    fullyConnectedLayer(numResponses)
    regressionLayer]
layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input      Sequence input with 3 dimensions
     2   ''   LSTM                LSTM with 100 hidden units
     3   ''   Fully Connected     1 fully connected layer
     4   ''   Regression Output   mean-squared-error

훈련 옵션 지정하기

훈련 옵션을 지정합니다.

  • Adam 최적화 함수를 사용하여 훈련시킵니다.

  • 훈련을 Epoch 250회 수행합니다. 크기가 큰 데이터 세트의 경우에는 양호한 피팅을 위해 이렇게 많은 Epoch 횟수만큼 훈련시키지 않아도 될 수 있습니다.

  • 검증에 사용되는 시퀀스와 응답 변수를 지정합니다.

  • 최적, 즉 검증 손실이 가장 적은 신경망을 출력합니다.

  • 학습률을 0.005로 설정합니다.

  • 각 미니 배치의 시퀀스가 가장 짧은 시퀀스와 길이가 같아지도록 자릅니다. 시퀀스를 자르면 데이터를 버리게 되더라도 채우기를 추가하지 않습니다. 시퀀스의 모든 시간 스텝에 중요한 정보가 포함될 가능성이 있는 시퀀스의 경우 자르기를 사용하면 신경망이 적합한 피팅을 달성하지 못할 수 있습니다.

  • 플롯에 훈련 진행 상황을 표시합니다.

  • 상세 출력값을 비활성화합니다.

options = trainingOptions("adam", ...
    MaxEpochs=250, ...
    ValidationData={XValidation TValidation}, ...
    OutputNetwork="best-validation-loss", ...
    InitialLearnRate=0.005, ...
    SequenceLength="shortest", ...
    Plots="training-progress", ...
    Verbose= false);

LSTM 신경망 훈련시키기

trainNetwork 함수를 사용하여 지정된 훈련 옵션으로 LSTM 신경망을 훈련시킵니다.

net = trainNetwork(XTrain, TTrain, layers, options);

LSTM 신경망 테스트하기

테스트 데이터를 사용하여 예측을 수행합니다.

YTest = predict(net,XTest, SequenceLength="shortest");

처음 몇 개의 예측값을 플롯으로 시각화합니다.

figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(XTest{i}',DisplayLabels="Channel " + (1:numChannels))

    xlabel("Time Step")
    title("Predicted Frequency: " + string(YTest(i)))
end

평균 제곱 오차를 히스토그램으로 시각화합니다.

figure
histogram(mean((TTest - YTest).^2,2))
xlabel("Error")
ylabel("Frequency")

전체 평균 제곱 오차의 제곱근을 계산합니다.

rmse = sqrt(mean((YTest-TTest).^2))
rmse = single
    0.6865

실제 주파수에 대해 예측된 주파수를 플로팅합니다.

figure
scatter(YTest,TTest, "b+");
xlabel("Predicted Frequency")
ylabel("Actual Frequency")
hold on

m = min(freq);
M=max(freq);
xlim([m M])
ylim([m M])
plot([m M], [m M], "r--")

참고 항목

| | | |

관련 항목