베이즈 최적화를 사용하여 분류기 피팅 최적화하기
이 예제에서는 fitcsvm
함수 및 OptimizeHyperparameters
이름-값 인수를 사용하여 SVM 분류를 최적화하는 방법을 보여줍니다.
데이터 생성하기
이 분류는 가우스 혼합 모델의 점 위치를 사용한 것입니다. The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009)의 17페이지에 이 모델에 대한 설명이 나와 있습니다. 이 모델은 평균 (1,0)과 단위 분산의 2차원 독립 정규분포로 분산되고 "녹색" 클래스에 속하는 10개 기준점을 생성하는 것으로 시작합니다. 또한, 평균 (0,1)과 단위 분산의 2차원 독립 정규분포로 분산되고 "빨간색" 클래스에 속하는 10개 기준점도 생성합니다. 다음과 같이 각 클래스(녹색과 빨간색)에 대해 임의의 100개 점을 생성합니다.
해당 색상별로 임의로 균일하게 분포하는 기준점 m을 선택합니다.
평균이 m이고 분산이 I/5인(여기서 I는 2×2 단위 행렬임) 2차원 정규분포를 띠는 독립적인 임의의 점을 생성합니다. 이 예제에서는 분산 I/50을 사용하여 최적화의 이점을 더 확실하게 보여줍니다.
각 클래스에 대해 10개 기준점을 생성합니다.
rng('default') % For reproducibility grnpop = mvnrnd([1,0],eye(2),10); redpop = mvnrnd([0,1],eye(2),10);
기준점을 표시합니다.
plot(grnpop(:,1),grnpop(:,2),'go') hold on plot(redpop(:,1),redpop(:,2),'ro') hold off
일부 빨간색 기준점이 녹색 기준점에 가깝기 때문에 위치만 기준으로 하여 데이터 점을 분류하는 것이 어려울 수 있습니다.
각 클래스에 대해 100개 데이터 점을 생성합니다.
redpts = zeros(100,2); grnpts = redpts; for i = 1:100 grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02); redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02); end
데이터 점을 표시합니다.
figure plot(grnpts(:,1),grnpts(:,2),'go') hold on plot(redpts(:,1),redpts(:,2),'ro') hold off
분류를 위한 데이터 준비하기
데이터를 한 행렬에 저장하고, 각 점의 클래스에 레이블을 지정하는 벡터 grp
를 생성합니다. 1은 녹색 클래스를 나타내고, –1은 빨간색 클래스를 나타냅니다.
cdata = [grnpts;redpts]; grp = ones(200,1); grp(101:200) = -1;
교차 검증 준비하기
교차 검증에 사용할 분할을 설정합니다.
c = cvpartition(200,'KFold',10);
이는 선택적 단계입니다. 최적화에 사용할 분할을 지정하면, 반환된 모델에 대한 실제 교차 검증 손실을 계산할 수 있습니다.
피팅 최적화하기
양호한 피팅, 즉 교차 검증 손실을 최소화하는 최적의 하이퍼파라미터를 갖는 피팅을 구하려면 베이즈 최적화를 사용하십시오. OptimizeHyperparameters
이름-값 인수를 사용하여 최적화할 하이퍼파라미터 목록을 지정하고 HyperparameterOptimizationOptions
이름-값 인수를 사용하여 최적화 옵션을 지정합니다.
'OptimizeHyperparameters'
를 'auto'
로 지정합니다. 'auto'
옵션에는 최적화할 일반적인 하이퍼파라미터 세트가 포함되어 있습니다. fitcsvm
은 BoxConstraint
, KernelScale
, Standardize
의 최적 값을 구합니다. 재현이 가능하도록 교차 검증 분할 c
를 사용하고 'expected-improvement-plus'
획득 함수를 선택하도록 하이퍼파라미터 최적화 옵션을 설정합니다. 디폴트 획득 함수는 실행 시간에 따라 다양한 결과를 반환할 수 있습니다.
opts = struct('CVPartition',c,'AcquisitionFunctionName', ... 'expected-improvement-plus'); Mdl = fitcsvm(cdata,grp,'KernelFunction','rbf', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|====================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | BoxConstraint| KernelScale | Standardize | | | result | | runtime | (observed) | (estim.) | | | | |====================================================================================================================| | 1 | Best | 0.195 | 0.21831 | 0.195 | 0.195 | 193.54 | 0.069073 | false | | 2 | Accept | 0.345 | 0.10288 | 0.195 | 0.20398 | 43.991 | 277.86 | false | | 3 | Accept | 0.365 | 0.085025 | 0.195 | 0.20784 | 0.0056595 | 0.042141 | false | | 4 | Accept | 0.61 | 0.17863 | 0.195 | 0.31714 | 49.333 | 0.0010514 | true | | 5 | Best | 0.1 | 0.30419 | 0.1 | 0.10005 | 996.27 | 1.3081 | false | | 6 | Accept | 0.13 | 0.069174 | 0.1 | 0.10003 | 25.398 | 1.7076 | false | | 7 | Best | 0.085 | 0.1168 | 0.085 | 0.08521 | 930.3 | 0.66262 | false | | 8 | Accept | 0.35 | 0.066595 | 0.085 | 0.085172 | 0.012972 | 983.4 | true | | 9 | Best | 0.075 | 0.091629 | 0.075 | 0.077959 | 871.26 | 0.40617 | false | | 10 | Accept | 0.08 | 0.12545 | 0.075 | 0.077975 | 974.28 | 0.45314 | false | | 11 | Accept | 0.235 | 0.30216 | 0.075 | 0.077907 | 920.57 | 6.482 | true | | 12 | Accept | 0.305 | 0.070665 | 0.075 | 0.077922 | 0.0010077 | 1.0212 | true | | 13 | Best | 0.07 | 0.080775 | 0.07 | 0.073603 | 991.16 | 0.37801 | false | | 14 | Accept | 0.075 | 0.078256 | 0.07 | 0.073191 | 989.88 | 0.24951 | false | | 15 | Accept | 0.245 | 0.09407 | 0.07 | 0.073276 | 988.76 | 9.1309 | false | | 16 | Accept | 0.07 | 0.0795 | 0.07 | 0.071416 | 957.65 | 0.31271 | false | | 17 | Accept | 0.35 | 0.11798 | 0.07 | 0.071421 | 0.0010579 | 33.692 | true | | 18 | Accept | 0.085 | 0.05857 | 0.07 | 0.071274 | 48.536 | 0.32107 | false | | 19 | Accept | 0.07 | 0.082979 | 0.07 | 0.070587 | 742.56 | 0.30798 | false | | 20 | Accept | 0.61 | 0.19356 | 0.07 | 0.070796 | 865.48 | 0.0010165 | false | |====================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | BoxConstraint| KernelScale | Standardize | | | result | | runtime | (observed) | (estim.) | | | | |====================================================================================================================| | 21 | Accept | 0.1 | 0.085428 | 0.07 | 0.070715 | 970.87 | 0.14635 | true | | 22 | Accept | 0.095 | 0.12121 | 0.07 | 0.07087 | 914.88 | 0.46353 | true | | 23 | Accept | 0.07 | 0.14119 | 0.07 | 0.070473 | 982.01 | 0.2792 | false | | 24 | Accept | 0.51 | 0.51006 | 0.07 | 0.070515 | 0.0010005 | 0.014749 | true | | 25 | Accept | 0.345 | 0.16526 | 0.07 | 0.070533 | 0.0010063 | 972.18 | false | | 26 | Accept | 0.315 | 0.17117 | 0.07 | 0.07057 | 947.71 | 152.95 | true | | 27 | Accept | 0.35 | 0.36783 | 0.07 | 0.070605 | 0.0010028 | 43.62 | false | | 28 | Accept | 0.61 | 0.10346 | 0.07 | 0.070598 | 0.0010405 | 0.0010258 | false | | 29 | Accept | 0.555 | 0.07333 | 0.07 | 0.070173 | 993.56 | 0.010502 | true | | 30 | Accept | 0.07 | 0.099019 | 0.07 | 0.070158 | 965.73 | 0.25363 | true | __________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 30 reached. Total function evaluations: 30 Total elapsed time: 16.8267 seconds Total objective function evaluation time: 4.3552 Best observed feasible point: BoxConstraint KernelScale Standardize _____________ ___________ ___________ 991.16 0.37801 false Observed objective function value = 0.07 Estimated objective function value = 0.072292 Function evaluation time = 0.080775 Best estimated feasible point (according to models): BoxConstraint KernelScale Standardize _____________ ___________ ___________ 957.65 0.31271 false Estimated objective function value = 0.070158 Estimated function evaluation time = 0.092681
Mdl = ClassificationSVM ResponseName: 'Y' CategoricalPredictors: [] ClassNames: [-1 1] ScoreTransform: 'none' NumObservations: 200 HyperparameterOptimizationResults: [1x1 BayesianOptimization] Alpha: [66x1 double] Bias: -0.0910 KernelParameters: [1x1 struct] BoxConstraints: [200x1 double] ConvergenceInfo: [1x1 struct] IsSupportVector: [200x1 logical] Solver: 'SMO'
fitcsvm
은 최상의 추정된 실현가능점을 사용하는 ClassificationSVM
모델 객체를 반환합니다. 최상의 추정된 실현가능점은 베이즈 최적화 과정의 기본 가우스 과정 모델을 기반으로 하여 교차 검증 손실의 신뢰 상한을 최소화하는 하이퍼파라미터 세트입니다.
베이즈 최적화 과정은 내부적으로 목적 함수의 가우스 과정 모델을 유지합니다. 분류의 경우 목적 함수는 교차 검증된 오분류율입니다. 각 반복에 대해 최적화 과정은 가우스 과정 모델을 업데이트하고 이 모델을 사용하여 새 하이퍼파라미터 세트를 구합니다. 반복 표시의 각 라인은 새 하이퍼파라미터 세트와 다음 열 값을 표시합니다.
Objective
— 새 하이퍼파라미터 세트에서 계산되는 목적 함수 값.Objective runtime
— 목적 함수 계산 시간.Eval result
— 결과 리포트로,Accept
,Best
또는Error
로 지정됩니다.Accept
는 목적 함수가 유한 값을 반환함을 나타내고,Error
는 목적 함수가 유한 실수형 스칼라가 아닌 값을 반환함을 나타냅니다.Best
는 목적 함수가 이전에 계산된 목적 함수 값보다 작은 유한 값을 반환함을 나타냅니다.BestSoFar(observed)
— 지금까지 계산된 최소 목적 함수 값. 이 값은 현재 반복의 목적 함수 값(현재 반복의Eval result
값이Best
인 경우) 또는 이전Best
반복의 값입니다.BestSoFar(estim.)
— 각 반복마다, 업데이트된 가우스 과정 모델을 사용하여 지금까지 시도된 모든 하이퍼파라미터 세트에서 목적 함수 값의 신뢰 상한이 추정됩니다. 그런 다음 최소 신뢰 상한을 갖는 점이 선택됩니다.BestSoFar(estim.)
값은 최소 신뢰 상한을 갖는 점에서predictObjective
함수에 의해 반환되는 목적 함수 값입니다.
반복 표시 아래의 플롯은 BestSoFar(observed)
값 및 BestSoFar(estim.)
값을 각각 파란색 및 녹색으로 표시합니다.
반환된 객체 Mdl
은 최상의 추정된 실현가능점, 즉 최종 가우스 과정 모델을 기반으로 하여 최종 반복에서 BestSoFar(estim.)
값을 생성하는 하이퍼파라미터 세트를 사용합니다.
HyperparameterOptimizationResults
속성에서 또는 bestPoint
함수를 사용하여 최상의 점을 가져올 수 있습니다.
Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
957.65 0.31271 false
[x,CriterionValue,iteration] = bestPoint(Mdl.HyperparameterOptimizationResults)
x=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
957.65 0.31271 false
CriterionValue = 0.0724
iteration = 16
기본적으로, bestPoint
함수는 'min-visited-upper-confidence-interval'
기준을 사용합니다. 이 기준은 16번째 반복에서 얻은 하이퍼파라미터를 최상의 점으로 선택합니다. CriterionValue
는 최종 가우스 과정 모델에 의해 계산된 교차 검증 손실의 상한입니다. 분할 c
를 사용하여 실제 교차 검증 손실을 계산합니다.
L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c, ... 'KernelFunction','rbf','BoxConstraint',x.BoxConstraint, ... 'KernelScale',x.KernelScale,'Standardize',x.Standardize=='true'))
L_MinEstimated = 0.0700
실제 교차 검증 손실은 추정된 값에 가깝습니다. 최적화 결과의 플롯 아래에 Estimated objective function value
가 표시됩니다.
HyperparameterOptimizationResults
속성에서 또는 Criterion
을 'min-observed'
로 지정하여 최상의 관측된 실현가능점(즉, 반복 표시의 마지막 Best
점)을 추출할 수도 있습니다.
Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
991.16 0.37801 false
[x_observed,CriterionValue_observed,iteration_observed] = ... bestPoint(Mdl.HyperparameterOptimizationResults,'Criterion','min-observed')
x_observed=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
991.16 0.37801 false
CriterionValue_observed = 0.0700
iteration_observed = 13
'min-observed'
기준은 13번째 반복에서 얻은 하이퍼파라미터를 최상의 점으로 선택합니다. CriterionValue_observed
는 선택한 하이퍼파라미터를 사용하여 계산된 실제 교차 검증 손실입니다. 자세한 내용은 bestPoint
의 Criterion 이름-값 인수를 참조하십시오.
최적화된 분류기를 시각화합니다.
d = 0.02; [x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ... min(cdata(:,2)):d:max(cdata(:,2))); xGrid = [x1Grid(:),x2Grid(:)]; [~,scores] = predict(Mdl,xGrid); figure h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3) = plot(cdata(Mdl.IsSupportVector,1), ... cdata(Mdl.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
새 데이터에 대한 정확도 평가하기
새 테스트 데이터 점을 생성하고 분류합니다.
grnobj = gmdistribution(grnpop,.2*eye(2)); redobj = gmdistribution(redpop,.2*eye(2)); newData = random(grnobj,10); newData = [newData;random(redobj,10)]; grpData = ones(20,1); % green = 1 grpData(11:20) = -1; % red = -1 v = predict(Mdl,newData);
테스트 데이터 세트에 대해 오분류율을 계산합니다.
L_Test = loss(Mdl,newData,grpData)
L_Test = 0.2000
어떤 새 데이터 점이 정확하게 분류되었는지 결정합니다. 올바르게 분류된 점에는 빨간색 정사각형을, 잘못 분류된 점에는 검은색 정사각형을 지정합니다.
h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); mydiff = (v == grpData); % Classified correctly for ii = mydiff % Plot red squares around correct pts h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12); end for ii = not(mydiff) % Plot black squares around incorrect pts h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12); end legend(h,{'-1 (training)','+1 (training)','Support Vectors', ... '-1 (classified)','+1 (classified)', ... 'Correctly Classified','Misclassified'}, ... 'Location','Southeast'); hold off