Main Content

베이즈 최적화를 사용하여 분류기 피팅 최적화하기

이 예제에서는 fitcsvm 함수 및 OptimizeHyperparameters 이름-값 인수를 사용하여 SVM 분류를 최적화하는 방법을 보여줍니다.

데이터 생성하기

이 분류는 가우스 혼합 모델의 점 위치를 사용한 것입니다. The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009)의 17페이지에 이 모델에 대한 설명이 나와 있습니다. 이 모델은 평균 (1,0)과 단위 분산의 2차원 독립 정규분포로 분산되고 "녹색" 클래스에 속하는 10개 기준점을 생성하는 것으로 시작합니다. 또한, 평균 (0,1)과 단위 분산의 2차원 독립 정규분포로 분산되고 "빨간색" 클래스에 속하는 10개 기준점도 생성합니다. 다음과 같이 각 클래스(녹색과 빨간색)에 대해 임의의 100개 점을 생성합니다.

  1. 해당 색상별로 임의로 균일하게 분포하는 기준점 m을 선택합니다.

  2. 평균이 m이고 분산이 I/5인(여기서 I는 2×2 단위 행렬임) 2차원 정규분포를 띠는 독립적인 임의의 점을 생성합니다. 이 예제에서는 분산 I/50을 사용하여 최적화의 이점을 더 확실하게 보여줍니다.

각 클래스에 대해 10개 기준점을 생성합니다.

rng('default') % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

기준점을 표시합니다.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

일부 빨간색 기준점이 녹색 기준점에 가깝기 때문에 위치만 기준으로 하여 데이터 점을 분류하는 것이 어려울 수 있습니다.

각 클래스에 대해 100개 데이터 점을 생성합니다.

redpts = zeros(100,2);
grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

데이터 점을 표시합니다.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

분류를 위한 데이터 준비하기

데이터를 한 행렬에 저장하고, 각 점의 클래스에 레이블을 지정하는 벡터 grp를 생성합니다. 1은 녹색 클래스를 나타내고, –1은 빨간색 클래스를 나타냅니다.

cdata = [grnpts;redpts];
grp = ones(200,1);
grp(101:200) = -1;

교차 검증 준비하기

교차 검증에 사용할 분할을 설정합니다.

c = cvpartition(200,'KFold',10);

이는 선택적 단계입니다. 최적화에 사용할 분할을 지정하면, 반환된 모델에 대한 실제 교차 검증 손실을 계산할 수 있습니다.

피팅 최적화하기

양호한 피팅, 즉 교차 검증 손실을 최소화하는 최적의 하이퍼파라미터를 갖는 피팅을 구하려면 베이즈 최적화를 사용하십시오. OptimizeHyperparameters 이름-값 인수를 사용하여 최적화할 하이퍼파라미터 목록을 지정하고 HyperparameterOptimizationOptions 이름-값 인수를 사용하여 최적화 옵션을 지정합니다.

'OptimizeHyperparameters''auto'로 지정합니다. 'auto' 옵션에는 최적화할 일반적인 하이퍼파라미터 세트가 포함되어 있습니다. fitcsvmBoxConstraintKernelScale의 최적 값을 구합니다. 재현이 가능하도록 교차 검증 분할 c를 사용하고 'expected-improvement-plus' 획득 함수를 선택하도록 하이퍼파라미터 최적화 옵션을 설정합니다. 디폴트 획득 함수는 실행 시간에 따라 다양한 결과를 반환할 수 있습니다.

