Deep learning with weighted-multiple labels?
이전 댓글 표시
Hello All,
I'd like to train a network in which the responses are tags (labels) provided by users for a book. The tags are counted, so I want to take into account that e.g., 5 people said that the book is "fiction" while 3 said that it's "sci-fi". I'm using the textual description of the book and LSTM to predict these tags. See the attached picture to illustrate the data structure.
I can get this done for predicting the top label but how to use the info for all the labels? I'm assuming that the output layer needs to be changed somehow but I can't figure out how. I think what I need is an output layer that compares a vector to a vector and minimizes the distance?
See below for the script that works for the single label version.
Any help would be highly appreciated,
Balazs
textstokenized=tokenizedDocument(texts);textstokenized = removeStopWords(textstokenized);
textstokenized = removeShortWords(textstokenized,2);
textstokenized = removeLongWords(textstokenized,15);
%textstokenized = normalizeWords(textstokenized);
enc = wordEncoding(textstokenized);
encoded_text=doc2sequence(enc,textstokenized,'Length',100);
cvp = cvpartition(topcat,'Holdout',0.1);
dataTrain = topcat(cvp.training,:);
dataHeldOut = topcat(cvp.test,:);
textheldout=encoded_text(cvp.test,:);
textDataTrain=encoded_text(cvp.training,:);
cvp2 = cvpartition(dataHeldOut,'HoldOut',0.5);
dataValidation = dataHeldOut(training(cvp2),:);
dataTest = dataHeldOut(test(cvp2),:);
textDataValidation = textheldout(training(cvp2));
textDataTest = textheldout(test(cvp2));
YTrain = categorical(dataTrain);
YValidation = categorical(dataValidation);
YTest = categorical(dataTest);
inputSize = 1;
embeddingDimension = 100;
numHiddenUnits = enc.NumWords;
hiddenSize = 100;
layers = [ ...
sequenceInputLayer(inputSize)
wordEmbeddingLayer(embeddingDimension,numHiddenUnits)
bilstmLayer(hiddenSize,'OutputMode','last')
fullyConnectedLayer(39)
softmaxLayer
classificationLayer];
qq={textDataValidation,YValidation};
options = trainingOptions('adam', 'MaxEpochs',1, 'GradientThreshold',1, 'InitialLearnRate',0.01, 'ValidationData',qq,'Plots','training-progress', 'Verbose',1);
net = trainNetwork2(textDataTrain,YTrain,layers,options);
답변 (0개)
카테고리
도움말 센터 및 File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!