Main Content

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

문서 임베딩을 사용하여 문서 분류하기

이 예제에서는 단어 임베딩을 사용해 문서를 특징 벡터로 변환하여 문서 분류기를 훈련시키는 방법을 보여줍니다.

대부분의 머신러닝 기술에서는 분류기를 훈련시키기 위한 입력값으로 특징 벡터가 필요합니다.

단어 임베딩은 개별 단어를 벡터에 매핑합니다. 단어 임베딩을 사용하면 단어 벡터를 결합하여 문서를 단일 벡터에 매핑할 수 있습니다(예: 평균 벡터를 계산하여 문서 벡터 생성).

레이블이 지정된 문서 벡터의 데이터 세트가 주어지면, 이러한 문서를 분류하도록 머신러닝 모델을 훈련시킬 수 있습니다.

사전 훈련된 단어 임베딩 불러오기

fastTextWordEmbedding 함수를 사용하여 사전 훈련된 단어 임베딩을 불러옵니다. 이 함수를 사용하려면 Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding 지원 패키지가 필요합니다. 이 지원 패키지가 설치되어 있지 않으면 함수에서 다운로드 링크를 제공합니다.

emb = fastTextWordEmbedding
emb = 
  wordEmbedding with properties:

     Dimension: 300
    Vocabulary: [","    "the"    "."    "and"    "of"    "to"    "in"    "a"    """    ":"    ")"    "that"    "("    "is"    "for"    "on"    "*"    "with"    "as"    "it"    "The"    "or"    "was"    "'"    "'s"    "by"    "from"    "at"    …    ]

재현성을 위해 rng 함수를 "default" 옵션과 함께 사용합니다.

rng("default");

훈련 데이터 불러오기

다음 단계는 예제 데이터를 불러오는 것입니다. factoryReports.csv 파일에는 각 이벤트에 대한 텍스트 설명과 범주 레이블이 포함된 공장 보고서가 들어 있습니다.

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

이 예제의 목표는 Category 열의 레이블을 기준으로 이벤트를 분류하는 것입니다. 데이터를 클래스별로 나누기 위해 레이블을 categorical형으로 변환합니다.

data.Category = categorical(data.Category);

다음 단계는 데이터를 훈련 세트와 테스트 세트로 분할하기 위한 파티션을 만드는 것입니다. 데이터를 훈련 파티션, 그리고 검증과 테스트를 위한 홀드아웃 파티션으로 분할합니다. 홀드아웃 백분율을 30%로 지정합니다.

cvp = cvpartition(data.Category,Holdout=0.3);

파티션을 사용하여 훈련 및 테스트를 위한 대상 레이블을 가져옵니다. 예제의 뒷부분에서, 문서의 벡터를 생성한 후에 입력 데이터를 훈련 데이터와 테스트 데이터로 분할하는 데에도 파티션이 사용됩니다.

TTrain = data.Category(training(cvp),:);
TTest = data.Category(test(cvp),:);

서로 다른 텍스트 데이터 모음을 동일한 방식으로 준비할 수 있기 때문에, 전처리를 수행하는 함수를 만드는 것이 유용할 수 있습니다.

분석에 사용할 수 있도록 텍스트 데이터를 토큰화하고 전처리하는 함수를 만듭니다. 이 예제의 예제 전처리 함수 섹션에 나오는 preprocessText 함수는 다음 단계를 수행합니다.

  1. tokenizedDocument를 사용하여 텍스트를 토큰화합니다.

  2. erasePunctuation을 사용하여 문장 부호를 지웁니다.

  3. removeStopWords를 사용하여 불용어 목록(예: "and", "of", "the")을 제거합니다.

  4. 모든 단어를 소문자로 변환합니다.

documents = preprocessText(data.Description);

문서를 특징 벡터로 변환하기

토큰에서 벡터로 변환하는 데는 함수 word2vec가 사용됩니다. 더 큰 문서의 경우, 여러 단어 벡터에서 단어에 대한 평균을 계산하여 단일 단어 벡터로 결합할 수 있습니다. 이 예제에서는 문서 벡터를 구하는 방법을 보여줍니다.

각 문서에 있는 모든 단어에 대한 평균을 계산하여 각 문서의 벡터를 구합니다. word2vec 함수를 사용하여 단어 벡터를 계산하고 rmissing 함수를 사용하여 임베딩 단어집에서 누락된 단어를 제거합니다. 문서의 단어(데이터의 첫 번째 차원)에 대한 평균을 계산합니다.

meanEmbedding = zeros(numel(documents),emb.Dimension);
for k=1:numel(documents)
    words = string(documents(k));
    wordVectors = word2vec(emb,words);
    wordVectors = rmmissing(wordVectors);
    meanEmbedding(k,:) = mean(wordVectors,1);
end
meanEmbeddingTrain = meanEmbedding(training(cvp),:);
meanEmbeddingTest = meanEmbedding(test(cvp),:);

임베디드 테스트 데이터의 크기를 표시합니다. 배열은 numObservations×embeddingDimension 배열입니다. 여기서 numObservations는 테스트 문서의 개수이고 embeddingDimension은 임베딩 차원입니다.

size(meanEmbeddingTest)
ans = 1×2

   144   300

각 문서에 대한 출력값은 문서에 포함된 단어 벡터의 모든 특징을 요약하는 하나의 300차원 배열입니다. 테스트 세트에서 첫 번째 문서의 벡터는 다음과 같이 구합니다.

meanEmbeddingTest(1,:)
ans = 1×300

   -0.1367   -0.0284   -0.1061   -0.0034    0.0577   -0.0662   -0.0845   -0.0606    0.0117   -0.0614    0.1074   -0.0814    0.0160   -0.0101   -0.0419   -0.0108   -0.0433   -0.0334   -0.0192   -0.0640   -0.1802   -0.0926    0.0291   -0.0787    0.1210   -0.0796    0.1160   -0.0278   -0.0243   -0.0577    0.0851    0.0354    0.0002    0.0060    0.0887    0.0491    0.0312   -0.0865   -0.0867    0.0378   -0.0794   -0.1174    0.0331    0.0432    0.0372   -0.0873   -0.0050   -0.0515    0.0382    0.0283

문서 벡터를 구한 후에는 차원 수를 2로 지정해 tsne를 사용하여 2차원 공간에 문서 벡터를 임베드할 수도 있습니다. t-SNE 플롯은 데이터의 군집을 표시하는 데 도움이 될 수 있으며, 이는 머신러닝 모델을 만들 수 있음을 의미합니다.

Y = tsne(meanEmbeddingTest);
gscatter(Y(:,1),Y(:,2), categorical(TTest))
title("Factory Report Embeddings")

문서 분류기 훈련시키기

문서 벡터와 각 군집을 시각화한 후에 fitcecoc를 사용하여 다중클래스 선형 분류 모델을 훈련시킬 수 있습니다.

mdl = fitcecoc(meanEmbeddingTrain,TTrain,Learners="linear")
mdl = 
  CompactClassificationECOC
      ResponseName: 'Y'
        ClassNames: [Electronic Failure    Leak    Mechanical Failure    Software Failure]
    ScoreTransform: 'none'
    BinaryLearners: {6×1 cell}
      CodingMatrix: [4×6 double]


  Properties, Methods

모델 테스트하기

평균 벡터의 점수를 계산하여 정확도 결과와 혼동행렬을 시각화합니다.

YTest = predict(mdl,meanEmbeddingTest);
acc = mean(YTest == TTest)
acc = 0.9444
confusionchart(YTest,TTest)

대각선상에 큰 숫자가 있으면 해당하는 클래스의 예측 정확도가 좋다는 것을 의미합니다. 비대각선상에 큰 숫자가 있으면 해당하는 클래스 간에 혼동이 심하다는 것을 의미합니다.

예제 전처리 함수

함수 preprocessText는 다음 단계를 순서대로 수행합니다.

  1. tokenizedDocument를 사용하여 텍스트를 토큰화합니다.

  2. erasePunctuation을 사용하여 문장 부호를 지웁니다.

  3. removeStopWords를 사용하여 불용어 목록(예: "and", "of", "the")을 제거합니다.

  4. 모든 단어를 소문자로 변환합니다.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Erase punctuation.
documents = erasePunctuation(documents);

% Remove a list of stop words.
documents = removeStopWords(documents);

% Lowercase all words.
documents = lower(documents);

end

참고 항목

| | | | |

관련 항목