Group K-fold partitioning a dataset

조회 수: 9 (최근 30일)
Ivan Abraham
Ivan Abraham 2018년 7월 31일
답변: Jaimin 2025년 1월 9일
The scikit-learn package in Python has a Group K-Fold function that allows you to split the data-set into test/train folds while ensuring the same "group" is not present in different folds. This is useful for example in studies where the same subject/person generates multiple data-points and we want to make sure the samples/data-points belonging to the same subject don't appear in both the training and testing folds.
I was wondering if MATLAB has a way to do this or enable this option in cvpartition function or in some other way. The default options only seem to preserve relative class-sizes.

답변 (1개)

Jaimin
Jaimin 2025년 1월 9일
While MATLAB does not offer a built-in function exactly like scikit-learn's GroupKFold, you can achieve similar results by manually creating your own group-based cross-validation partitions.
Here is how you can do it:
  1. Determine the unique groups in your dataset.
  2. Randomly shuffle these groups and then split them into k folds.
  3. Assign each data point to a fold based on its group.
% Sample data
data = rand(100, 5); % 100 samples, 5 features
labels = randi([0, 1], 100, 1); % Binary labels
groups = randi([1, 20], 100, 1); % 20 unique groups
% Number of folds
k = 5;
% Get unique groups
uniqueGroups = unique(groups);
% Shuffle groups
shuffledGroups = uniqueGroups(randperm(length(uniqueGroups)));
% Split groups into k folds
folds = cell(k, 1);
foldSize = ceil(length(shuffledGroups) / k);
for i = 1:k
startIdx = (i-1) * foldSize + 1;
endIdx = min(i * foldSize, length(shuffledGroups));
folds{i} = shuffledGroups(startIdx:endIdx);
end
% Create cross-validation partitions
cvIndices = zeros(size(groups));
for i = 1:k
testGroups = folds{i};
testIdx = ismember(groups, testGroups);
cvIndices(testIdx) = i;
end
for i = 1:k
testIdx = (cvIndices == i);
trainIdx = ~testIdx;
trainData = data(trainIdx, :);
trainLabels = labels(trainIdx);
testData = data(testIdx, :);
testLabels = labels(testIdx);
fprintf('Fold %d: Train on %d samples, Test on %d samples\n', i, sum(trainIdx), sum(testIdx));
end
For more information kindly refer following MathWorks documentation.

카테고리

Help CenterFile Exchange에서 Discriminant Analysis에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by