loss, the classification error

조회 수: 5 (최근 30일)
Hyeongcheol Lee
Hyeongcheol Lee 2019년 11월 20일
편집: Drew 2024년 10월 2일
I'm using the "loss" function when I calculate a classification error.
Below is a confusion matrix that has one mis classification.
캡처4.PNG
In abvoe case, I think the loss should be 1/7*100 = 14.3 %.
But the "loss" function shows 15.9 %.
It seems the "loss"' function has a special logic to calculate loss.
So, I'd like to know what it is. And, if possible, have loss value same as 14.3 % by modifying option in "loss" function.

답변 (1개)

Drew
Drew 2024년 10월 2일
편집: Drew 2024년 10월 2일
As indicated at https://www.mathworks.com/help/stats/classreg.learning.classif.compactclassificationensemble.loss.html#bst1mt4-4, "The software normalizes the observation weights so that they sum to the corresponding prior class probability stored in the Prior property." So, the unexpected behavior is due to the loss function making use of the prior class probabilities stored in the model.
Since you are looking for the unweighted classification error, with no dependence on the prior class probabilities stored in the model, this can be obtained from the "loss" function by specifiying a custom loss function handle in the "LossFun" name-value argument. The custom loss function "function loss = unweighted_classiferror_LossFun(C, s, W, Cost)" given in the following code can be used with the name-value argument "LossFun=@unweighted_classiferror_LossFun". The following code sets up a simple classifier that encounters exactly the situation you mentioned in the question, and then performs the calculation of the unweighted classification error to obtain the desired 14.3% (which is the same as 1/7 or 0.1429) using the "loss" function.
% For reproducability
rng(1);
% Two classes of data
% class 1, N(0,1) with 523 observations
% class 2, N(3,1) with 477 observations
Xtrain=[randn(523,1);(randn(477,1)+3)];
% Vector of target class labels, 1 and 2
Ytrain=[ones(523,1);2*ones(477,1)];
% Make a simple tree model with one split
mdl=fitctree(Xtrain,Ytrain,MaxNumSplits=1);
view(mdl)
Decision tree for classification 1 if x1<1.80105 then node 2 elseif x1>=1.80105 then node 3 else 1 2 class = 1 3 class = 2
model_priors_based_on_training_data = mdl.Prior
model_priors_based_on_training_data = 1×2
0.5230 0.4770
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
model_class_names = mdl.ClassNames
model_class_names = 2×1
1 2
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
% Create some fake test data that has the confusion matrix in the question
Xtest = [0 0 0 0 3 3 1.7]';
Ytest = [1 1 1 1 2 2 2]';
cm = confusionmat(Ytest,predict(mdl,Xtest))
cm = 2×2
4 0 1 2
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
% classiferror loss weighted by class priors from training data
weighted_classiferror_loss = loss(mdl,Xtest,Ytest)
weighted_classiferror_loss = 0.1590
% unweighted classiferror loss using a simple custom loss function
unweighted_classiferror_loss = loss(mdl,Xtest,Ytest,LossFun=@unweighted_classiferror_LossFun)
unweighted_classiferror_loss = 0.1429
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Calculate unweighted loss with custom function-handle loss function
function loss = unweighted_classiferror_LossFun(C, s, W, Cost)
% C is the N-by-K logical matrix for N observations and K classes
% indicating the class to which the corresponding observation belongs.
% The column order corresponds to the class order in mdl.ClassNames.
% s is the N-by-K matrix of predicted scores
% W is the N-by-1 vector of observation weights
% Cost is the K-by-K numeric matrix of misclassification costs
% This particular implementation ignores inputs "W" and "Cost"
% Find the class with the highest score for each observation
[~, predictedClass] = max(s, [], 2);
% Find the true class for each observation
[~, trueClass] = max(C, [], 2);
% Calculate the number of misclassified instances
misclassified = trueClass ~= predictedClass;
% Calculate the unweighted classification error
loss = sum(misclassified) / length(misclassified);
end

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by