Main Content

오만과 편견 그리고 MATLAB

이 예제에서는 문자 임베딩을 사용하여 텍스트를 생성하도록 딥러닝 LSTM 신경망을 훈련시키는 방법을 보여줍니다.

텍스트 생성을 위해 딥러닝 신경망을 훈련시키려면 문자 시퀀스 내의 다음 문자를 예측하도록 sequence-to-sequence LSTM 신경망을 훈련시키십시오. 다음 문자를 예측하도록 신경망을 훈련시키려면 입력 시퀀스를 시간 스텝 하나만큼 이동한 값을 응답 변수로 지정하십시오.

문자 임베딩을 사용하려면 각 훈련 관측값을 정수 시퀀스로 변환합니다. 이들 정수는 문자들로 이루어진 단어집을 참조합니다. 문자 임베딩을 학습하고 정수를 벡터로 매핑하는 단어 임베딩 계층을 신경망에 포함시킵니다.

훈련 데이터 불러오기

Project Gutenberg에서 제공하는 Pride and Prejudice(Jane Austen 저) 전자책의 HTML 코드를 읽어 들이고 webreadhtmlTree를 사용하여 구문 분석합니다.

url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);

p 요소를 찾아 단락을 추출합니다. CSS 선택자 ':not(.toc)'를 사용하여 클래스 "toc"를 가지는 단락 요소를 무시하도록 지정합니다.

paragraphs = findElement(tree,'p:not(.toc)');

extractHTMLText를 사용하여 단락에서 텍스트 데이터를 추출하고 빈 문자열을 제거합니다.

textData = extractHTMLText(paragraphs);
textData(textData == "") = [];

문자 길이가 20자보다 짧은 문자열을 제거합니다.

idx = strlength(textData) < 20;
textData(idx) = [];

텍스트 데이터를 워드 클라우드로 시각화합니다.

figure
wordcloud(textData);
title("Pride and Prejudice")

텍스트 데이터를 시퀀스로 변환하기

텍스트 데이터를 예측 변수의 경우 문자 인덱스로 구성된 시퀀스로, 응답 변수의 경우 categorical형 시퀀스로 변환합니다.

categorical 함수는 새 줄 요소과 공백 요소를 정의되지 않은 요소로 취급합니다. 이러한 문자에 대한 categorical형 요소를 만들려면 각각을 특수 문자 ""(단락 기호, "\x00B6")와 "·"(중간 점, "\x00B7")으로 바꾸십시오. 모호성을 피하려면 텍스트에 표시되지 않는 특수 문자를 선택해야 합니다. 이들 문자는 훈련 데이터에 나타나지 않으므로 이 용도로 사용할 수 있습니다.

newlineCharacter = compose("\x00B6");
whitespaceCharacter = compose("\x00B7");
textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);

텍스트 데이터를 루프를 사용해 순회하여, 각 관측값의 문자를 나타내는 문자 인덱스 시퀀스와 응답 변수에 대한 categorical형 문자 시퀀스를 만듭니다. 각 관측값의 끝을 나타내려면 특수 문자 "␃"(텍스트의 끝, "\x2403")을 포함하십시오.

endOfTextCharacter = compose("\x2403");
numDocuments = numel(textData);
for i = 1:numDocuments
    characters = textData{i};
    X = double(characters);
    
    % Create vector of categorical responses with end of text character.
    charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter];
    Y = categorical(charactersShifted);
    
    XTrain{i} = X;
    YTrain{i} = Y;
end

기본적으로 훈련 중에 훈련 데이터가 미니 배치로 분할되고 모든 시퀀스의 길이가 같아지도록 시퀀스가 채워집니다. 너무 많이 채우면 신경망 성능이 저하될 수 있습니다.

훈련 과정에서 너무 많이 채워지지 않도록 하려면 시퀀스 길이를 기준으로 시퀀스 데이터를 정렬한 다음 미니 배치 크기를 선택하여 하나의 미니 배치에 속한 시퀀스들이 비슷한 길이를 갖도록 합니다.

각 관측값에 대한 시퀀스 길이를 가져옵니다.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

시퀀스 길이를 기준으로 데이터를 정렬합니다.

