Main Content

분류

이 예제에서는 판별분석, 나이브 베이즈 분류기, 결정 트리를 사용하여 분류를 수행하는 방법을 보여줍니다. 여러 변수(예측 변수라고 함)의 측정값과 각각의 알려진 클래스 레이블이 있는 관측값을 포함하는 데이터 세트가 있다고 가정하겠습니다. 새 관측값의 예측 변수 값을 구하면 이러한 관측값이 속해 있을 수 있는 클래스를 확인할 수 있을까요? 이것이 바로 분류의 문제입니다.

피셔(Fisher)의 붓꽃 데이터

피셔의 붓꽃 데이터는 150개 붓꽃 표본의 꽃받침 길이, 꽃받침 너비, 꽃잎 길이, 꽃잎 너비에 대한 측정값으로 구성되어 있습니다. 세 종에서 각각 50개의 표본이 추출되었습니다. 데이터를 불러오고 꽃받침 측정값이 두 종 사이에 어떻게 다른지 확인해 봅니다. 꽃받침 측정값을 포함하는 두 열을 사용할 수 있습니다.

load fisheriris
f = figure;
gscatter(meas(:,1), meas(:,2), species,'rgb','osd');
xlabel('Sepal length');
ylabel('Sepal width');

Figure contains an axes object. The axes object with xlabel Sepal length, ylabel Sepal width contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent setosa, versicolor, virginica.

N = size(meas,1);

한 붓꽃의 꽃받침과 꽃잎을 측정하고 이러한 측정값을 기반으로 하여 해당 종을 확인해야 한다고 가정하겠습니다. 이 문제를 해결할 수 있는 한 가지 방법으로 판별분석(discriminant analysis)이 알려져 있습니다.

선형 판별분석 및 2차 판별분석

fitcdiscr 함수는 여러 유형의 판별분석을 사용하여 분류를 수행할 수 있습니다. 먼저, 디폴트 선형 판별분석(LDA)을 사용하여 데이터를 분류하겠습니다.

lda = fitcdiscr(meas(:,1:2),species);
ldaClass = resubPredict(lda);

알려진 클래스 레이블을 갖는 관측값을 대개 훈련 데이터(training data)라고 합니다. 이제 재대입 오차, 즉 훈련 세트에 대한 오분류 오차(오분류된 관측값의 비율)를 계산합니다.

ldaResubErr = resubLoss(lda)
ldaResubErr = 0.2000

또한, 훈련 세트에 대한 혼동행렬을 계산할 수도 있습니다. 혼동행렬에는 알려진 클래스 레이블과 예측된 클래스 레이블에 대한 정보가 포함됩니다. 일반적으로, 혼동행렬에서 (i,j) 요소는 알려진 클래스 레이블이 클래스 i이고 예측된 클래스가 j인 표본 개수입니다. 대각선 요소는 올바르게 분류된 관측값을 나타냅니다.

figure
ldaResubCM = confusionchart(species,ldaClass);

Figure contains an object of type ConfusionMatrixChart.

150개 훈련 측정값의 20%, 즉 30개 관측값이 선형 판별분석 함수에 의해 오분류되었습니다. 오분류된 점에 X를 그려 이러한 점을 표시할 수 있습니다.

figure(f)
bad = ~strcmp(ldaClass,species);
hold on;
plot(meas(bad,1), meas(bad,2), 'kx');
hold off;

Figure contains an axes object. The axes object with xlabel Sepal length, ylabel Sepal width contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent setosa, versicolor, virginica.

이 함수는 평면을 선으로 분할된 영역으로 나누고 여러 영역을 여러 종에 할당했습니다. 이러한 영역을 시각화하는 한 가지 방법은 (x,y) 값으로 구성된 그리드를 생성하고 분류 함수를 해당 그리드에 적용하는 것입니다.

[x,y] = meshgrid(4:.1:8,2:.1:4.5);
x = x(:);
y = y(:);
j = classify([x y],meas(:,1:2),species);
gscatter(x,y,j,'grb','sod')

Figure contains an axes object. The axes object with xlabel x, ylabel y contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent versicolor, setosa, virginica.

데이터 세트에 따라 여러 클래스의 영역이 선으로 제대로 분리되지 않는 경우가 있습니다. 이런 경우, 선형 판별분석이 적합하지 않습니다. 대신, 데이터에 2차 판별분석(QDA)을 시도해 볼 수 있습니다.

