Main Content

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

딥러닝을 사용하여 텍스트 데이터 분류하기

이 예제에서는 딥러닝 장단기 기억(LSTM) 신경망을 사용하여 텍스트 데이터를 분류하는 방법을 보여줍니다.

텍스트 데이터는 본질적으로 순차적입니다. 텍스트 조각은 단어로 이루어진 시퀀스로서, 각 단어 사이에는 종속성이 있을 수 있습니다. 장기적인 종속성을 학습하여 이를 시퀀스 데이터를 분류하는 데 사용하려면 LSTM 신경망을 사용하십시오. LSTM 신경망은 시퀀스 데이터의 시간 스텝 간의 장기적인 종속성을 학습할 수 있는 순환 신경망(RNN)의 일종입니다.

LSTM 신경망에 텍스트를 입력하려면 먼저 텍스트 데이터를 숫자형 시퀀스로 변환하십시오. 이렇게 하려면 문서를 숫자형 인덱스 시퀀스로 매핑하는 단어 인코딩을 사용하면 됩니다. 더 나은 결과를 위해 신경망에 단어 임베딩 계층을 포함시킵니다. 단어 임베딩은 단어집에 있는 단어를 스칼라형 인덱스가 아닌 숫자형 벡터로 매핑합니다. 이러한 임베딩은 비슷한 의미를 갖는 단어들이 비슷한 벡터를 갖도록 단어의 의미 체계 정보를 캡처합니다. 벡터 연산을 통해 단어 사이의 관계도 모델링합니다. 예를 들어, "로마와 이탈리아의 관계는 파리와 프랑스의 관계와 같다"는 이탈리아 로마 + 파리 = 프랑스라는 식으로 설명됩니다.

이 예제에서는 다음과 같은 네 단계를 사용하여 LSTM 신경망을 훈련시키고 사용합니다.

  • 데이터를 가져오고 전처리합니다.

  • 단어 인코딩을 사용하여 단어를 숫자 시퀀스로 변환합니다.

  • 단어 임베딩 계층을 사용하여 LSTM 신경망을 만들고 훈련시킵니다.

  • 훈련된 LSTM 신경망을 사용하여 새로운 텍스트 데이터를 분류합니다.

데이터 가져오기

공장 보고서 데이터를 가져옵니다. 이 데이터는 공장 이벤트에 대한 텍스트로 된 설명을 포함합니다. 텍스트 데이터를 문자열로 가져오도록 텍스트 유형을 'string'으로 지정하십시오.

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

이 예제의 목표는 Category 열의 레이블을 기준으로 이벤트를 분류하는 것입니다. 데이터를 클래스별로 나누기 위해 레이블을 categorical형으로 변환합니다.

data.Category = categorical(data.Category);

히스토그램을 사용하여 데이터의 클래스 분포를 표시합니다.

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

다음 단계는 데이터를 훈련 세트 및 검증 세트로 분할하는 것입니다. 데이터를 훈련 파티션, 그리고 검증과 테스트를 위한 홀드아웃 파티션으로 분할합니다. 홀드아웃 백분율을 20%로 지정합니다.

cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

분할된 테이블에서 텍스트 데이터와 레이블을 추출합니다.

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
YTrain = dataTrain.Category;
YValidation = dataValidation.Category;

데이터를 올바르게 가져왔는지 확인하려면 워드 클라우드를 사용하여 훈련 텍스트 데이터를 시각화하십시오.

figure
wordcloud(textDataTrain);
title("Training Data")

텍스트 데이터 전처리하기

텍스트 데이터를 토큰화하고 전처리하는 함수를 만듭니다. 이 예제의 마지막에 나오는 함수 preprocessText는 다음 단계를 수행합니다.

  1. tokenizedDocument를 사용하여 텍스트를 토큰화합니다.

  2. lower를 사용하여 텍스트를 소문자로 변환합니다.

  3. erasePunctuation을 사용하여 문장 부호를 지웁니다.

preprocessText 함수를 사용하여 훈련 데이터와 검증 데이터를 전처리합니다.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

전처리된 처음 몇 개의 훈련 문서를 표시합니다.

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
    10 tokens: there are cuts to the power when starting the plant
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses

문서를 시퀀스로 변환하기

문서를 LSTM 신경망에 입력하려면 단어 인코딩을 사용하여 문서를 숫자형 인덱스로 구성된 시퀀스로 변환하십시오.

단어 인코딩을 만들려면 wordEncoding 함수를 사용하십시오.

enc = wordEncoding(documentsTrain);

