Main Content

이 번역 페이지는 최신 내용을 담고 있지 않습니다. 최신 내용을 영문으로 보려면 여기를 클릭하십시오.

bayesopt를 사용하여 교차 검증 SVM 분류기 최적화하기

이 예제에서는 bayesopt 함수를 사용하여 SVM 분류를 최적화하는 방법을 보여줍니다. 이 분류는 가우스 혼합 모델의 점 위치를 사용한 것입니다. The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009)의 17페이지에 이 모델에 대한 설명이 나와 있습니다. 이 모델은 평균 (1,0)과 단위 분산의 2차원 독립 정규분포로 분산되고 "녹색" 클래스에 속하는 10개 기준점을 생성하는 것으로 시작합니다. 또한, 평균 (0,1)과 단위 분산의 2차원 독립 정규분포로 분산되고 "빨간색" 클래스에 속하는 10개 기준점도 생성합니다. 다음과 같이 각 클래스(녹색과 빨간색)에 대해 임의의 100개 점을 생성합니다.

  1. 해당 색상별로 임의로 균일하게 분포하는 기준점 m을 선택합니다.

  2. 평균이 m이고 분산이 I/5인(여기서 I는 2×2 단위 행렬임) 2차원 정규분포를 띠는 독립적인 임의의 점을 생성합니다. 이 예제에서는 분산 I/50을 사용하여 최적화의 이점을 더 확실하게 보여줍니다.

100개의 녹색 점과 100개의 빨간색 점을 생성한 후 fitcsvm을 사용하여 이들을 분류합니다. 그런 다음 교차 검증을 위해 결과로 생성된 SVM 모델의 모수를 bayesopt를 사용하여 최적화합니다.

점과 분류기 생성하기

각 클래스에 대해 10개 기준점을 생성합니다.

rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

기준점을 표시합니다.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

일부 빨간색 기준점이 녹색 기준점에 가깝기 때문에 위치만 기준으로 하여 데이터 점을 분류하는 것이 어려울 수 있습니다.

각 클래스에 대해 100개 데이터 점을 생성합니다.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

데이터 점을 표시합니다.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

분류를 위한 데이터 준비하기

데이터를 한 행렬에 저장하고, 각 점의 클래스에 레이블을 지정하는 벡터 grp를 생성합니다.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

교차 검증 준비하기

교차 검증에 사용할 분할을 설정합니다. 이 단계에서는 최적화가 각 단계에 사용하는 훈련 세트와 검정 세트를 정합니다.

c = cvpartition(200,'KFold',10);

베이즈 최적화를 위한 변수 준비하기

입력값 z = [rbf_sigma,boxconstraint]를 받아서 z의 교차 검증 손실 값을 반환하는 함수를 설정합니다. z의 성분을 1e-51e5 사이의 양의 로그 변환 변수로 취하십시오. 어떤 값이 좋을지 알 수 없으므로 넓은 범위를 선택하십시오.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

목적 함수

이 함수 핸들은 모수 [sigma,box]에서 교차 검증 손실을 계산합니다. 자세한 내용은 kfoldLoss 항목을 참조하십시오.

bayesopt는 변수 z를 목적 함수에 행이 하나인 테이블로 전달합니다.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

분류기 최적화하기

bayesopt를 사용하여 최상의 모수 [sigma,box]를 탐색합니다. 재현이 가능하도록 'expected-improvement-plus' 획득 함수를 선택합니다. 디폴트 획득 함수는 실행 시간에 따라 다양한 결과를 반환할 수 있습니다.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.49723 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.23801 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |     0.27614 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.37522 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.23259 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |     0.18231 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.21768 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |       1.493 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      1.8306 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.29761 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      1.1237 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.28838 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.15901 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.25403 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.16714 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |     0.21326 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |     0.23925 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.31168 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |     0.26502 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |     0.20048 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.13136 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.23338 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.24124 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.18186 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |     0.17713 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |     0.13144 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |      0.1642 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |     0.18907 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |     0.18856 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |     0.19236 |       0.075 |       0.075 |       3.5051 |       93.492 |

Figure contains an axes. The axes with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 62.3158 seconds
Total objective function evaluation time: 10.693

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.21326

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.23015
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: [function_handle]
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 62.3158
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

이 결과를 사용하여 최적화된 새 SVM 분류기를 훈련시킵니다.

z(1) = results.XAtMinObjective.sigma;
z(2) = results.XAtMinObjective.box;
SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'KernelScale',z(1),'BoxConstraint',z(2));

분류 경계를 플로팅합니다. 서포트 벡터 분류기를 시각화하기 위해 그리드에 대해 점수를 예측합니다.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

h = nan(3,1); % Preallocation
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

새 데이터에 대한 정확도 평가하기

새 데이터 점을 여러 개 생성하고 분류합니다.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

g = nan(7,1);
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 6 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors.

어떤 새 데이터 점이 정확하게 분류되었는지 확인합니다. 올바르게 분류된 점에는 빨간색 원을, 잘못 분류된 점에는 검은색 원을 표시합니다.

mydiff = (v == grpData); % Classified correctly
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors','Correctly Classified',...
    'Misclassified'},'Location','Southeast');
hold off

Figure contains an axes. The axes contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors, Correctly Classified, Misclassified.

참고 항목

|

관련 항목