2차 판별분석에 대한 재대입 오차를 계산합니다.

qda = fitcdiscr(meas(:,1:2),species,'DiscrimType','quadratic');
qdaResubErr = resubLoss(qda)
qdaResubErr = 0.2000

재대입 오차를 계산했습니다. 일반적으로 많은 사람들이 독립적인 세트에 대해 예상되는 예측 오차인 검정 오차(일반화 오차라고도 함)에 관심이 많습니다. 실제로, 재대입 오차가 검정 오차보다 작게 추정될 가능성이 높습니다.

이 경우, 레이블이 지정된 다른 데이터 세트가 없지만 교차 검증을 수행하여 이를 시뮬레이션할 수 있습니다. 층화된 10겹 교차 검증이 분류 알고리즘에서 검정 오차를 추정하는 데 흔히 사용되는 방법입니다. 이 방법은 훈련 세트를 10개의 서로소 부분 집합으로 임의로 나눕니다. 각 부분 집합은 크기가 대략 같으며 그 클래스 비율도 훈련 세트에서의 클래스 비율과 대략 같습니다. 한 부분 집합을 제외하고 나머지 9개 부분 집합을 사용하여 분류 모델을 훈련시킨 후, 제외한 부분 집합을 훈련된 모델을 사용하여 분류합니다. 10개 부분 집합을 한 번에 하나씩 제외시키며 이 작업을 반복할 수 있습니다.

교차 검증이 데이터를 임의로 나누기 때문에 검증 결과는 초기 난수 시드값에 따라 달라집니다. 이 예제의 결과를 정확히 재현하기 위해 다음 명령을 실행합니다.

rng(0,'twister');

먼저, cvpartition을 사용하여 10개의 층화된 서로소 부분 집합을 생성합니다.

cp = cvpartition(species,'KFold',10)
cp = 
K-fold cross validation partition
   NumObservations: 150
       NumTestSets: 10
         TrainSize: 135  135  135  135  135  135  135  135  135  135
          TestSize: 15  15  15  15  15  15  15  15  15  15
          IsCustom: 0

crossval 메서드와 kfoldLoss 메서드는 주어진 데이터 분할 cp를 사용하여 LDA와 QDA 모두에 대한 오분류 오차를 추정할 수 있습니다.

층화된 10겹 교차 검증을 사용하여 LDA에 대한 실제 검정 오차를 추정합니다.

cvlda = crossval(lda,'CVPartition',cp);
ldaCVErr = kfoldLoss(cvlda)
ldaCVErr = 0.2000

이 데이터에서 LDA 교차 검증 오차는 LDA 재대입 오차와 동일한 값을 가집니다.

층화된 10겹 교차 검증을 사용하여 QDA에 대한 실제 검정 오차를 추정합니다.

cvqda = crossval(qda,'CVPartition',cp);
qdaCVErr = kfoldLoss(cvqda)
qdaCVErr = 0.2200

QDA가 LDA보다 교차 검증 오차가 약간 더 큽니다. 이는 복잡한 모델보다 단순한 모델에서 더 견줄 만한, 즉 더 나은 성능을 얻을 수 있음을 보여줍니다.

나이브 베이즈 분류기

fitcdiscr 함수에는 두 가지 다른 유형 'DiagLinear''DiagQuadratic'이 있습니다. 이는 'linear''quadratic'과 유사하지만 대각 공분산 행렬 추정값을 사용합니다. 이러한 대각 옵션은 클래스 레이블이 주어진 경우 변수가 조건부로 독립적임을 가정하므로 나이브 베이즈 분류기를 보여주는 구체적인 예가 됩니다. 나이브 베이즈 분류기는 가장 많이 사용되는 분류기 중 하나입니다. 일반적으로 두 변수가 클래스에 조건부로 독립적이라는 가정이 맞지는 않지만, 실제로 많은 데이터 세트에서 나이브 베이즈 분류기가 제대로 역할을 하는 것으로 확인되었습니다.

fitcnb 함수를 사용하면 더 일반적인 나이브 베이즈 분류기 유형을 생성할 수 있습니다.

먼저, 가우스 분포를 사용하여 각 클래스에 포함된 각각의 변수를 모델링합니다. 재대입 오차와 교차 검증 오차를 계산할 수 있습니다.

