Main Content

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

베이즈 최적화를 사용하여 SVM 분류기 피팅 최적화하기

이 예제에서는 fitcsvm 함수와 OptimizeHyperparameters 이름-값 쌍을 사용하여 SVM 분류를 최적화하는 방법을 보여줍니다. 이 분류는 가우스 혼합 모델의 점 위치를 사용한 것입니다. The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009)의 17페이지에 이 모델에 대한 설명이 나와 있습니다. 이 모델은 평균 (1,0)과 단위 분산의 2차원 독립 정규분포로 분산되고 "녹색" 클래스에 속하는 10개 기준점을 생성하는 것으로 시작합니다. 또한, 평균 (0,1)과 단위 분산의 2차원 독립 정규분포로 분산되고 "빨간색" 클래스에 속하는 10개 기준점도 생성합니다. 다음과 같이 각 클래스(녹색과 빨간색)에 대해 임의의 100개 점을 생성합니다.

  1. 해당 색상별로 임의로 균일하게 분포하는 기준점 m을 선택합니다.

  2. 평균이 m이고 분산이 I/5인(여기서 I는 2×2 단위 행렬임) 2차원 정규분포를 띠는 독립적인 임의의 점을 생성합니다. 이 예제에서는 분산 I/50을 사용하여 최적화의 이점을 더 확실하게 보여줍니다.

점과 분류기 생성하기

각 클래스에 대해 10개 기준점을 생성합니다.

rng default % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

기준점을 표시합니다.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

일부 빨간색 기준점이 녹색 기준점에 가깝기 때문에 위치만 기준으로 하여 데이터 점을 분류하는 것이 어려울 수 있습니다.

각 클래스에 대해 100개 데이터 점을 생성합니다.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

데이터 점을 표시합니다.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

분류를 위한 데이터 준비하기

데이터를 한 행렬에 저장하고, 각 점의 클래스에 레이블을 지정하는 벡터 grp를 생성합니다.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

교차 검증 준비하기

교차 검증에 사용할 분할을 설정합니다. 이 단계에서는 최적화가 각 단계에 사용하는 훈련 세트와 검정 세트를 정합니다.

c = cvpartition(200,'KFold',10);

피팅 최적화하기

양호한 피팅, 즉 교차 검증 손실이 적은 피팅을 구하려면 베이즈 최적화를 사용하도록 옵션을 설정하십시오. 모든 최적화에 동일한 교차 검증 분할 c를 사용합니다.

재현이 가능하도록 'expected-improvement-plus' 획득 함수를 사용합니다.

opts = struct('Optimizer','bayesopt','ShowPlots',true,'CVPartition',c,...
    'AcquisitionFunctionName','expected-improvement-plus');
