이 페이지의 내용은 이전 릴리스에 관한 것입니다. 해당 영문 페이지는 최신 릴리스에서 제거되었습니다.

딥러닝을 사용하여 텍스트 데이터 분류하기

이 예제에서는 딥러닝 장단기 기억(LSTM) 신경망을 사용하여 일기 예보의 텍스트 설명을 분류하는 방법을 보여줍니다.

텍스트 데이터는 본질적으로 순차적입니다. 텍스트 조각은 단어로 이루어진 시퀀스로서, 각 단어 사이에는 종속성이 있을 수 있습니다. 장기적인 종속성을 학습하여 이를 시퀀스 데이터를 분류하는 데 사용하려면 LSTM 신경망을 사용하십시오. LSTM 신경망은 시퀀스 데이터의 시간 스텝 간의 장기적인 종속성을 학습할 수 있는 순환 신경망(RNN)의 일종입니다.

LSTM 네트워크에 텍스트를 입력하려면 먼저 텍스트 데이터를 숫자형 시퀀스로 변환하십시오. 이렇게 하려면 문서를 숫자형 인덱스 시퀀스로 매핑하는 단어 인코딩을 사용하면 됩니다. 더 나은 결과를 위해 네트워크에 단어 임베딩 계층을 포함시킵니다. 단어 임베딩은 어휘에 있는 단어를 스칼라형 인덱스가 아닌 숫자형 벡터로 매핑합니다. 이러한 임베딩은 비슷한 의미를 갖는 단어들이 비슷한 벡터를 갖도록 단어의 의미 체계 정보를 캡처합니다. 벡터 연산을 통해 단어 사이의 관계도 모델링합니다. 예를 들어, "왕 대 여왕은 남자 대 여자와 같다"라는 관계는 남자 + 여자 = 여왕이라는 식으로 설명됩니다.

이 예제에서는 다음과 같은 네 단계를 사용하여 LSTM 네트워크를 훈련시키고 사용합니다.

  • 데이터를 가져오고 전처리합니다.

  • 단어 인코딩을 사용하여 단어를 숫자 시퀀스로 변환합니다.

  • 단어 임베딩 계층을 사용하여 LSTM 네트워크를 만들고 훈련시킵니다.

  • 훈련된 LSTM 네트워크를 사용하여 새로운 텍스트 데이터를 분류합니다.

데이터 가져오기

일기 예보 데이터를 가져옵니다. 이 데이터는 날씨 이벤트에 대한 텍스트로 된 설명을 포함합니다. 텍스트 데이터를 문자열로 가져오도록 텍스트 유형을 'string'으로 지정하십시오.

filename = "weatherReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×16 table
            Time             event_id          state              event_type         damage_property    damage_crops    begin_lat    begin_lon    end_lat    end_lon                                                                                             event_narrative                                                                                             storm_duration    begin_day    end_day    year       end_timestamp    
    ____________________    __________    ________________    ___________________    _______________    ____________    _________    _________    _______    _______    _________________________________________________________________________________________________________________________________________________________________________________________________    ______________    _________    _______    ____    ____________________

    22-Jul-2016 16:10:00    6.4433e+05    "MISSISSIPPI"       "Thunderstorm Wind"       ""                "0.00K"         34.14        -88.63     34.122     -88.626    "Large tree down between Plantersville and Nettleton."                                                                                                                                                  00:05:00          22          22       2016    22-Jul-0016 16:15:00
    15-Jul-2016 17:15:00    6.5182e+05    "SOUTH CAROLINA"    "Heavy Rain"              "2.00K"           "0.00K"         34.94        -81.03      34.94      -81.03    "One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water."       00:00:00          15          15       2016    15-Jul-0016 17:15:00
    15-Jul-2016 17:25:00    6.5183e+05    "SOUTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.01        -80.93      35.01      -80.93    "NWS Columbia relayed a report of trees blown down along Tom Hall St."                                                                                                                                  00:00:00          15          15       2016    15-Jul-0016 17:25:00
    16-Jul-2016 12:46:00    6.5183e+05    "NORTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.64        -82.14      35.64      -82.14    "Media reported two trees blown down along I-40 in the Old Fort area."                                                                                                                                  00:00:00          16          16       2016    16-Jul-0016 12:46:00
    15-Jul-2016 14:28:00    6.4332e+05    "MISSOURI"          "Hail"                    ""                ""              36.45        -89.97      36.45      -89.97    ""                                                                                                                                                                                                      00:07:00          15          15       2016    15-Jul-0016 14:35:00
    15-Jul-2016 16:31:00    6.4332e+05    "ARKANSAS"          "Thunderstorm Wind"       ""                "0.00K"         35.85         -90.1     35.838     -90.087    "A few tree limbs greater than 6 inches down on HWY 18 in Roseland."                                                                                                                                    00:09:00          15          15       2016    15-Jul-0016 16:40:00
    15-Jul-2016 16:03:00    6.4343e+05    "TENNESSEE"         "Thunderstorm Wind"       "20.00K"          "0.00K"        35.056       -89.937      35.05     -89.904    "Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins."                                                                                     00:07:00          15          15       2016    15-Jul-0016 16:10:00
    15-Jul-2016 17:27:00    6.4344e+05    "TENNESSEE"         "Hail"                    ""                ""             35.385        -89.78     35.385      -89.78    "Quarter size hail near Rosemark."                                                                                                                                                                      00:05:00          15          15       2016    15-Jul-0016 17:32:00