nbGau = fitcnb(meas(:,1:2), species);
nbGauResubErr = resubLoss(nbGau)
nbGauResubErr = 0.2200
nbGauCV = crossval(nbGau, 'CVPartition',cp);
nbGauCVErr = kfoldLoss(nbGauCV)
nbGauCVErr = 0.2200
labels = predict(nbGau, [x y]);
gscatter(x,y,labels,'grb','sod')

Figure contains an axes object. The axes object with xlabel x, ylabel y contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent versicolor, setosa, virginica.

지금까지는 각 클래스의 변수가 다변량 정규분포를 가진다고 가정했습니다. 이 가정이 타당한 경우가 많지만, 이렇게 가정하고 싶지 않거나 이러한 가정이 유효하지 않다는 것을 확실히 알 수 있는 경우가 있습니다. 이제, 더 유연한 비모수적 기법인 커널 밀도 추정을 사용하여 각 클래스에 포함된 각 변수를 모델링하겠습니다. 여기서는 커널을 box로 설정합니다.

nbKD = fitcnb(meas(:,1:2), species, 'DistributionNames','kernel', 'Kernel','box');
nbKDResubErr = resubLoss(nbKD)
nbKDResubErr = 0.2067
nbKDCV = crossval(nbKD, 'CVPartition',cp);
nbKDCVErr = kfoldLoss(nbKDCV)
nbKDCVErr = 0.2133
labels = predict(nbKD, [x y]);
gscatter(x,y,labels,'rgb','osd')

Figure contains an axes object. The axes object with xlabel x, ylabel y contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent setosa, versicolor, virginica.

이 데이터 세트에 대해 커널 밀도 추정을 사용하는 나이브 베이즈 분류기는 가우스 분포를 사용하는 나이브 베이즈 분류기보다 더 작은 재대입 오차와 교차 검증 오차를 가집니다.

결정 트리

결정 트리를 기반으로 하는 또 다른 분류 알고리즘이 있습니다. 결정 트리는 "꽃받침 길이가 5.45보다 작은 경우 이 표본을 setosa로 분류한다"와 같은 단순한 규칙 집합입니다. 또한, 결정 트리는 각 클래스의 변수 분포에 대한 어떠한 가정도 필요로 하지 않으므로 비모수적 기법입니다.

fitctree 함수는 결정 트리를 생성합니다. 붓꽃 데이터에 대한 결정 트리를 생성하고 결정 트리가 붓꽃을 종으로 얼마나 잘 분류하는지를 확인해 봅니다.

t = fitctree(meas(:,1:2), species,'PredictorNames',{'SL' 'SW' });

결정 트리 방법이 평면을 나누는 과정을 보면 흥미롭습니다. 위에 나와 있는 동일한 기법을 사용하여 각 종에 할당된 영역을 시각화합니다.

[grpname,node] = predict(t,[x y]);
gscatter(x,y,grpname,'grb','sod')

Figure contains an axes object. The axes object with xlabel x, ylabel y contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent versicolor, setosa, virginica.

결정 트리를 시각화하는 또 다른 방법은 결정 규칙과 클래스 할당을 보여주는 도식을 그리는 것입니다.

view(t,'Mode','graph');

Figure Classification tree viewer contains an axes object and other objects of type uimenu, uicontrol. The axes object contains 60 objects of type line, text. One or more of the lines displays its values using only markers

이 복잡해 보이는 트리는 "SL < 5.45" 형식의 일련의 규칙을 사용하여 각 표본을 19개 종점 노드 중 하나로 분류합니다. 하나의 관측값에 대한 종 할당을 확인하려면 최상위 노드에서 시작하여 규칙을 적용하십시오. 점이 규칙을 충족하면 왼쪽 경로를 취하고, 충족하지 않으면 오른쪽 경로를 취합니다. 최종적으로, 관측값을 3개 종 중 하나에 할당하는 종점 노드에 도달하게 됩니다.

결정 트리에 대한 재대입 오차와 교차 검증 오차를 계산합니다.

dtResubErr = resubLoss(t)
dtResubErr = 0.1333
cvt = crossval(t,'CVPartition',cp);
dtCVErr = kfoldLoss(cvt)
dtCVErr = 0.3000

결정 트리 알고리즘의 경우 교차 검증 오차 추정값이 재대입 오차 추정값보다 훨씬 더 큽니다. 이는 생성된 트리가 훈련 세트에 과적합되었다는 것을 보여줍니다. 다시 말해서 이는 원래 훈련 세트를 올바르게 분류하는 트리이지만, 트리의 구조가 이 특정 훈련 세트에 민감하므로, 새 데이터에 대한 성능이 저하될 가능성이 높습니다. 대개의 경우, 새 데이터에 대해 복잡한 트리보다 성능이 더 높은 단순한 트리를 찾을 수 있습니다.

