필터 지우기
필터 지우기

How to generate ROC curve for multiclass semantic segmentation of biological tissue

조회 수: 9 (최근 30일)
My code looks something like this:
inputSize = [24 32 3];
imgLayer = imageInputLayer(inputSize)
filterSize = 3;
numFilters = 32;
conv = convolution2dLayer(filterSize,numFilters,'Padding',1);
relu = reluLayer();
poolSize = 2;
maxPoolDownsample2x = maxPooling2dLayer(poolSize,'Stride',2);
downsamplingLayers = [
conv
relu
maxPoolDownsample2x
conv
relu
maxPoolDownsample2x
]
filterSize = 4;
transposedConvUpsample2x = transposedConv2dLayer(4,numFilters,'Stride',2,'Cropping',1);
upsamplingLayers = [
transposedConvUpsample2x
relu
transposedConvUpsample2x
relu
]
numClasses = 5;
conv1x1 = convolution2dLayer(1,numClasses);
finalLayers = [
conv1x1
softmaxLayer()
pixelClassificationLayer()
]
net = [
imgLayer
downsamplingLayers
upsamplingLayers
finalLayers
]
imgLayer =
ImageInputLayer with properties:
Name: ''
InputSize: [24 32 3]
Hyperparameters
DataAugmentation: 'none'
Normalization: 'zerocenter'
NormalizationDimension: 'auto'
Mean: []
downsamplingLayers =
6×1 Layer array with layers:
1 '' Convolution 32 3×3 convolutions with stride [1 1] and padding [1 1 1 1]
2 '' ReLU ReLU
3 '' Max Pooling 2×2 max pooling with stride [2 2] and padding [0 0 0 0]
4 '' Convolution 32 3×3 convolutions with stride [1 1] and padding [1 1 1 1]
5 '' ReLU ReLU
6 '' Max Pooling 2×2 max pooling with stride [2 2] and padding [0 0 0 0]
upsamplingLayers =
4×1 Layer array with layers:
1 '' Transposed Convolution 32 4×4 transposed convolutions with stride [2 2] and cropping [1 1 1 1]
2 '' ReLU ReLU
3 '' Transposed Convolution 32 4×4 transposed convolutions with stride [2 2] and cropping [1 1 1 1]
4 '' ReLU ReLU
finalLayers =
3×1 Layer array with layers:
1 '' Convolution 5 1×1 convolutions with stride [1 1] and padding [0 0 0 0]
2 '' Softmax softmax
3 '' Pixel Classification Layer Cross-entropy loss
net =
14×1 Layer array with layers:
1 '' Image Input 24×32×3 images with 'zerocenter' normalization
2 '' Convolution 32 3×3 convolutions with stride [1 1] and padding [1 1 1 1]
3 '' ReLU ReLU
4 '' Max Pooling 2×2 max pooling with stride [2 2] and padding [0 0 0 0]
5 '' Convolution 32 3×3 convolutions with stride [1 1] and padding [1 1 1 1]
6 '' ReLU ReLU
7 '' Max Pooling 2×2 max pooling with stride [2 2] and padding [0 0 0 0]
8 '' Transposed Convolution 32 4×4 transposed convolutions with stride [2 2] and cropping [1 1 1 1]
9 '' ReLU ReLU
10 '' Transposed Convolution 32 4×4 transposed convolutions with stride [2 2] and cropping [1 1 1 1]
11 '' ReLU ReLU
12 '' Convolution 5 1×1 convolutions with stride [1 1] and padding [0 0 0 0]
13 '' Softmax softmax
14 '' Pixel Classification Layer Cross-entropy loss
>> dataSetDir = fullfile('C:\Users\a0050627\Pictures\3D reconstruction stacks\Final dissertation images\final dissertation images\ground truth 13 32h24w\trainingImages');
imageDir = fullfile('C:\Users\a0050627\Pictures\3D reconstruction stacks\Final dissertation images\final dissertation images\ground truth 13 32h24w\trainingImages');
labelDir = fullfile('C:\Users\a0050627\Pictures\3D reconstruction stacks\Final dissertation images\final dissertation images\ground truth 13 32h24w\trainingLabels');
>> imds = imageDatastore(imageDir);
classNames = ["Fibrous_tissue" "Cartilage" "Blood_cells" "Smooth_muscle" "Background"];
labelIDs = [1 2 3 4 5];
pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);
I = read(imds);
C = read(pxds);
>> I = imresize(I,5);
L = imresize(uint8(C{1}),5);
imshowpair(I,L,'montage')
numFilters = 56;
filterSize = 3;
numClasses = 5;
layers = [
imageInputLayer([24 32 3])
convolution2dLayer(filterSize,numFilters,'Padding',1)
reluLayer()
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(filterSize,numFilters,'Padding',1)
reluLayer()
transposedConv2dLayer(4,numFilters,'Stride',2,'Cropping',1);
convolution2dLayer(1,numClasses);
softmaxLayer()
pixelClassificationLayer()
];
opts = trainingOptions('sgdm', ...
'InitialLearnRate',1e-3, ...
'MaxEpochs',100, ...
'MiniBatchSize',56);
trainingData = combine(imds,pxds);
net = trainNetwork(trainingData,layers,opts);
testImage = imread('trainingImages0000.tif');
imshow(testImage)
C = semanticseg(testImage,net);
B = labeloverlay(testImage,C);
imshow(B)
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning |
| | | (hh:mm:ss) | Accuracy | Loss | Rate |
|========================================================================================|
| 1 | 1 | 00:00:00 | 16.08% | 2.4888 | 0.0010 |
| 50 | 50 | 00:00:03 | 69.22% | 0.8458 | 0.0010 |
| 100 | 100 | 00:00:05 | 73.94% | 0.6949 | 0.0010 |
|========================================================================================|
Error using imread>get_full_filename (line 570)
File "trainingImages0000.tif" does not exist.
Error in imread (line 377)
fullname = get_full_filename(filename);
>> I = imresize(I,5);
L = imresize(uint8(C{1}),5);
imshowpair(I,L,'montage')
numFilters = 56;
filterSize = 3;
numClasses = 5;
layers = [
imageInputLayer([24 32 3])
convolution2dLayer(filterSize,numFilters,'Padding',1)
reluLayer()
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(filterSize,numFilters,'Padding',1)
reluLayer()
transposedConv2dLayer(4,numFilters,'Stride',2,'Cropping',1);
convolution2dLayer(1,numClasses);
softmaxLayer()
pixelClassificationLayer()
];
opts = trainingOptions('sgdm', ...
'InitialLearnRate',1e-3, ...
'MaxEpochs',100, ...
'MiniBatchSize',56);
trainingData = combine(imds,pxds);
net = trainNetwork(trainingData,layers,opts);
testImage = imread('trainingImages0000.tif');
imshow(testImage)
C = semanticseg(testImage,net);
B = labeloverlay(testImage,C);
imshow(B)
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning |
| | | (hh:mm:ss) | Accuracy | Loss | Rate |
|========================================================================================|
| 1 | 1 | 00:00:00 | 25.46% | 2.5070 | 0.0010 |
| 50 | 50 | 00:00:03 | 71.97% | 0.8210 | 0.0010 |
| 100 | 100 | 00:00:05 | 75.99% | 0.6633 | 0.0010 |
|========================================================================================|
>> dataSetDir = fullfile('C:\Users\a0050627\Pictures\3D reconstruction stacks\Final dissertation images\final dissertation images\ground truth 13 32h24w\dataSetDir');
testImagesDir = fullfile('C:\Users\a0050627\Pictures\3D reconstruction stacks\Final dissertation images\final dissertation images\ground truth 13 32h24w\trainingImages');
imds = imageDatastore(testImagesDir);
testLabelsDir = fullfile('C:\Users\a0050627\Pictures\3D reconstruction stacks\Final dissertation images\final dissertation images\ground truth 13 32h24w\trainingLabels');
classNames = ["Fibrous_tissue" "Cartilage" "Blood_cells" "Smooth_muscle" "Background"];
labelIDs = [1 2 3 4 5];
pxdsTruth = pixelLabelDatastore(testLabelsDir,classNames,labelIDs);
pxdsResults = semanticseg(imds,net,"WriteLocation",tempdir);
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTruth);
metrics.ClassMetrics
metrics.ConfusionMatrix
normConfMatData = metrics.NormalizedConfusionMatrix.Variables;
figure
h = heatmap(classNames,classNames,100*normConfMatData);
h.XLabel = 'Predicted Class';
h.YLabel = 'True Class';
h.Title = 'Normalized Confusion Matrix (%)';
imageIoU = metrics.ImageMetrics.MeanIoU;
figure
histogram(imageIoU)
title('Image Mean IoU')
[minIoU, worstImageIndex] = min(imageIoU);
minIoU = minIoU(1);
worstImageIndex = worstImageIndex(1);
worstTestImage = readimage(imds,worstImageIndex);
worstTrueLabels = readimage(pxdsTruth,worstImageIndex);
worstPredictedLabels = readimage(pxdsResults,worstImageIndex);
worstTrueLabelImage = im2uint8(worstTrueLabels == classNames(1));
worstPredictedLabelImage = im2uint8(worstPredictedLabels == classNames(1));
worstMontage = cat(4,worstTestImage,worstTrueLabelImage,worstPredictedLabelImage);
worstMontage = imresize(worstMontage,4,"nearest");
figure
montage(worstMontage,'Size',[1 3])
title(['Test Image vs. Truth vs. Prediction. IoU = ' num2str(minIoU)])
[maxIoU, bestImageIndex] = max(imageIoU);
maxIoU = maxIoU(1);
bestImageIndex = bestImageIndex(1);
bestTestImage = readimage(imds,bestImageIndex);
bestTrueLabels = readimage(pxdsTruth,bestImageIndex);
bestPredictedLabels = readimage(pxdsResults,bestImageIndex);
bestTrueLabelImage = im2uint8(bestTrueLabels == classNames(1));
bestPredictedLabelImage = im2uint8(bestPredictedLabels == classNames(1));
bestMontage = cat(4,bestTestImage,bestTrueLabelImage,bestPredictedLabelImage);
bestMontage = imresize(bestMontage,4,"nearest");
figure
montage(bestMontage,'Size',[1 3])
title(['Test Image vs. Truth vs. Prediction. IoU = ' num2str(maxIoU)])
evaluationMetrics = ["accuracy" "iou"];
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTruth,"Metrics",evaluationMetrics);
metrics.ClassMetrics
Running semantic segmentation network
-------------------------------------
* Processed 8 images.
Evaluating semantic segmentation results
----------------------------------------
* Selected metrics: global accuracy, class accuracy, IoU, weighted IoU, BF score.
* Processed 8 images.
* Finalizing... Done.
* Data set metrics:
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.7609 0.42784 0.34059 0.6219 0.49539
ans =
5×3 table
Accuracy IoU MeanBFScore
________ _______ ___________
Fibrous_tissue 0 0 NaN
Cartilage 0.41071 0.31293 0.44522
Blood_cells 0.20241 0.16634 0.27843
Smooth_muscle 0.61982 0.45857 0.57353
Background 0.90628 0.7651 0.68441
ans =
5×5 table
Fibrous_tissue Cartilage Blood_cells Smooth_muscle Background
______________ _________ ___________ _____________ __________
Fibrous_tissue 0 3 6 6 0
Cartilage 0 46 37 28 1
Blood_cells 0 20 84 273 38
Smooth_muscle 0 12 42 1151 652
Background 0 0 5 346 3394
Looking forward to a response,
Zandile