일기 예보가 비어 있는 행은 테이블에서 제거합니다.

idxEmpty = strlength(data.event_narrative) == 0;
data(idxEmpty,:) = [];

이 예제의 목표는 event_type 열의 레이블을 기준으로 이벤트를 분류하는 것입니다. 데이터를 클래스별로 나누기 위해 레이블을 categorical형으로 변환합니다.

data.event_type = categorical(data.event_type);

히스토그램을 사용하여 데이터의 클래스 분포를 표시합니다. 레이블을 읽기 쉽도록 하려면 Figure의 너비를 늘리십시오.

f = figure;
f.Position(3) = 1.5*f.Position(3);

h = histogram(data.event_type);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

많은 클래스가 적은 수의 관측값을 포함하고 있어서 데이터의 클래스 간에 균형이 맞지 않습니다. 이런 식으로 클래스 간에 균형이 맞지 않으면 네트워크가 덜 정확한 모델로 수렴할 수 있습니다. 이 문제를 방지하려면 10회보다 적게 나타나는 클래스를 모두 제거하십시오.

히스토그램에서 클래스의 빈도 수와 클래스 이름을 가져옵니다.

classCounts = h.BinCounts;
classNames = h.Categories;

관측값이 10개보다 작은 클래스를 찾습니다.

idxLowCounts = classCounts < 10;
infrequentClasses = classNames(idxLowCounts)
infrequentClasses = 1×8 cell array
    {'Freezing Fog'}    {'Hurricane'}    {'Lakeshore Flood'}    {'Marine Dense Fog'}    {'Marine Strong Wind'}    {'Marine Tropical Depression'}    {'Seiche'}    {'Sneakerwave'}

빈도가 적은 이러한 클래스를 데이터에서 제거합니다. removecats를 사용하여 categorical형 데이터에서 사용되지 않는 범주를 제거합니다.

idxInfrequent = ismember(data.event_type,infrequentClasses);
data(idxInfrequent,:) = [];
data.event_type = removecats(data.event_type);

데이터는 이제 적당한 크기의 클래스로 정렬되어 있습니다. 다음 단계는 데이터를 훈련 세트, 검증 세트, 테스트 세트로 분할하는 것입니다. 데이터를 훈련 파티션, 그리고 검증과 테스트를 위한 홀드아웃 파티션으로 분할합니다. 홀드아웃 백분율을 30%로 지정합니다.

cvp = cvpartition(data.event_type,'Holdout',0.3);
dataTrain = data(training(cvp),:);
dataHeldOut = data(test(cvp),:);

홀드아웃 세트를 다시 분할하여 검증 세트를 얻습니다. 홀드아웃 백분율을 50%로 지정합니다. 이렇게 하면 훈련 관측값 70%, 검증 관측값 15%, 테스트 관측값 15%로 데이터가 분할됩니다.

cvp = cvpartition(dataHeldOut.event_type,'HoldOut',0.5);
dataValidation = dataHeldOut(training(cvp),:);
dataTest = dataHeldOut(test(cvp),:);

분할된 테이블에서 텍스트 데이터와 레이블을 추출합니다.