트리를 가지치기해 봅니다. 먼저, 원래 트리의 다양한 부분 집합에 대한 재대입 오차를 계산합니다. 그런 다음, 이 하위 트리에 대한 교차 검증 오차를 계산합니다. 그래프를 보면 재대입 오차가 지나치게 낙관적임을 알 수 있습니다. 트리 크기가 증가하면 항상 재대입 오차가 감소하지만, 특정 점을 넘어가면 트리 크기가 증가함에 따라 교차 검증 오차율이 증가합니다.

resubcost = resubLoss(t,'Subtrees','all');
[cost,secost,ntermnodes,bestlevel] = cvloss(t,'Subtrees','all');
plot(ntermnodes,cost,'b-', ntermnodes,resubcost,'r--')
figure(gcf);
xlabel('Number of terminal nodes');
ylabel('Cost (misclassification error)')
legend('Cross-validation','Resubstitution')

Figure contains an axes object. The axes object with xlabel Number of terminal nodes, ylabel Cost (misclassification error) contains 2 objects of type line. These objects represent Cross-validation, Resubstitution.

어떤 트리를 선택해야 할까요? 간단한 규칙은 교차 검증 오차가 가장 작은 트리를 선택하는 것입니다. 이 규칙에 만족할 수도 있지만, 더 복잡한 트리와 거의 동일하게 양호한 더 단순한 트리를 사용하고자 할 수 있습니다. 이 예제에서는 최솟값의 1 표준 오차 내에서 가장 단순한 트리를 선택하겠습니다. 이것이 ClassificationTreecvloss 메서드에서 사용하는 디폴트 규칙입니다.

최소 비용과 1 표준 오차를 더한 값에 해당하는 절단 값을 계산하여 그래프에 이를 표시할 수 있습니다. cvloss 메서드에서 계산되는 "최상의" 수준은 이 절단 값 아래에 있는 최솟값 트리입니다. (참고로, bestlevel=0은 가지치기 안 된 트리에 해당하므로 cvloss에서 계산되는 벡터 출력값에 대한 인덱스로 이를 사용하려면 1을 더해야 합니다.)

[mincost,minloc] = min(cost);
cutoff = mincost + secost(minloc);
hold on
plot([0 20], [cutoff cutoff], 'k:')
plot(ntermnodes(bestlevel+1), cost(bestlevel+1), 'mo')
legend('Cross-validation','Resubstitution','Min + 1 std. err.','Best choice')
hold off

Figure contains an axes object. The axes object with xlabel Number of terminal nodes, ylabel Cost (misclassification error) contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent Cross-validation, Resubstitution, Min + 1 std. err., Best choice.

마지막으로, 가지치기된 트리를 살펴보고 이 트리에 대해 추정된 오분류 오차를 계산할 수 있습니다.

pt = prune(t,'Level',bestlevel);
view(pt,'Mode','graph')

Figure Classification tree viewer contains an axes object and other objects of type uimenu, uicontrol. The axes object contains 18 objects of type line, text. One or more of the lines displays its values using only markers

cost(bestlevel+1)
ans = 0.2467

결론

이 예제에서는 MATLAB®에서 Statistics and Machine Learning Toolbox™ 함수를 사용하여 분류를 수행하는 방법을 보여줍니다.

이 예제의 의도는 피셔의 붓꽃 데이터에 대해 최적의 분석을 수행하는 것이 아닙니다. 사실 꽃받침 측정값 대신에 꽃잎 측정값을 사용하거나 꽃받침 측정값을 꽃잎 측정값과 함께 사용하면 더 나은 분류 결과를 얻을 수 있습니다. 또한, 이 예제는 여러 분류 알고리즘의 장점과 단점을 비교하려는 것이 아닙니다. 다른 데이터 세트에 대한 분석을 수행하고 여러 알고리즘을 비교하는 것이 유용할 수 있습니다. 또한, 다른 분류 알고리즘을 구현하는 Toolbox 함수도 있습니다. 예를 들어, Bootstrap Aggregation (Bagging) of Classification Trees Using TreeBagger 예제에 설명된 것처럼 TreeBagger를 사용하여 결정 트리의 앙상블에 대해 부트스트랩 집계를 수행할 수 있습니다.