SD from k-fold validation classification learner
조회 수: 7 (최근 30일)
이전 댓글 표시
I am dealing with a reviewer that ask for the standard deviation obtained in the k-fold validation LDA algorithm obtained from the classification learner app.
How can I do it?
댓글 수: 0
답변 (1개)
Drew
2024년 6월 14일
편집: Drew
2024년 6월 17일
The Classification Learner app does not currently provide access to the per-fold validation results, so the best way to get the standard deviation of accuracy or error rate across folds is to perform cross-validation again. This can be done by generating code from the Classification Learner app and then adjusting the code in order to calculate the standard deviation across folds. In simple cases that don't use PCA or feature selection, this can also be done easily by exporting the final model and then performing cross-validation with a few lines of code.
Without PCA or Feature Selection:
(1) In the simple case of a model trained in Classification Learner which does not use PCA or feature selection, here is an example of writing the code to redo cross-validation. After training the Linear Discriminant model and exporting it from the Classification Learner app as "trainedModel" (with the training data included in the model), you can train crossvalidation models, get the error rate of each fold, and calculate the standard deviation of the error per fold with a few lines of code as follows:
% Train cross-validation models based on the exported model
CVMdl = crossval(trainedModel.ClassificationDiscriminant,'KFold',5);
% Get error rate per fold. Mode='individual' provides per-fold error rate
errorRatePerFold=kfoldLoss(CVMdl,Mode='individual',LossFun='classiferror')
% Get standard deviation of error rate per fold
standardDeviationOfErrorPerFold = std(errorRatePerFold)
Running this code multiple times will generally result in different answers, due to the use of different crossvalidation partitions of the data. To get repeatable answers, set the random number seed, for example:
rng("default")
With PCA and/or feature selection:
(2) In the case of a model trained in Classification Learner that uses PCA and/or feature selection, the PCA matrix and/or the feature selection can be different for each crossvalidation fold. The crossvalidation code generated by the Classification Learner app takes all of this into account. So, after training the Linear Discriminant model in the Classification Learner app, choose the "Generate Function" option. This will generate code to re-train the final model on all of the training data, including code to train cross-validation models (one for each fold), where each fold can have different feature selection and PCA. This code can be slightly modified to obtain the error rate from each fold.
To illustrate, the following is slightly modified code generated from Classification Learner. For the purposes of this illustration, an LDA model was trained where feature selection is set to choose the top 3 features according to the MRMR criterion, and that is followed by PCA where the number of output features from the PCA is determined by the number of features needed to explain at least 95 percent of the variance. These are not recommended settings for best accuracy.These feature selection and PCA settings for the purpose of illustration of the resulting crossvalidation code.
The modifications to the generated code consist of just the following small changes:
(1) On the first line, added another output, "errorPerFold", and changed the function name to make it more specific.
(2) Added initialization of errorPerFold "errorPerFold=NaN(KFolds,1);". This is not necessary, but removes the warning about growing the array on each iteration through the loop.
(3) At the bottom of the loop over the folds, added lines to calculate the errorPerFold:
% Store per-fold error
temp = strcmp( strtrim(validationPredictions(cvp.test(fold),:)), strtrim(response(cvp.test(fold))));
errorPerFold(fold) = 1 - sum(temp)/length(temp);
Here is the generated code with those few modifications:
function [trainedClassifier, validationAccuracy, errorPerFold] = trainClassifier_LDA_withFSandPCA(trainingData)
% [trainedClassifier, validationAccuracy] = trainClassifier(trainingData)
% Returns a trained classifier and its accuracy. This code recreates the
% classification model trained in Classification Learner app. Use the
% generated code to automate training the same model with new data, or to
% learn how to programmatically train models.
%
% Input:
% trainingData: A table containing the same predictor and response
% columns as those imported into the app.
%
%
% Output:
% trainedClassifier: A struct containing the trained classifier. The
% struct contains various fields with information about the trained
% classifier.
%
% trainedClassifier.predictFcn: A function to make predictions on new
% data.
%
% validationAccuracy: A double representing the validation accuracy as
% a percentage. In the app, the Models pane displays the validation
% accuracy for each model.
%
% Use the code to train the model with new data. To retrain your
% classifier, call the function from the command line with your original
% data or new data as the input argument trainingData.
%
% For example, to retrain a classifier trained with the original data set
% T, enter:
% [trainedClassifier, validationAccuracy] = trainClassifier(T)
%
% To make predictions with the returned 'trainedClassifier' on new data T2,
% use
% [yfit,scores] = trainedClassifier.predictFcn(T2)
%
% T2 must be a table containing at least the same predictor columns as used
% during training. For details, enter:
% trainedClassifier.HowToPredict
% Auto-generated by MATLAB on 17-Jun-2024 11:10:35
% Extract predictors and response
% This code processes the data into the right shape for training the
% model.
inputTable = trainingData;
predictorNames = {'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'};
predictors = inputTable(:, predictorNames);
response = inputTable.Species;
isCategoricalPredictor = [false, false, false, false];
classNames = {'setosa'; 'versicolor'; 'virginica'};
% Feature Ranking and Selection
% Replace Inf/-Inf values with NaN to prepare data for normalization
predictors = standardizeMissing(predictors, {Inf, -Inf});
% Normalize data for feature ranking
predictorMatrix = normalize(predictors, "DataVariable", ~isCategoricalPredictor);
% Rank features using MRMR algorithm
featureIndex = fscmrmr(...
predictorMatrix, ...
response);
numFeaturesToKeep = 3;
includedPredictorNames = predictors.Properties.VariableNames(featureIndex(1:numFeaturesToKeep));
predictors = predictors(:,includedPredictorNames);
isCategoricalPredictor = isCategoricalPredictor(featureIndex(1:numFeaturesToKeep));
% Apply a PCA to the predictor matrix.
% Run PCA on numeric predictors only. Categorical predictors are passed through PCA untouched.
isCategoricalPredictorBeforePCA = isCategoricalPredictor;
numericPredictors = predictors(:, ~isCategoricalPredictor);
numericPredictors = table2array(varfun(@double, numericPredictors));
% 'inf' values have to be treated as missing data for PCA.
numericPredictors(isinf(numericPredictors)) = NaN;
[pcaCoefficients, pcaScores, ~, ~, explained, pcaCenters] = pca(...
numericPredictors);
% Keep enough components to explain the desired amount of variance.
explainedVarianceToKeepAsFraction = 95/100;
numComponentsToKeep = find(cumsum(explained)/sum(explained) >= explainedVarianceToKeepAsFraction, 1);
pcaCoefficients = pcaCoefficients(:,1:numComponentsToKeep);
predictors = [array2table(pcaScores(:,1:numComponentsToKeep)), predictors(:, isCategoricalPredictor)];
isCategoricalPredictor = [false(1,numComponentsToKeep), true(1,sum(isCategoricalPredictor))];
% Train a classifier
% This code specifies all the classifier options and trains the classifier.
classificationDiscriminant = fitcdiscr(...
predictors, ...
response, ...
'DiscrimType', 'linear', ...
'Gamma', 0, ...
'FillCoeffs', 'off', ...
'ClassNames', classNames);
% Create the result struct with predict function
predictorExtractionFcn = @(t) t(:, predictorNames);
featureSelectionFcn = @(x) x(:,includedPredictorNames);
pcaTransformationFcn = @(x) [ array2table((table2array(varfun(@double, x(:, ~isCategoricalPredictorBeforePCA))) - pcaCenters) * pcaCoefficients), x(:,isCategoricalPredictorBeforePCA) ];
discriminantPredictFcn = @(x) predict(classificationDiscriminant, x);
trainedClassifier.predictFcn = @(x) discriminantPredictFcn(pcaTransformationFcn(featureSelectionFcn(predictorExtractionFcn(x))));
% Add additional fields to the result struct
trainedClassifier.RequiredVariables = {'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'};
trainedClassifier.PCACenters = pcaCenters;
trainedClassifier.PCACoefficients = pcaCoefficients;
trainedClassifier.ClassificationDiscriminant = classificationDiscriminant;
trainedClassifier.About = 'This struct is a trained model exported from Classification Learner R2024b.';
trainedClassifier.HowToPredict = sprintf('To make predictions on a new table, T, use: \n [yfit,scores] = c.predictFcn(T) \nreplacing ''c'' with the name of the variable that is this struct, e.g. ''trainedModel''. \n \nThe table, T, must contain the variables returned by: \n c.RequiredVariables \nVariable formats (e.g. matrix/vector, datatype) must match the original training data. \nAdditional variables are ignored. \n \nFor more information, see <a href="matlab:helpview(fullfile(docroot, ''stats'', ''stats.map''), ''appclassification_exportmodeltoworkspace'')">How to predict using an exported model</a>.');
% Extract predictors and response
% This code processes the data into the right shape for training the
% model.
inputTable = trainingData;
predictorNames = {'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'};
predictors = inputTable(:, predictorNames);
response = inputTable.Species;
isCategoricalPredictor = [false, false, false, false];
classNames = {'setosa'; 'versicolor'; 'virginica'};
% Perform cross-validation
KFolds = 5;
cvp = cvpartition(response, 'KFold', KFolds);
% Initialize the predictions to the proper sizes
validationPredictions = response;
numObservations = size(predictors, 1);
numClasses = 3;
validationScores = NaN(numObservations, numClasses);
errorPerFold=NaN(KFolds,1);
for fold = 1:KFolds
trainingPredictors = predictors(cvp.training(fold), :);
trainingResponse = response(cvp.training(fold), :);
foldIsCategoricalPredictor = isCategoricalPredictor;
% Feature Ranking and Selection
% Replace Inf/-Inf values with NaN to prepare data for normalization
trainingPredictors = standardizeMissing(trainingPredictors, {Inf, -Inf});
% Normalize data for feature ranking
predictorMatrix = normalize(trainingPredictors, "DataVariable", ~foldIsCategoricalPredictor);
% Rank features using MRMR algorithm
featureIndex = fscmrmr(...
predictorMatrix, ...
trainingResponse);
numFeaturesToKeep = 3;
includedPredictorNames = trainingPredictors.Properties.VariableNames(featureIndex(1:numFeaturesToKeep));
trainingPredictors = trainingPredictors(:,includedPredictorNames);
foldIsCategoricalPredictor = foldIsCategoricalPredictor(featureIndex(1:numFeaturesToKeep));
% Apply a PCA to the predictor matrix.
% Run PCA on numeric predictors only. Categorical predictors are passed through PCA untouched.
isCategoricalPredictorBeforePCA = foldIsCategoricalPredictor;
numericPredictors = trainingPredictors(:, ~foldIsCategoricalPredictor);
numericPredictors = table2array(varfun(@double, numericPredictors));
% 'inf' values have to be treated as missing data for PCA.
numericPredictors(isinf(numericPredictors)) = NaN;
[pcaCoefficients, pcaScores, ~, ~, explained, pcaCenters] = pca(...
numericPredictors);
% Keep enough components to explain the desired amount of variance.
explainedVarianceToKeepAsFraction = 95/100;
numComponentsToKeep = find(cumsum(explained)/sum(explained) >= explainedVarianceToKeepAsFraction, 1);
pcaCoefficients = pcaCoefficients(:,1:numComponentsToKeep);
trainingPredictors = [array2table(pcaScores(:,1:numComponentsToKeep)), trainingPredictors(:, foldIsCategoricalPredictor)];
foldIsCategoricalPredictor = [false(1,numComponentsToKeep), true(1,sum(foldIsCategoricalPredictor))];
% Train a classifier
% This code specifies all the classifier options and trains the classifier.
classificationDiscriminant = fitcdiscr(...
trainingPredictors, ...
trainingResponse, ...
'DiscrimType', 'linear', ...
'Gamma', 0, ...
'FillCoeffs', 'off', ...
'ClassNames', classNames);
% Create the result struct with predict function
featureSelectionFcn = @(x) x(:,includedPredictorNames);
pcaTransformationFcn = @(x) [ array2table((table2array(varfun(@double, x(:, ~isCategoricalPredictorBeforePCA))) - pcaCenters) * pcaCoefficients), x(:,isCategoricalPredictorBeforePCA) ];
discriminantPredictFcn = @(x) predict(classificationDiscriminant, x);
validationPredictFcn = @(x) discriminantPredictFcn(pcaTransformationFcn(featureSelectionFcn(x)));
% Add additional fields to the result struct
% Compute validation predictions
validationPredictors = predictors(cvp.test(fold), :);
[foldPredictions, foldScores] = validationPredictFcn(validationPredictors);
% Store predictions in the original order
validationPredictions(cvp.test(fold), :) = foldPredictions;
validationScores(cvp.test(fold), :) = foldScores;
% Store per-fold error
temp = strcmp( strtrim(validationPredictions(cvp.test(fold),:)), strtrim(response(cvp.test(fold))));
errorPerFold(fold) = 1 - sum(temp)/length(temp);
end
% Compute validation accuracy
correctPredictions = strcmp( strtrim(validationPredictions), strtrim(response));
isMissing = cellfun(@(x) all(isspace(x)), response, 'UniformOutput', true);
correctPredictions = correctPredictions(~isMissing);
validationAccuracy = sum(correctPredictions)/length(correctPredictions);
After running this code, simply call "std(errorPerFold)" to get the standard deviation of the error rate across folds.
Just like in the simpler case, running this code multiple times will generally result in different answers, due to the use of different crossvalidation partitions of the data. To get repeatable answers, set the random number seed, for example:
rng("default")
If this answer helps you, please remember to accept the answer.
댓글 수: 0
참고 항목
카테고리
Help Center 및 File Exchange에서 Discriminant Analysis에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!