opts = struct('CVPartition',c,'AcquisitionFunctionName','expected-improvement-plus');
Mdl = fitcsvm(cdata,grp,'KernelFunction','rbf', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |       0.345 |     0.26612 |       0.345 |       0.345 |      0.00474 |       306.44 |
|    2 | Best   |       0.115 |     0.16757 |       0.115 |     0.12678 |       430.31 |       1.4864 |
|    3 | Accept |        0.52 |     0.21336 |       0.115 |      0.1152 |     0.028415 |     0.014369 |
|    4 | Accept |        0.61 |     0.41833 |       0.115 |     0.11504 |       133.94 |    0.0031427 |
|    5 | Accept |        0.34 |     0.46056 |       0.115 |     0.11504 |     0.010993 |       5.7742 |
|    6 | Best   |       0.085 |     0.25465 |       0.085 |    0.085039 |       885.63 |      0.68403 |
|    7 | Accept |       0.105 |     0.25751 |       0.085 |    0.085428 |       0.3057 |      0.58118 |
|    8 | Accept |        0.21 |     0.28915 |       0.085 |     0.09566 |      0.16044 |      0.91824 |
|    9 | Accept |       0.085 |     0.30816 |       0.085 |     0.08725 |       972.19 |      0.46259 |
|   10 | Accept |         0.1 |     0.34457 |       0.085 |    0.090952 |       990.29 |        0.491 |
|   11 | Best   |        0.08 |     0.21805 |        0.08 |    0.079362 |       2.5195 |        0.291 |
|   12 | Accept |        0.09 |     0.24212 |        0.08 |     0.08402 |       14.338 |      0.44386 |
|   13 | Accept |         0.1 |     0.23766 |        0.08 |     0.08508 |    0.0022577 |      0.23803 |
|   14 | Accept |        0.11 |     0.24347 |        0.08 |    0.087378 |       0.2115 |      0.32109 |
|   15 | Best   |        0.07 |     0.30411 |        0.07 |    0.081507 |        910.2 |      0.25218 |
|   16 | Best   |       0.065 |     0.24431 |       0.065 |    0.072457 |       953.22 |      0.26253 |
|   17 | Accept |       0.075 |     0.33287 |       0.065 |    0.072554 |       998.74 |      0.23087 |
|   18 | Accept |       0.295 |     0.21231 |       0.065 |    0.072647 |       996.18 |       44.626 |
|   19 | Accept |        0.07 |     0.26876 |       0.065 |     0.06946 |       985.37 |      0.27389 |
|   20 | Accept |       0.165 |     0.24669 |       0.065 |    0.071622 |     0.065103 |      0.13679 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |       0.345 |     0.20097 |       0.065 |    0.071764 |        971.7 |       999.01 |
|   22 | Accept |        0.61 |      0.2416 |       0.065 |    0.071967 |    0.0010168 |    0.0010005 |
|   23 | Accept |       0.345 |     0.26803 |       0.065 |    0.071959 |    0.0011459 |       995.89 |
|   24 | Accept |        0.35 |     0.23608 |       0.065 |    0.071863 |    0.0010003 |       40.628 |
|   25 | Accept |        0.24 |     0.39188 |       0.065 |    0.072124 |       996.55 |       10.423 |
|   26 | Accept |        0.61 |     0.46697 |       0.065 |    0.072067 |       994.71 |    0.0010063 |
|   27 | Accept |        0.47 |     0.28997 |       0.065 |     0.07218 |       993.69 |     0.029723 |
|   28 | Accept |         0.3 |     0.24924 |       0.065 |    0.072291 |       993.15 |       170.01 |
|   29 | Accept |        0.16 |     0.37085 |       0.065 |    0.072103 |       992.81 |       3.8594 |
|   30 | Accept |       0.365 |     0.19017 |       0.065 |    0.072112 |    0.0010017 |     0.044287 |

Figure contains an axes object. The axes object with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

Figure contains an axes object. The axes object with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 42.2011 seconds
Total objective function evaluation time: 8.4361

Best observed feasible point:
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

Observed objective function value = 0.065
Estimated objective function value = 0.073726
Function evaluation time = 0.24431

Best estimated feasible point (according to models):
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

Estimated objective function value = 0.072112
Estimated function evaluation time = 0.28413
Mdl = 
  ClassificationSVM
                         ResponseName: 'Y'
                CategoricalPredictors: []
                           ClassNames: [-1 1]
                       ScoreTransform: 'none'
                      NumObservations: 200
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                                Alpha: [77x1 double]
                                 Bias: -0.2352
                     KernelParameters: [1x1 struct]
                       BoxConstraints: [200x1 double]
                      ConvergenceInfo: [1x1 struct]
                      IsSupportVector: [200x1 logical]
                               Solver: 'SMO'


  Properties, Methods

fitcsvm은 최상의 추정된 실현가능점을 사용하는 ClassificationSVM 모델 객체를 반환합니다. 최상의 추정된 실현가능점은 베이즈 최적화 과정의 기본 가우스 과정 모델을 기반으로 하여 교차 검증 손실의 신뢰 상한을 최소화하는 하이퍼파라미터 세트입니다.

베이즈 최적화 과정은 내부적으로 목적 함수의 가우스 과정 모델을 유지합니다. 분류의 경우 목적 함수는 교차 검증된 오분류율입니다. 각 반복에 대해 최적화 과정은 가우스 과정 모델을 업데이트하고 이 모델을 사용하여 새 하이퍼파라미터 세트를 구합니다. 반복 표시의 각 라인은 새 하이퍼파라미터 세트와 다음 열 값을 표시합니다.

  • Objective — 새 하이퍼파라미터 세트에서 계산되는 목적 함수 값.

  • Objective runtime — 목적 함수 계산 시간.

  • Eval result — 결과 리포트로, Accept, Best 또는 Error로 지정됩니다. Accept는 목적 함수가 유한 값을 반환함을 나타내고, Error는 목적 함수가 유한 실수형 스칼라가 아닌 값을 반환함을 나타냅니다. Best는 목적 함수가 이전에 계산된 목적 함수 값보다 작은 유한 값을 반환함을 나타냅니다.

  • BestSoFar(observed) — 지금까지 계산된 최소 목적 함수 값. 이 값은 현재 반복의 목적 함수 값(현재 반복의 Eval result 값이 Best인 경우) 또는 이전 Best 반복의 값입니다.

  • BestSoFar(estim.) — 각 반복마다, 업데이트된 가우스 과정 모델을 사용하여 지금까지 시도된 모든 하이퍼파라미터 세트에서 목적 함수 값의 신뢰 상한이 추정됩니다. 그런 다음 최소 신뢰 상한을 갖는 점이 선택됩니다. BestSoFar(estim.) 값은 최소 신뢰 상한을 갖는 점에서predictObjective 함수에 의해 반환되는 목적 함수 값입니다.

반복 표시 아래의 플롯은 BestSoFar(observed) 값 및 BestSoFar(estim.) 값을 각각 파란색 및 녹색으로 표시합니다.

반환된 객체 Mdl은 최상의 추정된 실현가능점, 즉 최종 가우스 과정 모델을 기반으로 하여 최종 반복에서 BestSoFar(estim.) 값을 생성하는 하이퍼파라미터 세트를 사용합니다.

HyperparameterOptimizationResults 속성에서 또는 bestPoint 함수를 사용하여 최상의 점을 가져올 수 있습니다.

Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

[x,CriterionValue,iteration] = bestPoint(Mdl.HyperparameterOptimizationResults)
x=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

CriterionValue = 0.0888
iteration = 19

기본적으로, bestPoint 함수는 'min-visited-upper-confidence-interval' 기준을 사용합니다. 이 기준은 19번째 반복에서 얻은 하이퍼파라미터를 최상의 점으로 선택합니다. CriterionValue는 최종 가우스 과정 모델에 의해 계산된 교차 검증 손실의 상한입니다. 분할 c를 사용하여 실제 교차 검증 손실을 계산합니다.

L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf', ...
    'BoxConstraint',x.BoxConstraint,'KernelScale',x.KernelScale))
