Using R^2 results to optimize number of neurons in hidden layer
조회 수: 3 (최근 30일)
이전 댓글 표시
I am trying to find the optimal number of neurons in a hidden layer following Greg Heath's method of looping over the candidate number of neurons, with an several trials per number of neurons. The resulting R-squared statistic is plotted below. I would like some help in using the R-squared results to pick the optimal number of neurons. Is this right:
- 23 neurons is a good choice, since all the trials exceed the desired threshold of R-squared > 0.995. To make a prediction, I could pick any of the 10 trial nets that were generated with 23 neurons.
- 15 neurons is a bad choice because sometimes the threshold is not met
- More than 23 neurons is a bad choice because the network will be slower to run
- More than 33 neurons is a very bad choice due to overfitting (the code below limits the number of neurons to avoid overfitting).
- To save time, I could stop my loop after say 28 neurons, since by then I would have had 5 instances (23-28 neurons) in which all 10 trials resulted in R-squared values above my threshold.
Here is the code, based on Greg's examples with comments for learning.
%%Get data
[x, t] = engine_dataset; % Existing dataset
% Described here, https://www.mathworks.com/help/deeplearning/ug/choose-a-multilayer-neural-network-training-function.html#bss4gz0-26
[ I N ] = size(x); % Size of network inputs, nVar x nExamples
[ O N ] = size(t); % Size of network targets, nVar x nExamples
fractionTrain = 75/100; % How much of data will go to training
fractionValid = 15/100; % How much of data will go to validation
fractionTest = 15/100; % How much of data will go to testing
Ntrn = N - round(fractionTest*N + fractionValid*N); % Number of training examples
Ntrneq = Ntrn*O; % Number of training equations
% MSE for a naïve constant output model that always outputs average of target data
MSE00 = mean(var(t',1));
%%Find a good value for number of neurons H in hidden layer
Hmin = 0; % Lowest number of neurons H to consider (a linear model)
% To avoid overfitting with too many neurons, require that Ntrneq > Nw==> H <= Hub (upper bound)
Hub = -1+ceil( (Ntrneq-O) / ( I+O+1) ); % Upper bound for Ntrneq >= Nw
Hmax = floor(Hub/10); % Stay well below the upper bound by dividing by 10
dH = 1;
% Number of random initial weight trials for each candidate h
Ntrials = 10;
% Randomize data division and initial weights
% Choose a repeatable initial random number seed
% so that you can reproduce any individual design
rng(0);
j=0; % Loop counter for number of neurons in hidden layer
for h=Hmin:dH:Hmax % Number of neurons in hidden layer
j = j+1;
if h == 0 % Linear model
net = fitnet([]); % Fits a linear model
Nw = (I+1)*O; % Number of unknown weights
else % Nonlinear neural network model
net = fitnet(h); % Fits a network with h nodes
Nw = (I+1)*h + (h+1)*O; % Number of unknown weights
end
%%Divide up data into training, validation, testing sets
% Data presented to network during training. Network is adjusted according to its error.
net.divideParam.trainRatio = fractionTrain;
% Data used to measure network generalization, and to halt training when generalization stops improving
net.divideParam.valRatio = fractionValid;
% Data used as independent measure of network performance during and after training
net.divideParam.testRatio = fractionTest;
%%Loop through several trials for each candidate number of neurons
% to increase statistical reliability of results
for i=1:Ntrials
% Configure network inputs and outputs to best match input and target data
net = configure(net, x, t);
% Train network
[net, tr, y, e] = train(net,x,t); % Error e is target - output
% R-squared for normalized MSE. Fraction of target variance modeled by net
R2(i, j) = 1 - mse(e)/MSE00;
end
end
%%Plot R-squared
figure;
% Grid data in preparation for plotting
[nNeuronMat, trialMat] = meshgrid(Hmin:dH:Hmax, 1:Ntrials);
% Color of dot indicates the R-squared value
h1 = scatter(nNeuronMat(:), trialMat(:), 20, R2(:), 'filled');
hold on;
% Design is successful if it can account for 99.5% of mean target variance
iGood = R2 > 0.995;
% Circle the successful values
h2 = plot(nNeuronMat(iGood), trialMat(iGood), 'ko', 'markersize', 10);
xlabel('# Neurons'); ylabel('Trial #');
h3 = colorbar;
set(get(h3,'Title'),'String','R^2')
title({'Fraction of target variance that is modeled by net', ...
'Circled if it exceeds 99.5%'});
댓글 수: 3
Greg Heath
2018년 9월 26일
KAE: You say you typically run through a reasonable
value, here H=34, before seeking to minimize H.
What do the 'reasonable value' results tell you?
Maybe a good choice of a R-square threshold to use
later in the H minimization, so for your example we
might require the smallest H with Rsq>= 0.9975?
GREG: NO.
My training goal for all of the MATLAB examples has
always been the more realistic
NMSE = mse(t-y)/mean(var(t',1)) <= 0.01
or , equivalently
Rsq = 1- NMSE >= 0.99
That is not to say better values can not be acheived.
Just that this is a more reasonable way to begin
before trying to optimize the final result.
%Using the command
>> help nndatasets
% yields
simplefit_dataset - Simple fitting dataset.
abalone_dataset - Abalone shell rings dataset.
bodyfat_dataset - Body fat percentage dataset.
building_dataset - Building energy dataset.
chemical_dataset - Chemical sensor dataset.
cho_dataset - Cholesterol dataset.
engine_dataset - Engine behavior dataset.
vinyl_dataset - Vinyl bromide dataset.
% then
>> SIZE1 = size(simplefit_dataset)
SIZE2 = size(abalone_dataset)
SIZE3 = size(bodyfat_dataset)
SIZE4 = size(building_dataset)
SIZE5 = size(chemical_dataset)
SIZE6 = size(cho_dataset)
SIZE7 = size(engine_dataset)
SIZE8 = size(vinyl_dataset)
% yields
SIZE1 = 1 94
SIZE2 = 8 4177
SIZE3 = 13 252
SIZE4 = 14 4208
SIZE5 = 8 498
SIZE6 = 21 264
SIZE7 = 2 1199
SIZE8 = 16 68308
% Therefore, FOR OBVIOUS REASONS, most of my MATLAB tutorial-type investigations are restricted to N < 500 (i.e., simplefit, bodyfat, chemical and cho)
Hope this helps.
Greg
채택된 답변
Greg Heath
2018년 9월 23일
The answer is H = 12. H >= 13 is overfitting.
If that makes you queasy, average the output of the H=12 net with the 2 with H = 13.
NOTE: Overfitting is not "THE" problem. "THE PROBLEM" is:
OVERTRAINING AN OVERFIT NET.
Hope this helps.
*Thank you for formally accepting my answer*
Greg
댓글 수: 3
Greg Heath
2018년 9월 26일
편집: Greg Heath
2018년 9월 26일
There are TWO MAINPOINTS
1. H <= Hub PREVENTS OVERFITTING, i.e. having more unknown weights than training equations.
2. min(H) subject to R2 > 0.99 STABILIZES THE SOLUTION by REDUCING the uncertainty in weight estimates.
Hope this helps.
Greg
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Modeling and Prediction with NARX and Time-Delay Networks에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!