필터 지우기
필터 지우기

Formatting data for Deep Learning toolbox and the trainNetwork function

조회 수: 16 (최근 30일)
Hello,
I am struggling a lot with MATLAB trying to train a neural network to recognise gait events in kinematic walking data (sensors on patient's legs that give me their positions in 3D). My dataset (for now) is a 10661 by 96 matrix (10661 points with 96 features). The features are 3 positions (because it is in 3D) * 4 sensors * 2 for both feet * 4 different experiment conditions (where the patient walked with different speeds and inclination), which is equal to 96 features. My labels are inside a 96by1 vector with 1 for heel strike, -1 for toe off, 0 else (which are the only gait events I want to train to recognise for now). My code from that point:
% making a categorical array out of my labels array
labels_cat = categorical(labels,[-1 0 1],{'toe off' 'else' 'heel strike'});
numClasses = 3; % Number of classes in your problem
%% defining the neural net
inputSize = 3*4*2*4; % 96 features
hiddenSize = 100;
numClasses = 3; % the number of classes (heel strike, toe off, and other)
layers = [
imageInputLayer([1 1 inputSize], 'Name', 'input')
fullyConnectedLayer(50, 'Name', 'fc1')
reluLayer('Name', 'relu1')
fullyConnectedLayer(numClasses, 'Name', 'fc2')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'classoutput')
];
%% training
% Specify training options
options = trainingOptions('adam', ...
'MaxEpochs', 20, ...
'MiniBatchSize', 32, ...
'InitialLearnRate', 0.001, ...
'Plots', 'training-progress');
net = trainNetwork(data, labels_cat, layers, options);
I have tried every imaginable way to format the data (especially the labels) and I always get an error. Either a supposed size mismatch between X and Y (when both are exactly 10661x1 and 10661x96), or something else. I might be doing everything wrong, I never touched MATLAB for ML before and I'm not an expert on pytorch / tensorflow either.
Thanks :)

채택된 답변

Sivylla Paraskevopoulou
Sivylla Paraskevopoulou 2023년 4월 4일
Hi Alexei,
You have a feature matrix but your network expects images as input data. The first layer of your network is an imageInputLayer.
You can classify your data with machine learning or deep learning, and I think you will find these two examples helpful:
If you choose deep learning, see in the above example how a very simple LSTM network is created for classifying sensor data:
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by