Using R^2 results to optimize number of neurons in hidden layer

조회 수: 3 (최근 30일)
KAE
KAE 2018년 9월 4일
댓글: KAE 2018년 9월 26일
I am trying to find the optimal number of neurons in a hidden layer following Greg Heath's method of looping over the candidate number of neurons, with an several trials per number of neurons. The resulting R-squared statistic is plotted below. I would like some help in using the R-squared results to pick the optimal number of neurons. Is this right:
  • 23 neurons is a good choice, since all the trials exceed the desired threshold of R-squared > 0.995. To make a prediction, I could pick any of the 10 trial nets that were generated with 23 neurons.
  • 15 neurons is a bad choice because sometimes the threshold is not met
  • More than 23 neurons is a bad choice because the network will be slower to run
  • More than 33 neurons is a very bad choice due to overfitting (the code below limits the number of neurons to avoid overfitting).
  • To save time, I could stop my loop after say 28 neurons, since by then I would have had 5 instances (23-28 neurons) in which all 10 trials resulted in R-squared values above my threshold.
Here is the code, based on Greg's examples with comments for learning.
%%Get data
[x, t] = engine_dataset; % Existing dataset
% Described here, https://www.mathworks.com/help/deeplearning/ug/choose-a-multilayer-neural-network-training-function.html#bss4gz0-26
[ I N ] = size(x); % Size of network inputs, nVar x nExamples
[ O N ] = size(t); % Size of network targets, nVar x nExamples
fractionTrain = 75/100; % How much of data will go to training
fractionValid = 15/100; % How much of data will go to validation
fractionTest = 15/100; % How much of data will go to testing
Ntrn = N - round(fractionTest*N + fractionValid*N); % Number of training examples
Ntrneq = Ntrn*O; % Number of training equations
% MSE for a naïve constant output model that always outputs average of target data
MSE00 = mean(var(t',1));
%%Find a good value for number of neurons H in hidden layer
Hmin = 0; % Lowest number of neurons H to consider (a linear model)
% To avoid overfitting with too many neurons, require that Ntrneq > Nw==> H <= Hub (upper bound)
Hub = -1+ceil( (Ntrneq-O) / ( I+O+1) ); % Upper bound for Ntrneq >= Nw
Hmax = floor(Hub/10); % Stay well below the upper bound by dividing by 10
dH = 1;
% Number of random initial weight trials for each candidate h
Ntrials = 10;
% Randomize data division and initial weights
% Choose a repeatable initial random number seed
% so that you can reproduce any individual design
rng(0);
j=0; % Loop counter for number of neurons in hidden layer
for h=Hmin:dH:Hmax % Number of neurons in hidden layer
j = j+1;
if h == 0 % Linear model
net = fitnet([]); % Fits a linear model
Nw = (I+1)*O; % Number of unknown weights
else % Nonlinear neural network model
net = fitnet(h); % Fits a network with h nodes
Nw = (I+1)*h + (h+1)*O; % Number of unknown weights
end
%%Divide up data into training, validation, testing sets
% Data presented to network during training. Network is adjusted according to its error.
net.divideParam.trainRatio = fractionTrain;
% Data used to measure network generalization, and to halt training when generalization stops improving
net.divideParam.valRatio = fractionValid;
% Data used as independent measure of network performance during and after training
net.divideParam.testRatio = fractionTest;
%%Loop through several trials for each candidate number of neurons
% to increase statistical reliability of results
for i=1:Ntrials
% Configure network inputs and outputs to best match input and target data
net = configure(net, x, t);
% Train network
[net, tr, y, e] = train(net,x,t); % Error e is target - output
% R-squared for normalized MSE. Fraction of target variance modeled by net
R2(i, j) = 1 - mse(e)/MSE00;
end
end
%%Plot R-squared
figure;
% Grid data in preparation for plotting
[nNeuronMat, trialMat] = meshgrid(Hmin:dH:Hmax, 1:Ntrials);
% Color of dot indicates the R-squared value
h1 = scatter(nNeuronMat(:), trialMat(:), 20, R2(:), 'filled');
hold on;
% Design is successful if it can account for 99.5% of mean target variance
iGood = R2 > 0.995;
% Circle the successful values
h2 = plot(nNeuronMat(iGood), trialMat(iGood), 'ko', 'markersize', 10);
xlabel('# Neurons'); ylabel('Trial #');
h3 = colorbar;
set(get(h3,'Title'),'String','R^2')
title({'Fraction of target variance that is modeled by net', ...
'Circled if it exceeds 99.5%'});
  댓글 수: 3
KAE
KAE 2018년 9월 25일
편집: KAE 2018년 9월 25일
You say you typically run through a reasonable value, here H=34, before seeking to minimize H. What do the 'reasonable value' results tell you? Maybe a good choice of a R-square threshold to use later in the H minimization, so for your example we might require the smallest H with Rsq>= 0.9975?
Below for other learners is Greg's code with comments added, as I figured out what it was doing.
%%Load data
[x,t] = engine_dataset;
% Data description, https://www.mathworks.com/help/deeplearning/ug/choose-a-multilayer-neural-network-training-function.html#bss4gz0-26
help engine_dataset % Learn about dataset
[I N ] = size(x) % [2 1199] Size of network inputs, nVar x nExamples
% Inputs are engine speed and fuel rate
[O N ] = size(t) % [2 1199] Size of network targets, nVar x nExamples
% Ouputs are torque and emission levels
% Separate out inputs and outputs for later plotting
x1 = x(1,:); x2 = x(2,:);
t1 = t(1,:); t2 = t(2,:);
%%Mean square error for later normalization
vart1 = var(t',1) % 1e5*[3.0079 2.1709]
% MSE for a naïve constant output model that always outputs average of target data
MSEref = mean(vart1) % 2.5894e+05
%%Look for obvious relations between inputs and outputs
figure(1)
subplot(2,2,1), plot(x1,t1,'.') % Looks linear
title('Torque vs. Engine Speed'); xlabel('Input 1'); ylabel('Target 1');
subplot(2,2,2), plot(x1,t2,'.') % Looks like a cubic
title('Nitrous Oxide Emission vs. Engine Speed'); xlabel('Input 1'); ylabel('Target 2');
subplot(2,2,3), plot(x2,t1,'.') % On/off behavior, maybe cubic
title('Torque vs. Fuel Rate'); xlabel('Input 2'); ylabel('Target 1');
subplot(2,2,4), plot(x2,t2,'.') % On/off behavior
title('Nitrous Oxide Emission vs. Fuel Rate'); xlabel('Input 2'); ylabel('Target 2');
Ntrn = N-round(0.3*N) % 839 Number of training examples
Ntrneq = Ntrn*O % 1678 Number of training equations
% Nw = (I+1)*H+(H+1)*O
% Ntrneq >= Nw ==> H <= Hmax <= Hub
Hub = floor((Ntrneq-O)/(I+O+1)) % 335 Upper bound for Ntrneq >= Nw
H = round(Hub/10) % 34 Stay well below the upper bound by dividing by 10
net = fitnet(H); % Fit a network with h nodes
[net tr y e ] = train(net,x,t); % Train network
NMSE = mse(e)/MSEref % 0.0025 normalized MSE
Rsq = 1-NMSE % 0.9975 R-squared. Fraction of target variance modeled by net
Greg Heath
Greg Heath 2018년 9월 26일
KAE: You say you typically run through a reasonable
value, here H=34, before seeking to minimize H.
What do the 'reasonable value' results tell you?
Maybe a good choice of a R-square threshold to use
later in the H minimization, so for your example we
might require the smallest H with Rsq>= 0.9975?
GREG: NO.
My training goal for all of the MATLAB examples has
always been the more realistic
NMSE = mse(t-y)/mean(var(t',1)) <= 0.01
or , equivalently
Rsq = 1- NMSE >= 0.99
That is not to say better values can not be acheived.
Just that this is a more reasonable way to begin
before trying to optimize the final result.
%Using the command
>> help nndatasets
% yields
simplefit_dataset - Simple fitting dataset.
abalone_dataset - Abalone shell rings dataset.
bodyfat_dataset - Body fat percentage dataset.
building_dataset - Building energy dataset.
chemical_dataset - Chemical sensor dataset.
cho_dataset - Cholesterol dataset.
engine_dataset - Engine behavior dataset.
vinyl_dataset - Vinyl bromide dataset.
% then
>> SIZE1 = size(simplefit_dataset)
SIZE2 = size(abalone_dataset)
SIZE3 = size(bodyfat_dataset)
SIZE4 = size(building_dataset)
SIZE5 = size(chemical_dataset)
SIZE6 = size(cho_dataset)
SIZE7 = size(engine_dataset)
SIZE8 = size(vinyl_dataset)
% yields
SIZE1 = 1 94
SIZE2 = 8 4177
SIZE3 = 13 252
SIZE4 = 14 4208
SIZE5 = 8 498
SIZE6 = 21 264
SIZE7 = 2 1199
SIZE8 = 16 68308
% Therefore, FOR OBVIOUS REASONS, most of my MATLAB tutorial-type investigations are restricted to N < 500 (i.e., simplefit, bodyfat, chemical and cho)
Hope this helps.
Greg

댓글을 달려면 로그인하십시오.

채택된 답변

Greg Heath
Greg Heath 2018년 9월 23일
The answer is H = 12. H >= 13 is overfitting.
If that makes you queasy, average the output of the H=12 net with the 2 with H = 13.
NOTE: Overfitting is not "THE" problem. "THE PROBLEM" is:
OVERTRAINING AN OVERFIT NET.
Hope this helps.
*Thank you for formally accepting my answer*
Greg
  댓글 수: 3
Greg Heath
Greg Heath 2018년 9월 26일
편집: Greg Heath 2018년 9월 26일
There are TWO MAINPOINTS
1. H <= Hub PREVENTS OVERFITTING, i.e. having more unknown weights than training equations.
2. min(H) subject to R2 > 0.99 STABILIZES THE SOLUTION by REDUCING the uncertainty in weight estimates.
Hope this helps.
Greg
KAE
KAE 2018년 9월 26일
Thanks for the education!

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by