MATLAB Answers

Trying to make a RNN for 2D signal data classification to 2D classified matrix output. Error using trainNetwork (line 165) Invalid training data. Responses must be a vector of categorical responses, or a cell array of categorical response sequences.

조회 수: 17(최근 30일)
Orestis Marantos
Orestis Marantos 24 Mar 2021
댓글: Orestis Marantos 22 Apr 2021 17:20
Hello Everyone,
I'm trying to implement a Neural Network classification algorithm for signal data with a shape (800,500) , with 800 being the number of time steps and 500 being the number of observations. I want to train a NN to identify for every timestep if it belongs to class 0, 1, or -1 . So, my responses should have the same shape with the XTrain data, (800,500).
After trying to use Xtrain and YTrain in the form of simple 2 dimensional arrays I understood that this is not possible, so I made changed their form to cell arrays as the bibliography requires.
My XTrain data have the form of a cell array of double sequences and so do the YTrain data. Out of 500 observations I used 70%, which are 350 observations for my Train data.
As the Matlab trainNetwork bibliography suggest:
But, I'm still getting the same error:
The Layers and Options for my RNN so far are the following, please make sugestions :)
Please help

답변(1개)

Srivardhan Gadila
Srivardhan Gadila 29 Mar 2021
According to the documentation of the Input Arguments: sequences & responses of the trainNetwork function for the syntax net = trainNetwork(sequences,responses,layers,options) the input data should be of N-by-1 cell array of numeric arrays, where N is the number of observations and each observation must be a c-by-s matrix, where c is the number of features of the sequences and s is the sequence length in case of Vector sequences. Whereas the responses should be N-by-1 cell array of categorical sequences of labels, where N is the number of observations with each observation as a 1-by-s sequence of categorical labels, where s is the sequence length of the corresponding predictor sequence.
The following code may help you:
%% Create network.
inputSize = 800;
numClasses = 3;
numHiddenUnits = 100;
layers = [ ...
sequenceInputLayer(1,'Name','Sequence Input')
lstmLayer(numHiddenUnits,'Name','LSTM Layer')
fullyConnectedLayer(numClasses,'Name','FC')
softmaxLayer('Name','Softmax')
classificationLayer('Name','Classification Layer')];
lgraph = layerGraph(layers);
analyzeNetwork(lgraph)
%% Create Random Training data.
numTrainSamples = 50;
trainData = arrayfun(@(x)rand([1 inputSize]),1:numTrainSamples,'UniformOutput',false)';
trainLabels = arrayfun(@(x)categorical(randi([-1 1], 1,inputSize)),1:numTrainSamples,'UniformOutput',false)';
size(trainData)
size(trainLabels)
%% Train the network.
options = trainingOptions('adam', ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise',...
'Verbose',1, ...
'Plots','training-progress');
net = trainNetwork(trainData,trainLabels,lgraph,options);
  댓글 수: 1
Orestis Marantos
Orestis Marantos 22 Apr 2021 17:20
Hello Srivardhan, thank you so much for your thorough explanations. Could you also recommend me if the way to overpass the high skewness of my problem? I created with simulations more samples.
Specifically, I have 2000 samples each with 800 time steps. In each sample there are some points of interest (signal goes up (on), signal goes down (off)). I would like to detect those points so I thought that MIMO multi-label classification with lstm could solve my problem. So, I want to classify each time step of each sample as 0 if it is not a point of interest, or 1 if it is (changed the -1 also to 1). Eventually I'm interested only in ones so I'm not sure if it is the best method to follow.
Because more than 98% of the timesteps are zeroes the lstm keeps giving my only zeroes after the training. I have implemented a custom weighted binary cross entropy function and I give 0.97 weight to the ones and 0.03 to zeros, but the algorithm keeps giving me only zeros as a result for each prediciton.
Do you have any suggestions on how to solve this highly skewed problem? I read about over-sampling and under-sampling but this won't help in my problem as it is natural to have this very little number of active points (ones).
Thanks a lot in advance

댓글을 달려면 로그인하십시오.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by