textDataTrain = dataTrain.event_narrative;
textDataValidation = dataValidation.event_narrative;
textDataTest = dataTest.event_narrative;
YTrain = dataTrain.event_type;
YValidation = dataValidation.event_type;
YTest = dataTest.event_type;

데이터를 올바르게 가져왔는지 확인하려면 단어 구름을 사용하여 훈련 텍스트 데이터를 시각화하십시오.

figure
wordcloud(textDataTrain);
title("Training Data")

텍스트 데이터 전처리하기

텍스트 데이터를 토큰화하고 전처리하는 함수를 만듭니다. 이 예제의 마지막에 나오는 함수 preprocessText는 다음 단계를 수행합니다.

  1. tokenizedDocument를 사용하여 텍스트를 토큰화합니다.

  2. lower를 사용하여 텍스트를 소문자로 변환합니다.

  3. erasePunctuation을 사용하여 문장 부호를 지웁니다.

preprocessText 함수를 사용하여 훈련 데이터와 검증 데이터를 전처리합니다.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

전처리된 처음 몇 개의 훈련 문서를 표시합니다.

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     7 tokens: large tree down between plantersville and nettleton
    37 tokens: one to two feet of deep standing water developed on a street on the winthrop university campus after more than an inch of rain fell in less than an hour one vehicle was stalled in the water
    13 tokens: nws columbia relayed a report of trees blown down along tom hall st
    13 tokens: media reported two trees blown down along i40 in the old fort area
    14 tokens: a few tree limbs greater than 6 inches down on hwy 18 in roseland

문서를 시퀀스로 변환하기

문서를 LSTM 네트워크에 입력하려면 단어 인코딩을 사용하여 문서를 숫자형 인덱스로 구성된 시퀀스로 변환하십시오.

단어 인코딩을 만들려면 wordEncoding 함수를 사용하십시오.

enc = wordEncoding(documentsTrain);

다음 변환 단계는 문서가 모두 같은 길이가 되도록 채우고 자르는 것입니다. trainingOptions 함수는 입력 시퀀스를 자동으로 채우고 자르는 옵션을 제공합니다. 그러나 이러한 옵션은 단어 벡터로 구성된 시퀀스에 적합하지 않습니다. 이러한 옵션을 사용하는 대신 시퀀스를 수동으로 채우고 자릅니다. 단어 벡터로 구성된 시퀀스를 왼쪽을 채우고 자르면 훈련이 향상될 수 있습니다.

문서를 채우고 자르려면 먼저 목표 길이를 선택하고, 목표 길이보다 긴 문서는 자르고 목표 길이보다 짧은 문서는 왼쪽을 채우십시오. 최상의 결과를 위해 목표 길이는 다량의 데이터가 버려지지 않을 만큼 짧아야 합니다. 적당한 목표 길이를 찾으려면 훈련 문서의 길이를 히스토그램으로 표시해 보십시오.

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

대부분의 훈련 문서가 75개 미만의 토큰을 갖습니다. 이 값을 자르기와 채우기의 목표 길이로 사용합니다.

doc2sequence를 사용하여 문서를 숫자형 인덱스로 구성된 시퀀스로 변환합니다. 시퀀스의 길이가 75가 되도록 자르거나 왼쪽을 채우려면 'Length' 옵션을 75로 설정하십시오.

XTrain = doc2sequence(enc,documentsTrain,'Length',75);
XTrain(1:5)
ans=5×1 cell
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

동일한 옵션을 사용하여 검증 문서를 시퀀스로 변환합니다.

XValidation = doc2sequence(enc,documentsValidation,'Length',75);

LSTM 네트워크 만들고 훈련시키기

LSTM 네트워크 아키텍처를 정의합니다. 네트워크에 시퀀스 데이터를 입력하려면 시퀀스 입력 계층을 포함시키고 입력 크기를 1로 설정하십시오. 다음으로, 차원이 100이고 단어 인코딩과 동일한 단어 개수를 갖는 단어 임베딩 계층을 포함시킵니다. 다음으로, LSTM 계층을 포함시키고 은닉 유닛의 개수를 180으로 설정합니다. sequence-to-label 분류 문제에서 LSTM 계층을 사용하려면 출력 모드를 'last'로 설정하십시오. 마지막으로, 클래스 개수와 동일한 크기를 갖는 완전 연결 계층, 소프트맥스 계층, 분류 계층을 추가합니다.

