Main Content

게 분류

이 예제에서는 게의 형태를 측정한 크기를 기준으로 게 성별을 식별하기 위해 신경망을 분류기로 사용하는 방법을 보여줍니다.

문제: 게의 분류

이 예제에서는 게 형태에 대한 측정값으로부터 게 성별을 식별하는 분류기를 만들어 봅니다. 게 형태를 종, 앞쪽 입, 뒤 너비, 길이, 너비, 깊이의 6개 특성을 기준으로 고려합니다. 이 예제의 문제는 이들 6개 특성에 대한 관측값이 주어졌을 때 게의 성별을 식별하는 것입니다.

신경망을 사용해야 하는 이유

신경망은 능숙한 분류기로 기능한다는 것이 입증되었으며, 특히 비선형 문제를 해결하는 데 적합합니다. 게의 분류와 같은 실세계의 현상은 비선형적 특성을 가지므로 신경망은 이러한 문제를 해결하기에 적합합니다.

6개의 형태 특성이 신경망의 입력값으로 사용되고 게 성별이 목표값이 됩니다. 이 신경망은 게 형태에 대한 6개 관측값이 입력값으로 주어졌을 때 해당 게가 수컷인지 암컷인지 식별할 수 있어야 합니다.

이 목표를 달성하기 위해 이전에 기록된 입력값을 신경망에 제공한 다음 원하는 목표 출력값이 생성되도록 신경망을 조정합니다. 이 과정을 신경망 훈련이라고 합니다.

데이터 준비하기

분류 문제를 위한 데이터를 입력 행렬 X와 목표 행렬 T의 두 개 행렬로 구성하여 신경망에 사용할 데이터를 준비합니다.

입력 행렬의 각 i번째 열은 게의 종, 앞쪽 입, 뒤 너비, 길이, 너비, 깊이를 나타내는 6개의 요소를 갖습니다.

이에 대응하는 목표 행렬의 각 열은 2개의 요소를 갖습니다. 게가 암컷이면 첫 번째 요소의 값이 1이고 수컷이면 두 번째 요소의 값이 1입니다. (다른 모든 요소는 0입니다.)

다음과 같이 데이터셋을 불러옵니다.

[x,t] = crab_dataset;
size(x)
ans = 1×2

     6   200

size(t)
ans = 1×2

     2   200

신경망 분류기 구축하기

다음 단계는 게 성별에 대한 식별을 학습하는 신경망을 만드는 것입니다.

신경망은 임의의 초기 가중치로 시작하므로 이 예제의 결과는 매 실행마다 약간씩 달라집니다. 이러한 임의성을 방지하기 위해 난수 시드값이 설정됩니다. 그러나 자신만의 고유한 애플리케이션인 경우에는 난수 시드값을 설정할 필요가 없습니다.

setdemorandstream(491218382)

2계층(즉, 은닉 계층 1개) 피드포워드 신경망은 은닉 계층에 뉴런이 충분히 주어진다면 어떠한 입력-출력 관계도 학습할 수 있습니다. 출력 계층이 아닌 계층은 은닉 계층이라고 부릅니다.

이 예제에서는 뉴런 10개로 구성된 하나의 은닉 계층을 사용해 보겠습니다. 일반적으로 문제가 어려울수록 더 많은 뉴런과, 그리고 경우에 따라 더 많은 계층이 필요합니다. 문제가 단순할수록 필요한 뉴런의 개수도 적어집니다.

신경망이 아직 입력 및 목표 데이터에 맞춰 구성되지 않았으므로 입력값과 출력값의 크기는 0입니다. 이 구성 작업은 신경망을 훈련시킬 때 수행됩니다.

net = patternnet(10);
view(net)

이제 신경망을 훈련시킬 준비가 되었습니다. 샘플은 자동으로 훈련 세트, 검증 세트, 테스트 세트로 나뉩니다. 훈련 세트는 신경망을 가르치는 데 사용됩니다. 훈련은 신경망이 검증 세트에 대해 계속해서 향상되는 한 지속됩니다. 테스트 세트는 신경망의 정확도를 가늠하는 완전히 독립적인 척도를 제공합니다.

[net,tr] = train(net,x,t);

Figure Neural Network Training (25-Jan-2024 15:37:31) contains an object of type uigridlayout.

