필터 지우기
필터 지우기

How to get r-square,mean absolute error and mean square error after train neural network?

조회 수: 6 (최근 30일)
Hi all, I train neural network as follow command
net.divideFcn = 'dividerand'
net.divideParam.trainRatio= 0.6;
net.divideParam.testRatio= 0.2;
net.divideParam.valRatio= 0.2;
[net,tr]=train(net,input,target);
I want to get r-square,mean absolute error and mean square error from train,test and validation data
Cloud you please advice ?

답변 (1개)

Paras Gupta
Paras Gupta 2024년 7월 18일 9:07
Hi Ninlawat,
I understand that you want to compute different network performance metrics on the train, test, and validation data after training a neural network object in MATLAB.
The following code illustrates one way to achieve the same:
% dummy data
input = rand(1, 100); % 1 feature, 100 samples
target = 2 * input + 1 + 0.1 * randn(1, 100); % Linear relation with some noise
% Define the feedforward network
net = feedforwardnet(10); % 10 hidden neurons
% Set up the data division
net.divideFcn = 'dividerand';
net.divideParam.trainRatio = 0.6;
net.divideParam.valRatio = 0.2;
net.divideParam.testRatio = 0.2;
% Train the network
[net, tr] = train(net, input, target);
% Get the network outputs
outputs = net(input);
% Separate the outputs for training, validation, and testing
trainOutputs = outputs(tr.trainInd);
valOutputs = outputs(tr.valInd);
testOutputs = outputs(tr.testInd);
% Separate the targets for training, validation, and testing
trainTargets = target(tr.trainInd);
valTargets = target(tr.valInd);
testTargets = target(tr.testInd);
% Calculate and display R-square, MAE, and MSE for each dataset
datasets = {'train', 'val', 'test'};
outputsList = {trainOutputs, valOutputs, testOutputs};
targetsList = {trainTargets, valTargets, testTargets};
for i = 1:length(datasets)
dataset = datasets{i};
outputs = outputsList{i};
targets = targetsList{i};
% R-square
SS_res = sum((targets - outputs).^2);
SS_tot = sum((targets - mean(targets)).^2);
R_square = 1 - SS_res / SS_tot;
% Mean Absolute Error (MAE)
MAE = mae(targets - outputs);
% Mean Square Error (MSE)
MSE = mse(net, targets, outputs);
% Display the results
fprintf('%s R-square: %.4f\n', dataset, R_square);
fprintf('%s MAE: %.4f\n', dataset, MAE);
fprintf('%s MSE: %.4f\n', dataset, MSE);
fprintf('\n');
end
You can refer the following documentation links for more infromation on the properties and functions used in the code above:
Hope this helps.

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by