5 fold cross validation code for a dataset

조회 수: 12 (최근 30일)
Subhasmita Ghosh
Subhasmita Ghosh 2018년 5월 24일
답변: Shubham 2024년 9월 4일
I want to split my data set into train set as well as test set using 5 fold cross validation .

답변 (1개)

Shubham
Shubham 2024년 9월 4일
Hi Subhasmita,
In MATLAB, you can perform k-fold cross-validation to split your dataset into training and test sets. In k-fold cross-validation, the dataset is divided into k subsets (folds). The model is trained k times, each time using a different fold as the test set and the remaining folds as the training set.
Here's how you can perform 5-fold cross-validation in MATLAB:
% Load your dataset
load fisheriris % Example dataset
X = meas; % Features
y = species; % Labels
% Define the number of folds
k = 5;
% Create a cross-validation partition
cv = cvpartition(y, 'KFold', k);
% Initialize variable to store accuracy for each fold
accuracy = zeros(k, 1);
for i = 1:k
% Get the training and test indices for the current fold
trainIdx = training(cv, i);
testIdx = test(cv, i);
% Split the data into training and test sets for this fold
XTrain = X(trainIdx, :);
yTrain = y(trainIdx, :);
XTest = X(testIdx, :);
yTest = y(testIdx, :);
% Train the model on the training set
model = fitcsvm(XTrain, yTrain);
% Test the model on the test set
predictions = predict(model, XTest);
% Calculate accuracy for the current fold
accuracy(i) = sum(predictions == yTest) / length(yTest);
% Display accuracy for the current fold
fprintf('Fold %d Accuracy: %.2f%%\n', i, accuracy(i) * 100);
end
% Calculate and display the average accuracy across all folds
averageAccuracy = mean(accuracy);
fprintf('Average Accuracy: %.2f%%\n', averageAccuracy * 100);
Explanation
  1. Data Loading: We use the fisheriris dataset for demonstration, where X contains the features and y contains the labels.
  2. Cross-Validation Partition: We create a 5-fold partition using cvpartition with the option 'KFold', k.
  3. Loop Through Folds: For each fold, we:
  • Extract training and test indices.
  • Split the data into training and test sets.
  • Train a support vector machine (SVM) model using fitcsvm.
  • Predict on the test set and calculate accuracy.
4. Accuracy Calculation: We calculate and print the accuracy for each fold and the average accuracy across all folds.
Additional Notes
  • You can replace fitcsvm with any other classifier that suits your needs.
  • Ensure that your dataset is suitable for cross-validation, especially regarding class balance.
  • You might also want to explore MATLAB's crossval function, which automates some parts of this process.

카테고리

Help CenterFile Exchange에서 Argument Definitions에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by