[~,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

LSTM 신경망 만들고 훈련시키기

LSTM 아키텍처를 정의합니다. 은닉 유닛 400개를 갖는 sequence-to-sequence LSTM 분류 신경망을 지정합니다. 입력 크기를 훈련 데이터의 특징 차원으로 설정합니다. 문자 인덱스로 구성된 시퀀스의 경우 특징 차원은 1입니다. 차원이 200인 단어 임베딩 계층을 지정하고, (문자에 대응하는) 단어의 개수가 입력 데이터에서 가장 높은 문자 값이 되도록 지정합니다. 완전 연결 계층의 출력 크기가 응답 변수의 범주 개수가 되도록 설정합니다. 과적합을 방지하기 위해 LSTM 계층 뒤에 드롭아웃 계층을 포함시킵니다.

단어 임베딩 계층은 문자의 임베딩을 학습하고 각 문자를 200차원 벡터로 매핑합니다.

inputSize = size(XTrain{1},1);
numClasses = numel(categories([YTrain{:}]));
numCharacters = max([textData{:}]);

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(200,numCharacters)
    lstmLayer(400,'OutputMode','sequence')
    dropoutLayer(0.2);
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

훈련 옵션을 지정합니다. 미니 배치 크기 32, 초기 학습률 0.01로 훈련시키도록 지정합니다. 기울기가 한없이 증가하지 않도록 하려면 기울기 임계값을 1로 설정하십시오. 데이터가 정렬된 상태를 유지하도록 하려면 'Shuffle''never'로 설정하십시오. 훈련 진행 상황을 모니터링하려면 'Plots' 옵션을 'training-progress'로 설정하십시오. 세부 정보가 출력되지 않도록 하려면 'Verbose'false로 설정하십시오.

options = trainingOptions('adam', ...
    'MiniBatchSize',32,...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

신경망을 훈련시킵니다.

net = trainNetwork(XTrain,YTrain,layers,options);

새 텍스트 생성하기

훈련 데이터의 텍스트에 있는 첫 번째 문자들의 확률 분포에서 문자를 추출하여 텍스트의 첫 번째 문자를 생성합니다. 현재 생성된 텍스트 시퀀스를 사용하여 다음 시퀀스를 예측하려면 훈련된 LSTM 신경망을 사용하여 나머지 문자를 생성하십시오. 신경망이 "텍스트 끝" 문자를 예측할 때까지 계속해서 문자를 하나씩 생성합니다.

훈련 데이터의 첫 번째 문자들의 분포에 따라 첫 번째 문자를 추출합니다.

initialCharacters = extractBefore(textData,2);
firstCharacter = datasample(initialCharacters,1);
generatedText = firstCharacter;

첫 번째 문자를 숫자형 인덱스로 변환합니다.

X = double(char(firstCharacter));

나머지 예측에 대해서는 신경망의 예측 점수에 따라 다음 문자를 추출합니다. 예측 점수는 다음 문자의 확률 분포를 나타냅니다. 신경망의 출력 계층의 클래스 이름으로 주어지는 문자들의 단어집으로부터 문자를 추출합니다. 신경망의 분류 계층에서 단어집을 가져옵니다.

vocabulary = string(net.Layers(end).ClassNames);

predictAndUpdateState를 사용하여 문자를 하나씩 예측합니다. 각 예측에 대해 이전 문자의 인덱스를 입력합니다. 신경망이 텍스트 끝 문자를 예측하거나 생성된 텍스트의 길이가 500자가 되면 예측을 중단합니다. 대규모의 데이터 모음, 긴 시퀀스 또는 큰 신경망의 경우에는 일반적으로 GPU에서의 예측이 CPU에서의 예측보다 연산 속도가 빠릅니다. 그 밖의 경우에는 일반적으로 CPU에서의 예측이 연산 속도가 빠릅니다. 단일 시간 스텝 예측에는 CPU를 사용하십시오. 예측에 CPU를 사용하려면 predictAndUpdateState'ExecutionEnvironment' 옵션을 'cpu'로 설정하십시오.

maxLength = 500;
while strlength(generatedText) < maxLength
    % Predict the next character scores.
    [net,characterScores] = predictAndUpdateState(net,X,'ExecutionEnvironment','cpu');
    
    % Sample the next character.
    newCharacter = datasample(vocabulary,1,'Weights',characterScores);
    
    % Stop predicting at the end of text.
    if newCharacter == endOfTextCharacter
        break
    end
    
    % Add the character to the generated text.
    generatedText = generatedText + newCharacter;
    
    % Get the numeric index of the character.
    X = double(char(newCharacter));
end

특수 문자를 각각에 대응하는 공백 문자와 새 줄 문자로 바꾸어 텍스트를 재구성합니다.

generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])
generatedText = 
"“I wish Mr. Darcy, upon latter of my sort sincerely fixed in the regard to relanth. We were to join on the Lucases. They are married with him way Sir Wickham, for the possibility which this two od since to know him one to do now thing, and the opportunity terms as they, and when I read; nor Lizzy, who thoughts of the scent; for a look for times, I never went to the advantage of the case; had forcibling himself. They pility and lively believe she was to treat off in situation because, I am exceal"

여러 개의 텍스트를 생성하려면 resetState를 사용하여 생성과 생성 사이에 신경망 상태를 재설정하십시오.

net = resetState(net);

참고 항목

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

관련 항목