Main Content

BERT 문서 분류기 훈련시키기

R2023b 이후

이 예제에서는 문서 분류에 대해 BERT 신경망을 훈련시키는 방법을 보여줍니다.

BERT(Bidirectional Encoder Representations from Transformer) 모델은 문서 분류 및 감성 분석 등의 자연어 처리 작업을 위해 미세 조정(파인 튜닝)이 가능한 트랜스포머 신경망입니다. 이 신경망은 어텐션 계층을 사용하여 문맥 내의 텍스트를 분석하고 단어 간의 장거리 종속성(long-range dependency)을 캡처합니다.

이 예제에서는 텍스트 설명을 사용하여 공장 보고서의 범주를 예측하도록, 사전 훈련된 BERT-Base 신경망을 세부 조정합니다.

훈련 데이터 불러오기

factoryReports CSV 파일에서 훈련 데이터를 읽어옵니다. 이 파일에는 각 보고서에 대한 텍스트 설명과 범주 레이블이 포함된 공장 보고서가 들어 있습니다.

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

테이블의 Category 열에 있는 레이블을 categorical형 값으로 변환하고 히스토그램을 사용하여 데이터의 클래스 분포를 표시합니다.

data.Category = categorical(data.Category);
figure
histogram(data.Category)
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

클래스 개수를 확인합니다.

classNames = categories(data.Category);
numClasses = numel(classNames)
numClasses = 4

데이터를 훈련 세트와 테스트 세트로 분할합니다. 홀드아웃 백분율을 10%로 지정합니다.

cvp = cvpartition(data.Category,Holdout=0.1);
dataTrain = data(cvp.training,:);
dataTest = data(cvp.test,:);

테이블에서 텍스트 데이터와 레이블을 추출합니다.

textDataTrain = dataTrain.Description;
textDataTest = dataTest.Description;
TTrain = dataTrain.Category;
TTest = dataTest.Category;

사전 훈련된 BERT 문서 분류기 불러오기

bertDocumentClassifier 함수를 사용하여, 사전 훈련된 BERT-Base 문서 분류기를 불러옵니다. Text Analytics Toolbox™ Model for BERT-Base Network 지원 패키지가 설치되어 있지 않으면, 필요한 지원 패키지로 연결되는 애드온 탐색기 링크를 함수에서 제공합니다. 지원 패키지를 설치하려면 링크를 클릭한 다음 설치를 클릭하십시오.

mdl = bertDocumentClassifier(ClassNames=classNames)
mdl = 
  bertDocumentClassifier with properties:

       Network: [1×1 dlnetwork]
     Tokenizer: [1×1 bertTokenizer]
    ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]

훈련 옵션 지정하기

훈련 옵션을 지정합니다. 훈련 옵션 중에서 선택하려면 경험적 분석이 필요합니다. 실험을 실행하여 여러 훈련 옵션 구성을 살펴보려면 Experiment Manager 앱을 사용할 수 있습니다.

  • Adam 최적화 함수 알고리즘을 사용하여 훈련시킵니다.

  • Epoch 8회 동안 훈련시킵니다.

  • 세부 조정을 위해 학습률을 낮춥니다. 0.0001의 학습률을 사용하여 훈련시킵니다.

  • 매 Epoch마다 데이터를 섞습니다.

  • 플롯에서 훈련 진행 상황을 모니터링한 후 정확도 메트릭을 모니터링합니다.

  • 상세 출력을 비활성화합니다.

options = trainingOptions("adam", ...
    MaxEpochs=8, ...
    InitialLearnRate=1e-4, ...
    Shuffle="every-epoch", ...  
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

신경망 훈련시키기

trainBERTDocumentClassifier 함수를 사용하여 신경망을 훈련시킵니다. 기본적으로 trainBERTDocumentClassifier 함수는 GPU를 사용할 수 있으면 GPU를 사용합니다. GPU에서 훈련시키려면 Parallel Computing Toolbox™ 라이선스와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. GPU가 없으면 trainBERTDocumentClassifier 함수는 CPU를 사용합니다. 실행 환경을 지정하려면 ExecutionEnvironment 훈련 옵션을 사용하십시오.

mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

신경망 테스트하기

테스트 데이터를 사용하여 예측을 수행합니다.

YTest = classify(mdl,textDataTest);

혼동행렬에서 예측을 시각화합니다.

figure
confusionchart(TTest,YTest)

테스트 예측의 분류 정확도를 계산합니다.

accuracy = mean(TTest == YTest)
accuracy = 0.9375

새 데이터를 사용하여 예측하기

새 공장 보고서의 이벤트 유형을 분류합니다. 새 공장 보고서를 포함하는 string형 배열을 만듭니다.

strNew = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];
labelsNew = classify(mdl,strNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

참고 항목

| | (Deep Learning Toolbox) | (Deep Learning Toolbox)

관련 항목