Conflict between SVM classifier and perfcurve()
조회 수: 7 (최근 30일)
이전 댓글 표시
I want to do binary classification using SVM and evaluate its performance using ROC by the function perfcurve(). Meanwhile I want to find out which is optimal threshold that separates these two classes and wheter feature > threshold is classified into class 1 or feature < threshold is classified into class 1.
Mdl = fitcsvm(X,Y); % train a model
[Label,Score] = predict(Mdl, X_new); % predict new data
[X,Y,T,AUC,OPTROCPT] = perfcurve(labels,Scores(:, 2), posclass); % ROC
My data X is a 100-by-1 vector, i.e. 100 observations and each observation has only one feature.
The predict() function will give a label based on the score. The score is peculiar, compared to the scores given by other classifiers. Generally, the classifier will give only one score for each observation and give a threshold on score to do classification. However, the predict() does not. In my binary classification case , the score is a n-by-2 matrix, accoding to matlab documentaion, "For each observation in X, the predicted class label corresponds to the maximum score among all classes."
In my case, I found that
Scores(:, 2) = -Scores(:, 1)
So
Label = Scores(:, 2) > Scores(:, 1);
or
Label = Scores(:, 2) < Scores(:, 1);
Here comes the problem, I can not get the optimal threshold that separates these two classes. Maybe you want to say that 0 is the threshold, score > 0 is classified into one class, and score < 0 is classified into another one. It seems plausiable, but contradicts perfcurve().
To find out the optimal threshold, I use perfcurve(). One of its output OPTROCPT is "Optimal operating point of the ROC curve, returned as a 1-by-2 array with false positive rate (FPR) and true positive rate (TPR) values for the optimal ROC operating point."
So I calculate the optimal threshold as follow
optIndex = find(X==OPTROCPT(1) & Y==OPTROCPT(2));
optThresh = T(optIndex);
This optThresh is different from optimal threshold given by SVM.
This code below can reproduce this problem
% conflict between SVM classifier and perfcurve()
%% load data
load fisheriris
inds = ~strcmp(species,'setosa'); % use two species to do binary classification
X = meas(inds, 1); % use one feature only
y = species(inds);
%% get train data and validation data
train_inds = true([100, 1]);
train_inds(1:25) = false;
train_inds(51:75) = false;
val_inds = ~train_inds ;
X_train = X(train_inds);
y_train = y(train_inds);
X_val = X(val_inds);
y_val = y(val_inds);
%% SVM
SVMModel = fitcsvm(X_train,y_train);
[Label,Score] = predict(SVMModel, X_val); % predict on new data
[X,Y,T,AUC,OPTROCPT] = perfcurve(y_val,Score(:, 2), SVMModel.ClassNames{2}); % ROC
optIndex = find(X==OPTROCPT(1) & Y==OPTROCPT(2));
optThresh = T(optIndex);
isequal(optThresh, 0)
댓글 수: 0
답변 (1개)
Song Gao
2021년 5월 27일
I think the problem here is you used the validation dataset to determine the optimal point.
댓글 수: 0
참고 항목
카테고리
Help Center 및 File Exchange에서 Classification Learner App에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!