LSTM - Set special loss function
조회 수: 8 (최근 30일)
이전 댓글 표시
Based on this great MatLab-example I would like to build a neural network classifying each timestep of a timeseries (x_i,y_i) (i=1:N) as 1 or 2. For training purpose I created 500 different timeseries and the corresponding target-vectors. In reality, about 85 % of a timeseries is in state 1 and the rest (15 %) in state 2. The training-process lookes like this:
It stagnates at about 85%, because having the Mean Squared Error as a loss-function, a policy classifying every timestep as 1 results in a "good" accuracy of 85 %. So nearly every timestep is classified as 1. I am quite sure this can be avoided by using another loss function, but unfortunately I do not know how I can create an arbitrary loss-function.
I would like to adapt the loss function in a way, that if it falsely classifies a true state 2 as 1, then the loss is weighted by a factor f.
How can this be done?
This is how the training is set up:
featureDimension = 2;
numHiddenUnits = 100;
numClasses = 2;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
options = trainingOptions('adam', ...
'GradientThreshold',1, ...
'InitialLearnRate',0.01, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',20, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);
댓글 수: 0
답변 (2개)
sii
2018년 5월 18일
If it is still relevant, check following documentation. As far as I know, it is only available in the R2018a release.
댓글 수: 0
Stuart Whipp
2018년 12월 10일
I think what is needed is a weighted classification output so you can account for the imbalance in your classes. A custom layer tutorial exists for this, but it only works on image classification problems seemingly.
댓글 수: 1
Stuart Whipp
2018년 12월 12일
편집: Stuart Whipp
2018년 12월 12일
Conor Daly (staff) kindly answered my question this morning with a custom output layer that I can confirm has worked for my Use Case. Please take a look at this link as I believe it also answers your question.
Regards Stuart
참고 항목
카테고리
Help Center 및 File Exchange에서 Image Data Workflows에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!