훈련하는 동안 신경망의 성능이 얼마나 향상되는지 보려면 "Performance" 버튼을 클릭하거나 PLOTPERFORM을 호출하십시오.

성능은 평균제곱오차로 측정되어 로그 스케일로 표시됩니다. 신경망이 훈련됨에 따라 오차가 빠르게 감소했습니다.

훈련 세트, 검증 세트, 테스트 세트 각각에 대해 성능이 표시됩니다.

plotperform(tr)

Figure Performance (plotperform) contains an axes object. The axes object with title Best Validation Performance is 0.023041 at epoch 21, xlabel 27 Epochs, ylabel Cross Entropy (crossentropy) contains 6 objects of type line. One or more of the lines displays its values using only markers These objects represent Train, Validation, Test, Best.

분류기 테스트하기

이제 훈련된 신경망을 테스트 샘플을 사용하여 테스트할 수 있습니다. 이렇게 하면 실세계의 데이터에 적용했을 때 신경망이 얼마나 잘 동작할지 가늠할 수 있습니다.

신경망 출력값은 0과 1 사이의 범위에 있을 것이므로 vec2ind 함수를 사용하여 각 출력 벡터에서 가장 큰 요소의 위치를 클래스 인덱스로 가져올 수 있습니다.

testX = x(:,tr.testInd);
testT = t(:,tr.testInd);

testY = net(testX);
testIndices = vec2ind(testY)
testIndices = 1×30

     2     2     2     1     2     2     2     1     2     2     2     2     1     1     2     2     2     1     2     2     1     2     1     1     1     1     1     2     2     1

신경망이 데이터를 얼마나 잘 피팅했는지 알 수 있는 한 척도는 정오 플롯입니다. 여기서는 모든 샘플에 대해 혼동행렬이 플로팅되었습니다.

혼동행렬은 올바른 분류와 잘못된 분류의 비율을 보여줍니다. 올바른 분류는 행렬 대각선 위에 녹색 정사각형으로 표시됩니다. 잘못된 분류는 빨간색 정사각형으로 표시됩니다.

신경망이 올바르게 분류하도록 학습했으면 빨간색 정사각형의 비율은 매우 작아서 오분류가 적었음을 나타낼 것입니다.

그렇지 않으면 훈련을 추가로 실시하거나 더 많은 은닉 뉴런을 사용하여 신경망을 훈련시키는 것이 좋습니다.

plotconfusion(testT,testY)

Figure Confusion (plotconfusion) contains an axes object. The axes object with title Confusion Matrix, xlabel Target Class, ylabel Output Class contains 29 objects of type patch, text, line.

올바른 분류와 잘못된 분류의 비율은 다음과 같습니다.

[c,cm] = confusion(testT,testY)
c = 0.0333
cm = 2×2

    12     1
     0    17

fprintf('Percentage Correct Classification   : %f%%\n', 100*(1-c));
Percentage Correct Classification   : 96.666667%
fprintf('Percentage Incorrect Classification : %f%%\n', 100*c);
Percentage Incorrect Classification : 3.333333%

신경망이 데이터를 얼마나 잘 피팅했는지 알 수 있는 또 다른 척도는 ROC(수신자 조작 특성) 플롯입니다. ROC 플롯은 출력값의 임계값이 0에서 1까지 변함에 따라 거짓양성률과 참양성률의 관계가 어떻게 되는지 보여줍니다.

선이 왼쪽 위에 가까울수록 높은 참양성률을 얻기 위해 허용해야 하는 거짓양성의 개수가 줄어듭니다. 가장 좋은 분류기는 이 선이 왼쪽 아래 코너에서 왼쪽 위 코너 또는 오른쪽 위 코너로 향하거나 그에 가까운 형태로 나타납니다.

plotroc(testT,testY)

Figure Receiver Operating Characteristic (plotroc) contains an axes object. The axes object with title ROC, xlabel False Positive Rate, ylabel True Positive Rate contains 4 objects of type line. These objects represent Class 1, Class 2.

이 예제에서는 신경망을 사용하여 게를 분류하는 방법을 살펴보았습니다.

신경망과 그 응용 분야에 대해 더 알아보려면 다른 예제와 문서를 살펴보십시오.