Cross-validation improvement

조회 수: 2 (최근 30일)
Qiang Wang
Qiang Wang 2020년 9월 22일
편집: Qiang Wang 2020년 10월 5일
Hi, All
I want to improve cross-validation results.
so, I fellow this suggestion from @Greg Heath use multiple nets to improve the cross-validation results.
a. For each i of i =1:k design multiple nets differing by the assignment of random initial weights. Discard those with poor performance and average the performance of the rest
is this code right?
clear all;
close all;
clc
counter=0
for h= 5:20
for j= 5:20
%%input file
xp = predictors;
xt = table2array(predictors);
x= xt' % input
t = response' % target
trainFcn = 'trainbr'; % Scaled conjugate gradient backpropagation.
hiddenLayerSize = [h,j]
net = fitnet(hiddenLayerSize, trainFcn)
sum_rmse =0
fid=fopen('sum-r.txt','w')
KFolds = 10 % K value
cvp = cvpartition(size(response, 1), 'KFold', KFolds) %gernater cvp
for i= 1:KFolds
trainIdx =cvp.training(i); % index of training data
testIdx = cvp.test(i); % index of validation data
trInd=find(trainIdx)
testInd=find(testIdx)
xtrain=x(:,trainIdx);
ytrain=t(:,trainIdx);
xtest=x(:,testIdx);
ytest=t(:,testIdx);
net.layers{1}.transferFcn ='tansig'; % logsig
net.layers{2}.transferFcn ='tansig';
net.layers{3}.transferFcn ='purelin';
net.divideFcn = 'divideind';
net.divideParam.trainInd=trInd;
net.divideParam.testInd=tstInd;
net.trainParam.showWindow = 0
net.performFcn = 'mse' % MSE
%rng('default')
nets{i} = train(net,xtrain,ytrain) % trains a network net according to net.trainFcn and net.trainParam. tr: returns a training record
end
%% multiple nets
for i= 1: KFolds
neti=nets{i}
yPred = neti(xtest)
ypt=neti(xtrain)
y = neti(x) % test the molde net % Estimate the targets using the trained network.
Train_RMSE = sqrt(sum(( ypt - ytrain ).^2) / numel(ytrain))
%da = y-t;
vrmse = sqrt(sum((yPred-ytest).^2) / numel(ytest))
train_t(i)=Train_RMSE
ttrr(i)=vrmse
R= corrcoef(ytest,yPred)
Rv(i)=R(1,2)
%avrg_rmse = mean(testrmse)
trmse= mse(neti,ytest,yPred)
sum_rmse = sum_rmse +vrmse
counter=counter+1
figname = strcat(num2str(counter),'.mat')
save(figname);
fprintf(fid,'%6.2f %12.2f\n\n',sum_rmse/10);
%plotregression(yTest,yPred,'Test')
end
fclose(fid);
average_rmse(h,j)=sum_rmse/10;
%accuracy=mean(tata)
cvrmse(h,j)=mean(ttrr)
ttrmse(h,j)=mean(train_t)
R_value(h,j) = mean(Rv)
end
end
R_value=R_value'
ttrmse= ttrmse'
cvrmse= cvrmse'
save ('test.txt','cvrmse','-ascii')
save('train.txt','ttrmse','-ascii')
save('r.txt','R_value','-ascii')

답변 (0개)

카테고리

Help CenterFile Exchange에서 Statistics and Machine Learning Toolbox에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by