svmmod = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |       0.345 |     0.24139 |       0.345 |       0.345 |      0.00474 |       306.44 |
|    2 | Best   |       0.115 |     0.16228 |       0.115 |     0.12678 |       430.31 |       1.4864 |
|    3 | Accept |        0.52 |     0.13937 |       0.115 |      0.1152 |     0.028415 |     0.014369 |
|    4 | Accept |        0.61 |      0.2731 |       0.115 |     0.11504 |       133.94 |    0.0031427 |
|    5 | Accept |        0.34 |     0.35782 |       0.115 |     0.11504 |     0.010993 |       5.7742 |
|    6 | Best   |       0.085 |     0.12094 |       0.085 |    0.085039 |       885.63 |      0.68403 |
|    7 | Accept |       0.105 |     0.11976 |       0.085 |    0.085428 |       0.3057 |      0.58118 |
|    8 | Accept |        0.21 |     0.13908 |       0.085 |     0.09566 |      0.16044 |      0.91824 |
|    9 | Accept |       0.085 |     0.18319 |       0.085 |     0.08725 |       972.19 |      0.46259 |
|   10 | Accept |         0.1 |     0.27818 |       0.085 |    0.090952 |       990.29 |        0.491 |
|   11 | Best   |        0.08 |     0.11533 |        0.08 |    0.079362 |       2.5195 |        0.291 |
|   12 | Accept |        0.09 |     0.12483 |        0.08 |     0.08402 |       14.338 |      0.44386 |
|   13 | Accept |         0.1 |     0.12396 |        0.08 |     0.08508 |    0.0022577 |      0.23803 |
|   14 | Accept |        0.11 |     0.11322 |        0.08 |    0.087378 |       0.2115 |      0.32109 |
|   15 | Best   |        0.07 |     0.15379 |        0.07 |    0.081507 |        910.2 |      0.25218 |
|   16 | Best   |       0.065 |     0.17069 |       0.065 |    0.072457 |       953.22 |      0.26253 |
|   17 | Accept |       0.075 |     0.28338 |       0.065 |    0.072554 |       998.74 |      0.23087 |
|   18 | Accept |       0.295 |     0.12518 |       0.065 |    0.072647 |       996.18 |       44.626 |
|   19 | Accept |        0.07 |     0.14373 |       0.065 |     0.06946 |       985.37 |      0.27389 |
|   20 | Accept |       0.165 |     0.14863 |       0.065 |    0.071622 |     0.065103 |      0.13679 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |       0.345 |     0.13785 |       0.065 |    0.071764 |        971.7 |       999.01 |
|   22 | Accept |        0.61 |      0.1561 |       0.065 |    0.071967 |    0.0010168 |    0.0010005 |
|   23 | Accept |       0.345 |     0.12046 |       0.065 |    0.071959 |    0.0010674 |       999.18 |
|   24 | Accept |        0.35 |     0.11488 |       0.065 |    0.071863 |    0.0010003 |       40.628 |
|   25 | Accept |        0.24 |     0.26111 |       0.065 |    0.072124 |       996.55 |       10.423 |
|   26 | Accept |        0.61 |     0.19622 |       0.065 |    0.072068 |       958.64 |    0.0010026 |
|   27 | Accept |        0.47 |     0.14015 |       0.065 |     0.07218 |       993.69 |     0.029723 |
|   28 | Accept |         0.3 |     0.12165 |       0.065 |    0.072291 |       993.15 |       170.01 |
|   29 | Accept |        0.16 |     0.27222 |       0.065 |    0.072104 |       992.81 |       3.8594 |
|   30 | Accept |       0.365 |     0.12858 |       0.065 |    0.072112 |    0.0010017 |     0.044287 |

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

Figure contains an axes. The axes with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 52.3583 seconds
Total objective function evaluation time: 5.1671

Best observed feasible point:
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

Observed objective function value = 0.065
Estimated objective function value = 0.073726
Function evaluation time = 0.17069

Best estimated feasible point (according to models):
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

Estimated objective function value = 0.072112
Estimated function evaluation time = 0.17908
svmmod = 
  ClassificationSVM
                         ResponseName: 'Y'
                CategoricalPredictors: []
                           ClassNames: [-1 1]
                       ScoreTransform: 'none'
                      NumObservations: 200
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                                Alpha: [77x1 double]
                                 Bias: -0.2352
                     KernelParameters: [1x1 struct]
                       BoxConstraints: [200x1 double]
                      ConvergenceInfo: [1x1 struct]
                      IsSupportVector: [200x1 logical]
                               Solver: 'SMO'


  Properties, Methods

최적화된 모델의 손실을 구합니다.

lossnew = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf',...
    'BoxConstraint',svmmod.HyperparameterOptimizationResults.XAtMinObjective.BoxConstraint,...
    'KernelScale',svmmod.HyperparameterOptimizationResults.XAtMinObjective.KernelScale))
lossnew = 0.0650

이 손실은 최적화 결과에서 "관측된 목적 함수 값" 아래에 보고되는 손실과 동일합니다.

최적화된 분류기를 시각화합니다.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(svmmod,xGrid);
figure;
h = nan(3,1); % Preallocation
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(svmmod.IsSupportVector,1),...
    cdata(svmmod.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

참고 항목

|

관련 항목