필터 지우기
필터 지우기

how to set input for fully connected layer

조회 수: 4 (최근 30일)
alex
alex 2020년 6월 20일
Hello. I extracted features from video dataset and saved them in a cell. Because each video has different number of frames, size of feature matrixes are different (plz see the attached picture). I want to use these features as input for net_fc in last line of following code, but this error happens: ''Error using trainNetwork invalid training data. Responses must be a vector of categorical responses, or a cell array of categorical response sequences. can any one plz help me to set my input for the nt_fc? thanks so much...
clc;clear;close
DD = load( 'D:\simulation\ALL1\TRAIN_DATA_taranhade.mat');
sequences = DD.sequences;
YTrain = ['a'; 'a'; 'a'; 'a'; 'a'; 'a';...
'b'; 'b'; 'b'; 'b'; 'b'; 'b';...
'c'; 'c'; 'c'; 'c'; 'c'; 'c';...
'd'; 'd'; 'd'; 'd'; 'd'; 'd';...
'e'; 'e'; 'e'; 'e'; 'e'; 'e';...
'f'; 'f'; 'f'; 'f'; 'f'; 'f';...
'g'; 'g'; 'g'; 'g'; 'g'; 'g';...
'h'; 'h'; 'h'; 'h'; 'h'; 'h'];
numObservations = numel(sequences);
idx = randperm(numObservations);
N = floor(0.9 * numObservations);
idxTrain = idx(1:N);
sequencesTrain = sequences(idxTrain);
labelsTrain = YTrain(idxTrain);
idxValidation = idx(N+1:end);
sequencesValidation = sequences(idxValidation);
labelsValidation = YTrain(idxValidation);
numFeatures = size(sequencesTrain{1},1);
numClasses = 8;%numel(categories((labelsTrain)));
% analyzeNetwork
layers = [
sequenceInputLayer(numFeatures,'Name','sequence')
fullyConnectedLayer(128,'Name','fc1');
reluLayer
fullyConnectedLayer(numClasses);
reluLayer
softmaxLayer('Name','soft1')
classificationLayer('Name','classification')];
miniBatchSize = 16;
numObservations = numel(sequencesTrain);
numIterationsPerEpoch = floor(numObservations / miniBatchSize);
options = trainingOptions('adam', ...
'MiniBatchSize',miniBatchSize, ...
'InitialLearnRate',1e-4, ...
'GradientThreshold',2, ...
'Shuffle','every-epoch', ...
'ValidationData',{sequencesValidation,labelsValidation}, ...
'ValidationFrequency',numIterationsPerEpoch, ...
'Plots','training-progress', ...
'Verbose',false);
net_fc = trainNetwork(sequencesTrain,labelsTrain,layers,options);

답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by