Semantic Segmentation Using Dilated Convolutions
Train a semantic segmentation network using dilated convolutions.
A semantic segmentation network classifies every pixel in an image, resulting in an image that is segmented by class. Applications for semantic segmentation include road segmentation for autonomous driving and cancer cell segmentation for medical diagnosis. To learn more, see Getting Started with Semantic Segmentation Using Deep Learning (Computer Vision Toolbox).
Semantic segmentation networks like DeepLab [1] make extensive use of dilated convolutions (also known as atrous convolutions) because they can increase the receptive field of the layer (the area of the input which the layers can see) without increasing the number of parameters or computations.
Load Training Data
The example uses a simple dataset of 32-by-32 triangle images for illustration purposes. The dataset includes accompanying pixel label ground truth data. Load the training data using an imageDatastore
and a pixelLabelDatastore
.
dataFolder = fullfile(toolboxdir('vision'),'visiondata','triangleImages'); imageFolderTrain = fullfile(dataFolder,'trainingImages'); labelFolderTrain = fullfile(dataFolder,'trainingLabels');
Create an imageDatastore
for the images.
imdsTrain = imageDatastore(imageFolderTrain);
Create a pixelLabelDatastore
for the ground truth pixel labels.
classNames = ["triangle" "background"]; labels = [255 0]; pxdsTrain = pixelLabelDatastore(labelFolderTrain,classNames,labels)
pxdsTrain = PixelLabelDatastore with properties: Files: {200x1 cell} ClassNames: {2x1 cell} ReadSize: 1 ReadFcn: @readDatastoreImage AlternateFileSystemRoots: {}
Create Semantic Segmentation Network
This example uses a simple semantic segmentation network based on dilated convolutions.
Create a data source for training data and get the pixel counts for each label.
ds = combine(imdsTrain,pxdsTrain); tbl = countEachLabel(pxdsTrain)
tbl=2×3 table
Name PixelCount ImagePixelCount
______________ __________ _______________
{'triangle' } 10326 2.048e+05
{'background'} 1.9447e+05 2.048e+05
The majority of pixel labels are for background. This class imbalance biases the learning process in favor of the dominant class. To fix this, use class weighting to balance the classes. You can use several methods to compute class weights. One common method is inverse frequency weighting where the class weights are the inverse of the class frequencies. This method increases the weight given to under represented classes. Calculate the class weights using inverse frequency weighting.
numberPixels = sum(tbl.PixelCount); frequency = tbl.PixelCount / numberPixels; classWeights = 1 ./ frequency;
Create a network for pixel classification by using an image input layer with an input size corresponding to the size of the input images. Next, specify three blocks of convolution, batch normalization, and ReLU layers. For each convolutional layer, specify 32 3-by-3 filters with increasing dilation factors and pad the inputs so they are the same size as the outputs by setting the 'Padding'
option to 'same'
. To classify the pixels, include a convolutional layer with K 1-by-1 convolutions, where K is the number of classes, followed by a softmax layer and a pixelClassificationLayer
with the inverse class weights.
inputSize = [32 32 1]; filterSize = 3; numFilters = 32; numClasses = numel(classNames); layers = [ imageInputLayer(inputSize) convolution2dLayer(filterSize,numFilters,'DilationFactor',1,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(filterSize,numFilters,'DilationFactor',2,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(filterSize,numFilters,'DilationFactor',4,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(1,numClasses) softmaxLayer pixelClassificationLayer('Classes',classNames,'ClassWeights',classWeights)];
Train Network
Specify the training options.
options = trainingOptions('sgdm', ... 'MaxEpochs', 100, ... 'MiniBatchSize', 64, ... 'InitialLearnRate', 1e-3);
Train the network using trainNetwork
.
net = trainNetwork(ds,layers,options);
Training on single CPU. Initializing input data normalization. |========================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning | | | | (hh:mm:ss) | Accuracy | Loss | Rate | |========================================================================================| | 1 | 1 | 00:00:00 | 91.62% | 1.6825 | 0.0010 | | 17 | 50 | 00:00:13 | 88.56% | 0.2393 | 0.0010 | | 34 | 100 | 00:00:26 | 92.08% | 0.1672 | 0.0010 | | 50 | 150 | 00:00:37 | 93.17% | 0.1472 | 0.0010 | | 67 | 200 | 00:00:49 | 94.15% | 0.1313 | 0.0010 | | 84 | 250 | 00:00:59 | 94.47% | 0.1167 | 0.0010 | | 100 | 300 | 00:01:12 | 95.04% | 0.1100 | 0.0010 | |========================================================================================| Training finished: Max epochs completed.
Test Network
Load the test data. Create an imageDatastore
for the images. Create a pixelLabelDatastore
for the ground truth pixel labels.
imageFolderTest = fullfile(dataFolder,'testImages'); imdsTest = imageDatastore(imageFolderTest); labelFolderTest = fullfile(dataFolder,'testLabels'); pxdsTest = pixelLabelDatastore(labelFolderTest,classNames,labels);
Make predictions using the test data and trained network.
pxdsPred = semanticseg(imdsTest,net,'MiniBatchSize',32,'WriteLocation',tempdir);
Running semantic segmentation network ------------------------------------- * Processed 100 images.
Evaluate the prediction accuracy using evaluateSemanticSegmentation
.
metrics = evaluateSemanticSegmentation(pxdsPred,pxdsTest);
Evaluating semantic segmentation results ---------------------------------------- * Selected metrics: global accuracy, class accuracy, IoU, weighted IoU, BF score. * Processed 100 images. * Finalizing... Done. * Data set metrics: GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore ______________ ____________ _______ ___________ ___________ 0.95237 0.97352 0.72081 0.92889 0.46416
For more information on evaluating semantic segmentation networks, see evaluateSemanticSegmentation
(Computer Vision Toolbox).
Segment New Image
Read and display the test image triangleTest.jpg
.
imgTest = imread('triangleTest.jpg');
figure
imshow(imgTest)
Segment the test image using semanticseg
and display the results using labeloverlay
.
C = semanticseg(imgTest,net); B = labeloverlay(imgTest,C); figure imshow(B)
See Also
pixelLabelDatastore
(Computer Vision Toolbox) | pixelLabelImageDatastore
(Computer Vision Toolbox) | semanticseg
(Computer Vision Toolbox) | labeloverlay
(Image Processing Toolbox) | countEachLabel
(Computer Vision Toolbox) | pixelClassificationLayer
(Computer Vision Toolbox) | trainingOptions
| trainNetwork
| evaluateSemanticSegmentation
(Computer Vision Toolbox) | convolution2dLayer