Main Content

bayesopt를 사용하여 교차 검증된 분류기 최적화하기

이 예제에서는 bayesopt 함수를 사용하여 SVM 분류를 최적화하는 방법을 보여줍니다.

또는 OptimizeHyperparameters 이름-값 인수를 사용하여 분류기를 최적화할 수 있습니다. 예제는 베이즈 최적화를 사용하여 분류기 피팅 최적화하기 항목을 참조하십시오.

데이터 생성하기

이 분류는 가우스 혼합 모델의 점 위치를 사용한 것입니다. The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009)의 17페이지에 이 모델에 대한 설명이 나와 있습니다. 이 모델은 평균 (1,0)과 단위 분산의 2차원 독립 정규분포로 분산되고 "녹색" 클래스에 속하는 10개 기준점을 생성하는 것으로 시작합니다. 또한, 평균 (0,1)과 단위 분산의 2차원 독립 정규분포로 분산되고 "빨간색" 클래스에 속하는 10개 기준점도 생성합니다. 다음과 같이 각 클래스(녹색과 빨간색)에 대해 임의의 100개 점을 생성합니다.

  1. 해당 색상별로 임의로 균일하게 분포하는 기준점 m을 선택합니다.

  2. 평균이 m이고 분산이 I/5인(여기서 I는 2×2 단위 행렬임) 2차원 정규분포를 띠는 독립적인 임의의 점을 생성합니다. 이 예제에서는 분산 I/50을 사용하여 최적화의 이점을 더 확실하게 보여줍니다.

각 클래스에 대해 10개 기준점을 생성합니다.

rng('default') % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

기준점을 표시합니다.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line. One or more of the lines displays its values using only markers

일부 빨간색 기준점이 녹색 기준점에 가깝기 때문에 위치만 기준으로 하여 데이터 점을 분류하는 것이 어려울 수 있습니다.

각 클래스에 대해 100개 데이터 점을 생성합니다.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

데이터 점을 표시합니다.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line. One or more of the lines displays its values using only markers

분류를 위한 데이터 준비하기

데이터를 한 행렬에 저장하고, 각 점의 클래스에 레이블을 지정하는 벡터 grp를 생성합니다. 1은 녹색 클래스를 나타내고, -1은 빨간색 클래스를 나타냅니다.

cdata = [grnpts;redpts];
grp = ones(200,1);
grp(101:200) = -1;

교차 검증 준비하기

교차 검증에 사용할 분할을 설정합니다. 이 단계에서는 최적화가 각 단계에 사용하는 훈련 세트와 테스트 세트를 정합니다.

c = cvpartition(200,'KFold',10);

베이즈 최적화를 위한 변수 준비하기

입력값 z = [rbf_sigma,boxconstraint]를 받아서 z의 교차 검증 손실 값을 반환하는 함수를 설정합니다. z의 성분을 1e-51e5 사이의 양의 로그 변환 변수로 취하십시오. 어떤 값이 좋을지 알 수 없으므로 넓은 범위를 선택하십시오.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

목적 함수

이 함수 핸들은 모수 [sigma,box]에서 교차 검증 손실을 계산합니다. 자세한 내용은 kfoldLoss 항목을 참조하십시오.

bayesopt는 변수 z를 목적 함수에 행이 하나인 테이블로 전달합니다.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

분류기 최적화하기

bayesopt를 사용하여 최상의 모수 [sigma,box]를 탐색합니다. 재현이 가능하도록 'expected-improvement-plus' 획득 함수를 선택합니다. 디폴트 획득 함수는 실행 시간에 따라 다양한 결과를 반환할 수 있습니다.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.16814 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.19002 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |     0.14965 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.14495 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |    0.098095 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |     0.14574 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.14731 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |     0.98874 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |       1.216 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |    0.061622 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |     0.76664 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.13937 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.07809 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.07018 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |    0.056946 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |     0.20583 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |     0.14375 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.09529 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |     0.12555 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |    0.069276 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |    0.064193 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |    0.057582 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |       0.131 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |    0.067127 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |    0.086768 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |    0.072945 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |     0.10464 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |    0.095945 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |     0.13501 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |    0.067905 |       0.075 |       0.075 |       3.5051 |       93.492 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 22.7968 seconds
Total objective function evaluation time: 5.9443

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.20583

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.11491

Figure contains an axes object. The axes object with title Objective function model, xlabel sigma, ylabel box contains 5 objects of type line, surface, contour. One or more of the lines displays its values using only markers These objects represent Observed points, Model mean, Next point, Model minimum feasible.

Figure contains an axes object. The axes object with title Min objective vs. Number of function evaluations, xlabel Function evaluations, ylabel Min objective contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf','BoxConstraint',z.box,'KernelScale',z.sigma))
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 22.7968
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

XAtMinEstimatedObjective 속성에서 또는 bestPoint 함수를 사용하여 최상의 추정된 실현가능점을 가져옵니다. 기본적으로, bestPoint 함수는 'min-visited-upper-confidence-interval' 기준을 사용합니다. 자세한 내용은 bestPointCriterion 이름-값 인수를 참조하십시오.

results.XAtMinEstimatedObjective
ans=1×2 table
     sigma      box  
    _______    ______

    0.32869    18.076

z = bestPoint(results)
z=1×2 table
     sigma      box  
    _______    ______

    0.32869    18.076

최상의 점을 사용하여 최적화된 새 SVM 분류기를 훈련시킵니다.

SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf', ...
    'KernelScale',z.sigma,'BoxConstraint',z.box);

서포트 벡터 분류기를 시각화하기 위해 그리드에 대해 점수를 예측합니다.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

분류 경계를 플로팅합니다.

figure
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1) ,...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');

Figure contains an axes object. The axes object contains 4 objects of type line, contour. One or more of the lines displays its values using only markers These objects represent -1, +1, Support Vectors.

새 데이터에 대한 정확도 평가하기

새 테스트 데이터 점을 생성하고 분류합니다.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);  % green = 1
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

테스트 데이터 세트에 대해 오분류율을 계산합니다.

L = loss(SVMModel,newData,grpData)
L = 
0.3500

어떤 새 데이터 점이 정확하게 분류되었는지 확인합니다. 올바르게 분류된 점에는 빨간색 원을, 잘못 분류된 점에는 검은색 원을 표시합니다.

h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**');

mydiff = (v == grpData); % Classified correctly

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','Support Vectors', ...
    '-1 (classified)','+1 (classified)', ...
    'Correctly Classified','Misclassified'}, ...
    'Location','Southeast');
hold off

Figure contains an axes object. The axes object contains 8 objects of type line, contour. One or more of the lines displays its values using only markers These objects represent -1 (training), +1 (training), Support Vectors, -1 (classified), +1 (classified), Correctly Classified, Misclassified.

참고 항목

|

관련 항목