Main Content

trainBERTDocumentClassifier

BERT 문서 분류기 훈련시키기

R2023b 이후

    설명

    BERT(Bidirectional Encoder Representations from Transformer) 모델은 문서 분류 및 감성 분석 등의 자연어 처리 작업을 위해 미세 조정(파인 튜닝)이 가능한 트랜스포머 신경망입니다. 이 신경망은 어텐션 계층을 사용하여 문맥 내의 텍스트를 분석하고 단어 간의 장거리 종속성(long-range dependency)을 캡처합니다.

    예제

    mdlTrained = trainBERTDocumentClassifier(documents,targets,mdl,options)는 지정된 텍스트 데이터와 목표값을 사용하여 BERT 문서 분류기를 훈련시킵니다.

    mdlTrained = trainBERTDocumentClassifier(tbl,mdl,options)는 지정된 테이블의 텍스트 데이터와 목표값을 사용하여 BERT 문서 분류기를 훈련시킵니다.

    예제

    모두 축소

    factoryReports CSV 파일에서 훈련 데이터를 읽어옵니다. 이 파일에는 각 보고서에 대한 텍스트 설명과 범주 레이블이 포함된 공장 보고서가 들어 있습니다.

    filename = "factoryReports.csv";
    data = readtable(filename,TextType="string");
    head(data)
                                     Description                                       Category          Urgency          Resolution         Cost 
        _____________________________________________________________________    ____________________    ________    ____________________    _____
    
        "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
        "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
        "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
        "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
        "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
        "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
        "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
        "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38
    

    테이블의 Category 열에 있는 레이블을 categorical형 값으로 변환합니다.

    data.Category = categorical(data.Category);

    데이터를 훈련 세트와 테스트 세트로 분할합니다. 홀드아웃 백분율을 10%로 지정합니다.

    cvp = cvpartition(data.Category,Holdout=0.1);
    dataTrain = data(cvp.training,:);
    dataTest = data(cvp.test,:);

    테이블에서 텍스트 데이터와 레이블을 추출합니다.

    textDataTrain = dataTrain.Description;
    textDataTest = dataTest.Description;
    TTrain = dataTrain.Category;
    TTest = dataTest.Category;

    bertDocumentClassifier 함수를 사용하여, 사전 훈련된 BERT-Base 문서 분류기를 불러옵니다.

    classNames = categories(data.Category);
    mdl = bertDocumentClassifier(ClassNames=classNames)
    mdl = 
      bertDocumentClassifier with properties:
    
           Network: [1×1 dlnetwork]
         Tokenizer: [1×1 bertTokenizer]
        ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]
    
    

    훈련 옵션을 지정합니다. 훈련 옵션 중에서 선택하려면 경험적 분석이 필요합니다. 실험을 실행하여 여러 훈련 옵션 구성을 살펴보려면 Experiment Manager 앱을 사용할 수 있습니다.

    • Adam 최적화 함수 알고리즘을 사용하여 훈련시킵니다.

    • Epoch 8회 동안 훈련시킵니다.

    • 세부 조정을 위해 학습률을 낮춥니다. 0.0001의 학습률을 사용하여 훈련시킵니다.

    • 매 Epoch마다 데이터를 섞습니다.

    • 플롯에서 훈련 진행 상황을 모니터링한 후 정확도 메트릭을 모니터링합니다.

    • 상세 출력을 비활성화합니다.

    options = trainingOptions("adam", ...
        MaxEpochs=8, ...
        InitialLearnRate=1e-4, ...
        Shuffle="every-epoch", ...  
        Plots="training-progress", ...
        Metrics="accuracy", ...
        Verbose=false);

    trainBERTDocumentClassifier 함수를 사용하여 신경망을 훈련시킵니다. 기본적으로 trainBERTDocumentClassifier 함수는 GPU를 사용할 수 있으면 GPU를 사용합니다. GPU에서 훈련시키려면 Parallel Computing Toolbox™ 라이선스와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. GPU가 없으면 trainBERTDocumentClassifier 함수는 CPU를 사용합니다. 실행 환경을 지정하려면 ExecutionEnvironment 훈련 옵션을 사용하십시오.

    mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

    테스트 데이터를 사용하여 예측을 수행합니다.

    YTest = classify(mdl,textDataTest);

    테스트 예측의 분류 정확도를 계산합니다.

    accuracy = mean(TTest == YTest)
    accuracy = 0.9375
    

    입력 인수

    모두 축소

    훈련 문서로, string형 배열, 문자형 벡터로 구성된 셀형 배열 또는 tokenizedDocument 배열로 지정됩니다.

    documents의 요소 개수와 targets의 요소 개수가 일치해야 합니다.

    훈련 목표값으로, categorical형 배열, string형 배열 또는 문자형 벡터로 구성된 셀형 배열로 지정됩니다.

    훈련 문서 및 훈련 목표값으로, 2개 열을 갖는 테이블로 지정됩니다. 첫 번째 열은 string형 또는 문자형 벡터로 지정된 텍스트 데이터를 포함합니다. 두 번째 열은 categorical형 값, string형 또는 문자형 벡터로 지정된 목표값을 포함합니다.

    데이터형: table

    BERT 문서 분류기 모델로, bertDocumentClassifier 객체로 지정됩니다.

    훈련 옵션으로, trainingOptions (Deep Learning Toolbox) 함수에서 반환된 TrainingOptionsSGDM 객체, TrainingOptionsRMSProp 객체 또는 TrainingOptionsADAM 객체로 지정됩니다. 이들 객체를 만들려면 솔버를 각각 "sgdm", "rmsprop", "adam"으로 설정합니다.

    출력 인수

    모두 축소

    BERT 문서 분류기 모델로, bertDocumentClassifier 객체로 반환됩니다.

    참고 문헌

    [1] Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. "BERT: Pre-Training of Deep Bidirectional Transformers for Language Understanding" Preprint, submitted May 24, 2019. https://doi.org/10.48550/arXiv.1810.04805.

    [2] Srivastava, Nitish, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. "Dropout: A Simple Way to Prevent Neural Networks from Overfitting." The Journal of Machine Learning Research 15, no. 1 (January 1, 2014): 1929–58

    [3] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "ImageNet Classification with Deep Convolutional Neural Networks." Communications of the ACM 60, no. 6 (May 24, 2017): 84–90. https://doi.org/10.1145/3065386

    버전 내역

    R2023b에 개발됨