How to improve prediction accuracy of ANN?
조회 수: 7 (최근 30일)
이전 댓글 표시
I am using ANN to predict thawing rates by training the model with input features from my freeze thawing experiments. Below is the neural network segment of my code. I have a freeze thawing dataset with 6 input features and just a single output feature which is the thawing rate. I have tried changing the hidden layer size, neuron size and the proportion of training and testing data sets. However, I am unable to get good results with low MREs no matter how many times we run the program. I am unable to also change the number of epochs as it only goes up to 13 or 14 not 1000.
MRE % = (Target output - Predicted output ) *100 / (Target output)
Below is the ANN part of the code I am using. We use tangential sigmoid to normalize the data by scaling it between 0 and 1 based on the maximum and minimum value.
trainFcn = 'trainlm'; % Bayesian Regularization backpropagation. For updating weights
net.layers{1}.transferFcn = 'tansig'; %to calculate the layer's output from the input
out = [];
outNew = [];
% n>m>o
for n=1:6 %lay=hiden layer
for m=1:5 %length(0:dm:end_m)
for o=1:4
hiddenLayerSize = [n,m,o];
net.trainParam.epochs = 2000;
net = fitnet(hiddenLayerSize,trainFcn);
net.input.processFcns = {}; %not to do scale as default
net.output.processFcns = {};
net.divideFcn = 'dividerand'; % Divide data randomly
net.divideMode = 'sample'; % Divide up every sample
net.divideParam.trainRatio = 90/100;
% net.divideParam.valRatio = 5/100;
%can we remove validation? - can try
net.divideParam.testRatio = 10/100;
net.performFcn = 'mse'; % Mean Squared Error
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', ...
'plotregression', 'plotfit'};
% Train the Network
[net,tr] = train(net,x_scaled,target_scaled);
% Test the Network
y_scaled= net(x_scaled); %predictions of input x_scaled
x_train = x_scaled(:,tr.trainInd);
target_train = target_scaled(:,tr.trainInd);
x_test = x_scaled(:,tr.testInd);
target_test = target_scaled(:,tr.testInd);
y_train_pred = net(x_train); %get the training prediction
y_test_pred = net(x_test); %get the testing prediction
wts = getwb(net); %get the weights and biases as single vector
[b, IW, LW] = separatewb(net, wts); %separate weights and biases from weight/bias vector
Below is an example of how we are using the code. This is just a small subset of the data we are using. In reality we are using close to 40 rows of data. We have 6 input features, one target output feature. The predicted output is shown below as well. The last 4 rows are testing set predictions and every other row belongs to the training set.
댓글 수: 0
답변 (1개)
Drew
2023년 9월 26일
Some thoughts:
(1) The data set is very small ("close to 40 rows of data"). With such a small training set, try some small neural networks, some with only 1 or 2 layers. Currently, it looks like all of your networks have three layers. In general, the limited training data will lead to limited generalization capability.
(2) During training, the net.performFcn is set to mse, so mean squared error is being minimized during training. After training, you are looking at the MRE% (Mean Relative Error Percentage). Consider, for example, looking at MSE (Mean Squared Error) as well, since that was the metric used during training.
(3) Given that the data set is very small, use cross-validation for all experiments. In the limit, leave-one-out cross validation could be done. It looks like your code may be getting a different single partition of the data for each candidate network topology. That will make the results across topologies not directly comparable. The small size of each test set also makes the results highly variable.
(4) To get another perspective easily, load the data into the Regression Learner app (part of the SMLT toolbox). In the session start dialog, specify to use cross-validation. To get leave-one-out cross-validation, increase the number of folds to be equal to the number of data points. Train various neural network models, including an optimizable neural network, which will optimize hyperparameters (including network size) on this data set. The app will show the error rate from cross validation for each model.
댓글 수: 0
참고 항목
카테고리
Help Center 및 File Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!