Main Content

문서 임베딩을 사용하여 문서 분류하기

이 예제에서는 문서 임베딩을 사용해 문서를 특징 벡터로 변환하여 문서 분류기를 훈련시키는 방법을 보여줍니다.

대부분의 머신러닝 기술에서는 분류기를 훈련시키기 위한 입력값으로 특징 벡터가 필요합니다.

문서 임베딩은 문서를 벡터에 매핑합니다. 레이블이 지정된 문서 벡터의 데이터 세트가 주어지면, 이러한 문서를 분류하도록 머신러닝 모델을 훈련시킬 수 있습니다.

사전 훈련된 문서 임베딩 불러오기

documentEmbedding 함수를 사용하여 사전 훈련된 문서 임베딩 "all-MiniLM-L6-v2"를 불러옵니다. 이 모델을 사용하려면 Text Analytics Toolbox™ Model for all-MiniLM-L6-v2 Network 지원 패키지가 필요합니다. 이 지원 패키지가 설치되어 있지 않으면 함수에서 다운로드 링크를 제공합니다.

emb = documentEmbedding(Model="all-MiniLM-L6-v2");

재현성을 위해 rng 함수를 "default" 옵션과 함께 사용합니다.

rng("default");

훈련 데이터 불러오기

다음으로 예제 데이터를 불러옵니다. factoryReports.csv 파일에는 각 이벤트에 대한 텍스트 설명과 범주 레이블이 포함된 공장 보고서가 들어 있습니다.

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

이 예제의 목표는 Category 열의 레이블을 기준으로 이벤트를 분류하는 것입니다. 데이터를 클래스별로 나누기 위해 레이블을 categorical형으로 변환합니다.

str = data.Description;
labels = categorical(data.Category);

다음으로 데이터를 훈련 파티션, 그리고 검증과 테스트를 위한 홀드아웃 파티션으로 분할합니다. 홀드아웃 백분율을 30%로 설정합니다.

cvp = cvpartition(labels,Holdout=0.3);
idxTrain = training(cvp);
idxTest = test(cvp);

훈련 파티션과 테스트 파티션을 위한 대상 레이블을 가져옵니다.

labelsTrain = labels(idxTrain,:);
labelsTest = labels(idxTest,:);

문서를 특징 벡터로 변환하기

공장 보고서를 벡터로 변환하려면 embed 함수를 사용합니다. 문서에서 텍스트 전처리를 수행할 필요가 없습니다.

embeddedDocumentsTrain = embed(emb,str(idxTrain,:));
embeddedDocumentsTest = embed(emb,str(idxTest,:));

임베디드 테스트 데이터의 크기를 표시합니다.

size(embeddedDocumentsTest)
ans = 1×2

   144   384

144개 문서 각각에 대한 출력값은 전체 문서의 의미론적 표현을 제공하는 384개 요소를 가진 단일 벡터입니다. 테스트 세트에서 첫 번째 문서의 임베딩 벡터를 표시합니다.

embeddedDocumentsTest(1,:)
ans = 1×384

   -0.0141   -0.0434    0.0271   -0.0302   -0.1098   -0.0431   -0.0311   -0.0633    0.0388   -0.0577    0.0328   -0.0112   -0.0293   -0.0755   -0.0539    0.0484    0.0798   -0.0112   -0.0152   -0.0711   -0.0854    0.0378    0.0026    0.0957    0.0080    0.0720    0.0196    0.0605    0.0109   -0.0186    0.0441   -0.0159   -0.0111   -0.0404    0.1344   -0.0472   -0.0102    0.0745    0.0056   -0.1010    0.0479   -0.0117    0.0843   -0.0471   -0.0217    0.0362   -0.0030   -0.0579    0.1073   -0.0383

임베딩 벡터를 시각화하려면 t-SNE 플롯을 만듭니다. 먼저 tsne를 사용하여 2차원 공간에 벡터를 임베딩합니다. 그런 다음 gscatter를 사용하여 레이블을 기준으로 그룹화된 테스트 임베딩 벡터의 산점도 플롯을 만듭니다.

Y = tsne(embeddedDocumentsTest,Distance="cosine");
gscatter(Y(:,1),Y(:,2),labelsTest)
title("Factory Report Embeddings")

문서 분류기 훈련시키기

fitcecoc를 사용하여 다중 클래스 선형 분류 모델을 훈련시킵니다.

mdl = fitcecoc(embeddedDocumentsTrain,labelsTrain,Learners="linear")
mdl = 
  CompactClassificationECOC
      ResponseName: 'Y'
        ClassNames: [Electronic Failure    Leak    Mechanical Failure    Software Failure]
    ScoreTransform: 'none'
    BinaryLearners: {6×1 cell}
      CodingMatrix: [4×6 double]


  Properties, Methods

모델 테스트하기

테스트 문서의 범주를 예측합니다. 정확도를 계산하고 혼동행렬 차트를 플로팅합니다.

labelPredict = predict(mdl,embeddedDocumentsTest);
acc = mean(labelPredict == labelsTest)
acc = 0.9444
confusionchart(labelPredict,labelsTest)

대각선상에 큰 값이 있으면 대응하는 클래스에 대한 예측이 정확하다는 것을 의미합니다. 비대각선상에 큰 값이 있으면 대응하는 클래스 간에 혼동이 심하다는 것을 의미합니다.

참고 항목

| | | | | | |

관련 항목