How to save a neural network to test on a new dataset?
조회 수: 4 (최근 30일)
이전 댓글 표시
Hi,
I am using the following code to train and test NN for 2-class classification. I need to save the trained network to test on a diffreent data set. I tried the save net command, but it just saved the results and not the trained model.
Can nayone please help to get that.?
load iris.mat; % Matlab also provides this dataset (load fisheriris.mat)
% Call features & labels
feat=f; label=l;
% Programmer: Jingwei Too
function NN=jNN(feat,label,kfold,Hiddens,Maxepochs)
% Layer
if length(Hiddens)==1
h1=Hiddens(1); net=patternnet(h1);
elseif length(Hiddens)==2
h1=Hiddens(1); h2=Hiddens(2); net=patternnet([h1 h2]);
elseif length(Hiddens)==3
h1=Hiddens(1); h2=Hiddens(2); h3=Hiddens(3);
net=patternnet([h1 h2 h3]);
end
% rng('default');
% Divide data into k-folds
fold=cvpartition(label,'kfold',kfold,'stratify',true);
% Pre
pred2=[]; ytest2=[]; Afold=zeros(kfold,1);
% Neural network start
for i=1:kfold
% Call index of training & testing sets
trainIdx=fold.training(i); testIdx=fold.test(i);
% Call training & testing features and labels
xtrain=feat(trainIdx,:); ytrain=label(trainIdx);
xtest=feat(testIdx,:); ytest=label(testIdx);
% Set Maximum epochs
net.trainParam.epochs= Maxepochs;
% to prevent early stopping
net.trainParam.max_fail = 500;
net.trainParam.min_grad = 0.000000000000001;
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', ...
'plotconfusion', 'plotroc'};
% Training model
net=train(net,xtrain',dummyvar(ytrain)');
% Perform testing
pred=net(xtest');
% Confusion matrix
[~,con]=confusion(dummyvar(ytest)',pred);
% Get accuracy for each fold
Afold(i)=100*sum(diag(con))/sum(con(:));
% Store temporary result for each fold
pred2=[pred2(1:end,:),pred]; ytest2=[ytest2(1:end);ytest];
end
% Overall confusion matrix
save net
[~,confmat]=confusion(dummyvar(ytest2)',pred2); confmat=transpose(confmat);
% Average accuracy over k-folds
acc=mean(Afold);
% Store results
NN.fold=Afold; NN.acc=acc; NN.con=confmat;
fprintf('\n Classification Accuracy (NN): %g %%',acc);
% figure, plotperform(tr)
%figure, plottrainstate(tr)
% figure, ploterrhist(e)
% figure, plotconfusion(ytest2,pred)
% figure, plotroc(Labels,y)
end
댓글 수: 0
채택된 답변
Dheeraj Singh
2019년 12월 20일
So, instead of saving the model inside the function, you can return the model
function [net,NN]=jNN(feat,label,kfold,Hiddens,Maxepochs)
and then use the save command
save net
Also, if you trying to use the iris dataset in MATLAB use iris.dat
load iris.dat
feat=iris(:,1:4); label=iris(:,5);
I used the following code and it is working for me:
feat=iris(:,1:4); label=iris(:,5);
% Programmer: Jingwei Too
[net,NN]=jNN(feat,label,5,[10 10 10],10000)
save net
댓글 수: 2
Dheeraj Singh
2019년 12월 23일
This approach looks fine.
But to get bettwe results you may try using Feature Extraction by Means of Spatial Filtering (Common Spatial Patterns) as done in the following blog:
You can also refer to These File Exchange Links for EEG Data Analysis:
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Biomedical Signal Processing에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!