How to save a neural network to test on a new dataset?

조회 수: 4 (최근 30일)
Joana
Joana 2019년 12월 16일
댓글: Dheeraj Singh 2019년 12월 23일
Hi,
I am using the following code to train and test NN for 2-class classification. I need to save the trained network to test on a diffreent data set. I tried the save net command, but it just saved the results and not the trained model.
Can nayone please help to get that.?
load iris.mat; % Matlab also provides this dataset (load fisheriris.mat)
% Call features & labels
feat=f; label=l;
% Programmer: Jingwei Too
function NN=jNN(feat,label,kfold,Hiddens,Maxepochs)
% Layer
if length(Hiddens)==1
h1=Hiddens(1); net=patternnet(h1);
elseif length(Hiddens)==2
h1=Hiddens(1); h2=Hiddens(2); net=patternnet([h1 h2]);
elseif length(Hiddens)==3
h1=Hiddens(1); h2=Hiddens(2); h3=Hiddens(3);
net=patternnet([h1 h2 h3]);
end
% rng('default');
% Divide data into k-folds
fold=cvpartition(label,'kfold',kfold,'stratify',true);
% Pre
pred2=[]; ytest2=[]; Afold=zeros(kfold,1);
% Neural network start
for i=1:kfold
% Call index of training & testing sets
trainIdx=fold.training(i); testIdx=fold.test(i);
% Call training & testing features and labels
xtrain=feat(trainIdx,:); ytrain=label(trainIdx);
xtest=feat(testIdx,:); ytest=label(testIdx);
% Set Maximum epochs
net.trainParam.epochs= Maxepochs;
% to prevent early stopping
net.trainParam.max_fail = 500;
net.trainParam.min_grad = 0.000000000000001;
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', ...
'plotconfusion', 'plotroc'};
% Training model
net=train(net,xtrain',dummyvar(ytrain)');
% Perform testing
pred=net(xtest');
% Confusion matrix
[~,con]=confusion(dummyvar(ytest)',pred);
% Get accuracy for each fold
Afold(i)=100*sum(diag(con))/sum(con(:));
% Store temporary result for each fold
pred2=[pred2(1:end,:),pred]; ytest2=[ytest2(1:end);ytest];
end
% Overall confusion matrix
save net
[~,confmat]=confusion(dummyvar(ytest2)',pred2); confmat=transpose(confmat);
% Average accuracy over k-folds
acc=mean(Afold);
% Store results
NN.fold=Afold; NN.acc=acc; NN.con=confmat;
fprintf('\n Classification Accuracy (NN): %g %%',acc);
% figure, plotperform(tr)
%figure, plottrainstate(tr)
% figure, ploterrhist(e)
% figure, plotconfusion(ytest2,pred)
% figure, plotroc(Labels,y)
end

채택된 답변

Dheeraj Singh
Dheeraj Singh 2019년 12월 20일
save will generally save all the variables present in the workspace at point of time.
So, instead of saving the model inside the function, you can return the model
function [net,NN]=jNN(feat,label,kfold,Hiddens,Maxepochs)
and then use the save command
save net
Also, if you trying to use the iris dataset in MATLAB use iris.dat
load iris.dat
feat=iris(:,1:4); label=iris(:,5);
I used the following code and it is working for me:
feat=iris(:,1:4); label=iris(:,5);
% Programmer: Jingwei Too
[net,NN]=jNN(feat,label,5,[10 10 10],10000)
save net
  댓글 수: 2
Joana
Joana 2019년 12월 20일
Hi Dheeraj,
Thanks for your answer.
I have one more question regarding data preparation for the Input. I have EEG data for 2-classes, recorded at 1200hz, with 32 EEG channels. i have extracted each class for 1 second for 100 trials. So the data is in the format of number of channels x sampling frequency x trials = 32x1200x200.
I tried the above program by converting the data to 2D as: 38,400x200.
SO input layers neurons are 38,400. and 3 hidden layers with neurons [10 10 10], one output neuron.
Is it the right way to do it.? or should i try something else.? (This gives acceptable results though but i am confused if inout neurons can be this much)
I tried using cell array function but it gives an error that 'Data distribution doesn't have equal number of time steps'.
I'll highly appreciate if you can guide me on this one. :)
Dheeraj Singh
Dheeraj Singh 2019년 12월 23일
This approach looks fine.
But to get bettwe results you may try using Feature Extraction by Means of Spatial Filtering (Common Spatial Patterns) as done in the following blog:
You can also refer to These File Exchange Links for EEG Data Analysis:

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Biomedical Signal Processing에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by