L_MinEstimated = 0.0700

실제 교차 검증 손실은 추정된 값에 가깝습니다. 최적화 결과의 플롯 아래에 Estimated objective function value가 표시됩니다.

HyperparameterOptimizationResults 속성에서 또는 Criterion'min-observed'로 지정하여 최상의 관측된 실현가능점(즉, 반복 표시의 마지막 Best 점)을 추출할 수도 있습니다.

Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

[x_observed,CriterionValue_observed,iteration_observed] = bestPoint(Mdl.HyperparameterOptimizationResults,'Criterion','min-observed')
x_observed=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

CriterionValue_observed = 0.0650
iteration_observed = 16

'min-observed' 기준은 16번째 반복에서 얻은 하이퍼파라미터를 최상의 점으로 선택합니다. CriterionValue_observed는 선택한 하이퍼파라미터를 사용하여 계산된 실제 교차 검증 손실입니다. 자세한 내용은 bestPointCriterion 이름-값 인수를 참조하십시오.

최적화된 분류기를 시각화합니다.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(Mdl,xGrid);

figure
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(Mdl.IsSupportVector,1), ...
    cdata(Mdl.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');

Figure contains an axes object. The axes object contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

새 데이터에 대한 정확도 평가하기

새 검정 데이터 점을 생성하고 분류합니다.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1); % green = 1
grpData(11:20) = -1; % red = -1

v = predict(Mdl,newData);

검정 데이터 세트에 대해 오분류율을 계산합니다.

L_Test = loss(Mdl,newData,grpData)
L_Test = 0.3500

어떤 새 데이터 점이 정확하게 분류되었는지 결정합니다. 올바르게 분류된 점에는 빨간색 정사각형을, 잘못 분류된 점에는 검은색 정사각형을 지정합니다.

h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**');

mydiff = (v == grpData); % Classified correctly

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','Support Vectors', ...
    '-1 (classified)','+1 (classified)', ...
    'Correctly Classified','Misclassified'}, ...
    'Location','Southeast');
hold off

Figure contains an axes object. The axes object contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), Support Vectors, -1 (classified), +1 (classified), Correctly Classified, Misclassified.

참고 항목

|

관련 항목