How to check error/accuracy of K-means clustering on new dataset
조회 수: 43 (최근 30일)
이전 댓글 표시
Hi i am using the following example
I want to find the test error/score on predicted data using K means clustering how can i find that
The following example classify the new data using K means Clustering. i want to check How accurate data belong to the cluster.
rng('default') % For reproducibility
X = [randn(100,2)*0.75+ones(100,2);
randn(100,2)*0.5-ones(100,2);
randn(100,2)*0.75];
[idx,C] = kmeans(X,3);
figure
gscatter(X(:,1),X(:,2),idx,'bgm')
hold on
plot(C(:,1),C(:,2),'kx')
legend('Cluster 1','Cluster 2','Cluster 3','Cluster Centroid')
%Assign New Data to Existing Clusters
Xtest = [randn(10,2)*0.75+ones(10,2);
randn(10,2)*0.5-ones(10,2);
randn(10,2)*0.75];
[~,idx_test] = pdist2(C,Xtest,'euclidean','Smallest',1);
gscatter(Xtest(:,1),Xtest(:,2),idx_test,'bgm','ooo')
legend('Cluster 1','Cluster 2','Cluster 3','Cluster Centroid', ...
'Data classified to Cluster 1','Data classified to Cluster 2', ...
'Data classified to Cluster 3')
댓글 수: 0
채택된 답변
Adam Danz
2021년 12월 27일
편집: Adam Danz
2021년 12월 27일
> i want to check How accurate data belong to the cluster.
Part 1. kmeans clustering is not a classifier...
...so the algorithm is agnostic to any a priori group identy. Instead, kmeans clustering minimzes the sum of point-to-centroid distances summed over all k clusters (see documnetation). This confounds the notion of accuracy that is typically applied to classifiers.
If you'd like to apply a classifier instead of kmean clustering, start by perusing Matlab's documentation on classification.
Part 2. measuring how well data are clustered
This question differs from the more frequently asked question how to determine cluster size which is addressed by using the elbow method, silhouette method, or gap statistic (I'll let readers look them up). Instead, this question assumes that we know the real group each data point belongs to and we want to know how well each group is separable by kmeans.
Your demo data come from 3 distributions and all three either strongly overlap with at least one other group.
rng('default') % for reproducibility of demo
X = [randn(100,2)*0.75+ones(100,2);
randn(100,2)*0.5-ones(100,2);
randn(100,2)*0.75];
groupID = repelem(1:3,1,100)'; % known group ID for each point
figure()
gscatter(X(:,1), X(:,2), groupID, 'bgm', 'o')
title('Raw data')
You could compute the percentage of points within each cluster that belong to the dominating group within the cluster.
k = 3;
[idx,C] = kmeans(X,k);
% Plot clusters
figure();
gscatter(X(:,1), X(:,2), idx, 'bgm', 'x')
title('Clustered data')
T = array2table(zeros(k,3),'VariableName',{'cluster','dominantGroup','percentSame'});
for i = 1:k
counts = histcounts(groupID(idx==i),'BinMethod','integers','BinLimits',[1,k]);
[maxCount, maxIdx] = max(counts);
T.cluster(i) = i;
T.dominantGroup(i) = maxIdx;
T.percentSame(i) = maxCount/sum(counts);
end
disp(T)
For example, in the table above you can see that 89.87% of data in cluster 1 are from group 1, 70.37% of data in cluster 2 are from group 3, and 85.84% of data in cluster 3 are from group 2.
Part 3. measuring how well new data belong to the clusters
To reiterate part 1, kmeans is not a classifier so it doesn't care about a priori group assignments. Assuming you know the groups and have already computed the cluster centroid locations, you could determine which centroid is closest to a new data point and whether that cluster is dominated by training data points from the same group.
% Create 3 groups of data
rng('default') % for reproducibility of demo
g1 = randn(150,2)*0.75+ones(150,2);
g2 = randn(150,2)*0.5-ones(150,2);
g3 = randn(150,2)*0.75;
% Separate the data into 'training' set to
% get the cluster centroids and a 'test' set.
trainset = [g1(1:100,:); g2(1:100,:); g3(1:100,:)];
trainID = repelem(1:3,1,100)';
testset = [g1(101:end,:); g2(101:end,:); g3(101:end,:)];
testID = repelem(1:3,1,50)';
% Compute centroid locations of training set
k = 3;
[idx,C] = kmeans(trainset,k);
% Determine which group dominates each cluster
% These lines do the same thing as the loop in the previous section.
% The dominant group for cluster n is dominantGroup(n)
grpCounts = arrayfun(@(i){histcounts(trainID(idx==i),'BinMethod','integers','BinLimits',[1,k])},1:k);
[~, dominantGroup] = cellfun(@(c)max(c),grpCounts)
% Assign each test coordinate to a cluster using the same
% distance metric used by kmeans (sqeuclidean by default)
[~,cluster] = pdist2(C,testset,'squaredeuclidean','Smallest',1);
% Convert cluster IDs to training group IDs
trainGroup = dominantGroup(cluster);
% Compute the percentage of test points in each group that match
% the dominant training group within the same cluster
percentSame = arrayfun(@(i)mean(trainGroup(testID==i)==i),1:k)
The data above shows that 56% of data from groupID 1 were closest to the centroid that was dominated by groupID 1 training data.
The figure below visually depicts this by plotting the training data grouped by 3 clusters (red, green, cyan) and the test data that belongs to group 1 (black). The dominantGroup variable shows us that group 1 is primarily in cluster 2 (cyan) but the percentSame shows that only 56% of the group 1 test data are in cluster 2 which appears to be reasonable in the figure.
figure()
gscatter(trainset(:,1), trainset(:,2), idx, 'gcr')
legend(compose('Cluster %d',1:k))
hold on
testIdx = testID==1; % just plot test group 1
plot(testset(testIdx,1), testset(testIdx,2), 'ks', ...
'LineWidth',2,'MarkerSize',8,'DisplayName', 'TestGroup 1')
댓글 수: 3
Adam Danz
2021년 12월 28일
@hammad younas, @yanqi liu created a confusion matrix. There are 3 groups so the matrix is 4x4 and contains an additional column and row that shows correct and incorrect classification rates (recall that kmeans is a custering method, not a classifier).
See also
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Statistics and Machine Learning Toolbox에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!