inputSize = 1;
embeddingDimension = 100;
numWords = enc.NumWords;
numHiddenUnits = 180;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 100 dimensions and 16954 unique words
     3   ''   LSTM                    LSTM with 180 hidden units
     4   ''   Fully Connected         39 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

훈련 옵션을 지정합니다. 솔버를 'adam'으로 설정하고 훈련은 Epoch 10회, 기울기 임계값은 1로 설정합니다. 초기 학습률을 0.01로 설정합니다. 훈련 진행 상황을 모니터링하려면 'Plots' 옵션을 'training-progress'로 설정하십시오. 'ValidationData' 옵션을 사용하여 검증 데이터를 지정합니다. 세부 정보가 출력되지 않도록 하려면 'Verbose'false로 설정하십시오.

기본적으로 trainNetwork는 GPU를 사용할 수 있으면 GPU를 사용합니다(Parallel Computing Toolbox™와 Compute Capability 3.0 이상의 CUDA® 지원 GPU 필요). GPU가 없으면 CPU를 사용합니다. 실행 환경을 수동으로 지정하려면 trainingOptions'ExecutionEnvironment' 이름-값 쌍 인수를 사용하십시오. CPU에서 훈련시키면 GPU에서 훈련시키는 것보다 시간이 상당히 오래 걸릴 수 있습니다.

options = trainingOptions('adam', ...
    'MaxEpochs',10, ...    
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

trainNetwork 함수를 사용하여 LSTM 네트워크를 훈련시킵니다.

net = trainNetwork(XTrain,YTrain,layers,options);

LSTM 네트워크 테스트하기

LSTM 네트워크를 테스트하려면 먼저 훈련 데이터와 같은 방식으로 테스트 데이터를 준비하십시오. 그런 다음 전처리된 테스트 데이터에 대해 훈련된 LSTM 네트워크 net을 사용하여 예측을 수행합니다.

훈련 문서와 같은 단계를 사용하여 테스트 데이터를 전처리합니다.

textDataTest = lower(textDataTest);
documentsTest = tokenizedDocument(textDataTest);
documentsTest = erasePunctuation(documentsTest);

훈련 시퀀스를 만들 때와 같은 옵션으로 doc2sequence를 사용하여 테스트 문서를 시퀀스로 변환합니다.

XTest = doc2sequence(enc,documentsTest,'Length',75);
XTest(1:5)
ans=5×1 cell
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

훈련된 LSTM 네트워크를 사용하여 테스트 문서를 분류합니다.

YPred = classify(net,XTest);

분류 정확도를 계산합니다. 정확도는 네트워크가 올바르게 예측한 레이블의 비율입니다.

accuracy = sum(YPred == YTest)/numel(YPred)
accuracy = 0.8684

새 데이터를 사용하여 예측하기

새 일기 예보 3개의 이벤트 유형을 분류합니다. 새 일기 예보를 포함하는 string형 배열을 만듭니다.

reportsNew = [ ...
    "Lots of water damage to computer equipment inside the office."
    "A large tree is downed and blocking traffic outside Apple Hill."
    "Damage to many car windshields in parking lot."];

전처리 단계를 사용하여 텍스트 데이터를 훈련 문서로 전처리합니다.

documentsNew = preprocessText(reportsNew);

훈련 시퀀스를 만들 때와 같은 옵션으로 doc2sequence를 사용하여 텍스트 데이터를 시퀀스로 변환합니다.

XNew = doc2sequence(enc,documentsNew,'Length',75);

훈련된 LSTM 네트워크를 사용하여 새 시퀀스를 분류합니다.

[labelsNew,score] = classify(net,XNew);

일기 예보를 예측된 레이블과 함께 표시합니다.

[reportsNew string(labelsNew)]
ans = 3×2 string array
    "Lots of water damage to computer equipment inside the office."      "Flash Flood"      
    "A large tree is downed and blocking traffic outside Apple Hill."    "Thunderstorm Wind"
    "Damage to many car windshields in parking lot."                     "Hail"             

전처리 함수

함수 preprocessText는 다음 단계를 수행합니다.

  1. tokenizedDocument를 사용하여 텍스트를 토큰화합니다.

  2. lower를 사용하여 텍스트를 소문자로 변환합니다.

  3. erasePunctuation을 사용하여 문장 부호를 지웁니다.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

참고 항목

| | | | | | | |

관련 항목