다음 변환 단계는 문서가 모두 같은 길이가 되도록 채우고 자르는 것입니다. trainingOptions 함수는 입력 시퀀스를 자동으로 채우고 자르는 옵션을 제공합니다. 그러나 이러한 옵션은 단어 벡터로 구성된 시퀀스에 적합하지 않습니다. 이러한 옵션을 사용하는 대신 시퀀스를 수동으로 채우고 자릅니다. 단어 벡터로 구성된 시퀀스를 왼쪽을 채우고 자르면 훈련이 향상될 수 있습니다.

문서를 채우고 자르려면 먼저 목표 길이를 선택하고, 목표 길이보다 긴 문서는 자르고 목표 길이보다 짧은 문서는 왼쪽을 채우십시오. 최상의 결과를 위해 목표 길이는 다량의 데이터가 버려지지 않을 만큼 짧아야 합니다. 적당한 목표 길이를 찾으려면 훈련 문서의 길이를 히스토그램으로 표시해 보십시오.

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

대부분의 훈련 문서가 10개 미만의 토큰을 갖습니다. 이 값을 자르기와 채우기의 목표 길이로 사용합니다.

doc2sequence를 사용하여 문서를 숫자형 인덱스로 구성된 시퀀스로 변환합니다. 시퀀스의 길이가 10이 되도록 자르거나 왼쪽을 채우려면 'Length' 옵션을 10으로 설정하십시오.

sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}

동일한 옵션을 사용하여 검증 문서를 시퀀스로 변환합니다.

XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);

LSTM 신경망 만들고 훈련시키기

LSTM 신경망 아키텍처를 정의합니다. 신경망에 시퀀스 데이터를 입력하려면 시퀀스 입력 계층을 포함시키고 입력 크기를 1로 설정하십시오. 다음으로, 차원이 50이고 단어 인코딩과 동일한 단어 개수를 갖는 단어 임베딩 계층을 포함시킵니다. 다음으로, LSTM 계층을 포함시키고 은닉 유닛의 개수를 80으로 설정합니다. sequence-to-label 분류 문제에서 LSTM 계층을 사용하려면 출력 모드를 'last'로 설정하십시오. 마지막으로, 클래스 개수와 동일한 크기를 갖는 완전 연결 계층, 소프트맥스 계층, 분류 계층을 추가합니다.

inputSize = 1;
embeddingDimension = 50;
numHiddenUnits = 80;

numWords = enc.NumWords;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 50 dimensions and 423 unique words
     3   ''   LSTM                    LSTM with 80 hidden units
     4   ''   Fully Connected         4 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

훈련 옵션 지정하기

다음과 같이 훈련 옵션을 지정합니다.

  • Adam 솔버를 사용하여 훈련시킵니다.

  • 미니 배치 크기를 16으로 지정합니다.

  • 매 Epoch마다 데이터를 섞습니다.

  • 'Plots' 옵션을 'training-progress'로 설정하여 훈련 진행 상황을 모니터링합니다.

  • 'ValidationData' 옵션을 사용하여 검증 데이터를 지정합니다.

  • 'Verbose' 옵션을 false로 설정하여 세부 정보가 출력되지 않도록 합니다.

기본적으로 trainNetwork는 GPU를 사용할 수 있으면 GPU를 사용합니다(Parallel Computing Toolbox™와 Compute Capability 3.0 이상의 CUDA® 지원 GPU 필요). GPU가 없으면 CPU를 사용합니다. 실행 환경을 수동으로 지정하려면 trainingOptions'ExecutionEnvironment' 이름-값 쌍 인수를 사용하십시오. CPU에서 훈련시키면 GPU에서 훈련시키는 것보다 시간이 상당히 오래 걸릴 수 있습니다.

options = trainingOptions('adam', ...
    'MiniBatchSize',16, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

trainNetwork 함수를 사용하여 LSTM 신경망을 훈련시킵니다.

net = trainNetwork(XTrain,YTrain,layers,options);

새 데이터를 사용하여 예측하기

새 보고서 3개의 이벤트 유형을 분류합니다. 새 보고서를 포함하는 string형 배열을 만듭니다.

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

전처리 단계를 사용하여 텍스트 데이터를 훈련 문서로 전처리합니다.

documentsNew = preprocessText(reportsNew);

훈련 시퀀스를 만들 때와 같은 옵션으로 doc2sequence를 사용하여 텍스트 데이터를 시퀀스로 변환합니다.

XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);

훈련된 LSTM 신경망을 사용하여 새 시퀀스를 분류합니다.

labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

전처리 함수

함수 preprocessText는 다음 단계를 수행합니다.

  1. tokenizedDocument를 사용하여 텍스트를 토큰화합니다.

  2. lower를 사용하여 텍스트를 소문자로 변환합니다.

  3. erasePunctuation을 사용하여 문장 부호를 지웁니다.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

참고 항목

| | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

관련 항목