답변 (1개)

Avadhoot
Avadhoot 2024년 1월 24일
Hi Zandine,
From your question I infer that you need to plot the ROC curve for your segmentation model. Typically it is not a standard practice to use ROC curves in case of segmentation as it is a multiclass and pixel-by-pixel classification problem. You could generate ROC curve for all the classes separately. In this approach, for every class that class is considered as positive and all the other classes are considered as negative. For that you will require the predicted probabilities matrix for this data. This matrix should contain the output of your neural network before the softmax layer. Using that you can calculate the true positive rate (TPR) and the false positive rate(FPR) for each class in the data.
You can refer to a sample code given below:
for i = 1:numClasses
% Get predicted probabilities for class 'i'
predictedProbabilitiesForClass = predictedProbabilities(:,:,i);
% Initialize arrays to hold TPR and FPR values
TPR = [];
FPR = [];
% Threshold levels to evaluate
thresholds = 0:0.01:1;
for threshold = thresholds
% Binarize predictions based on the current threshold
predictedLabels = predictedProbabilitiesForClass >= threshold;
% Calculate confusion matrix elements
TP = sum((predictedLabels == 1) & (trueLabels == i), 'all');
FP = sum((predictedLabels == 1) & (trueLabels ~= i), 'all');
TN = sum((predictedLabels == 0) & (trueLabels ~= i), 'all');
FN = sum((predictedLabels == 0) & (trueLabels == i), 'all');
% Calculate True Positive Rate (TPR) and False Positive Rate (FPR)
currentTPR = TP / (TP + FN);
currentFPR = FP / (FP + TN);
% Append to TPR and FPR arrays
TPR = [TPR, currentTPR];
FPR = [FPR, currentFPR];
end
% Sort the FPR and TPR values
[FPR, ind] = sort(FPR);
TPR = TPR(ind);
% Plot ROC curve for class 'i'
figure;
plot(FPR, TPR);
xlabel('False Positive Rate');
ylabel('True Positive Rate');
title(sprintf('ROC Curve for Class %d', i));
end
Below are the few considerations regarding the code:
  • "trueLabels" contains the actual output data for each sample ranging from 1 to 5.
  • TPR and FPR are calculated for a set of thresholds to get the ROC curve.
  • ROC curve for each class is plotted separately.
I hope it helps.

제품


릴리스

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by