This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

Define Custom Pixel Classification Layer with Dice Loss

This example shows how to define and create a custom pixel classification layer that uses Dice loss.

This layer can be used to train semantic segmentation networks. To learn more about creating custom deep learning layers, see Define Custom Deep Learning Layers (Deep Learning Toolbox).

Dice Loss

The Dice loss is based on the Sørensen-Dice similarity coefficient for measuring overlap between two segmented images. The generalized Dice loss [1,2], L, for between one image Y and the corresponding ground truth T is given by

L=1-2k=1Kwkm=1MYkmTkmk=1Kwkm=1MYkm2+Tkm2 ,

where K is the number of classes, M is the number of elements along the first two dimensions of Y, andwk is a class specific weighting factor that controls the contribution each class makes to the loss. wk is typically the inverse area of the expected region:

wk=1(m=1MTkm)2

This weighting helps counter the influence of larger regions on the Dice score making it easier for the network to learn how to segment smaller regions.

Classification Layer Template

Copy the classification layer template into a new file in MATLAB®. This template outlines the structure of a classification layer and includes the functions that define the layer behavior. The rest of the example shows how to complete the dicePixelClassificationLayer.

classdef dicePixelClassificationLayer < nnet.layer.ClassificationLayer

   properties
      % Optional properties
   end

   methods

        function loss = forwardLoss(layer, Y, T)
            % Layer forward loss function goes here.
        end
        
        function dLdY = backwardLoss(layer, Y, T)
            % Layer backward loss function goes here.
        end
    end
end

Declare Layer Properties

By default, custom output layers have the following properties:

  • Name – Layer name, specified as a character vector or a string scalar. To include this layer in a layer graph, you must specify a nonempty unique layer name. If you train a series network with this layer and Name is set to '', then the software automatically assigns a name at training time.

  • Description – One-line description of the layer, specified as a character vector or a string scalar. This description appears when the layer is displayed in a Layer array. If you do not specify a layer description, then the software displays the layer class name.

  • Type – Type of the layer, specified as a character vector or a string scalar. The value of Type appears when the layer is displayed in a Layer array. If you do not specify a layer type, then the software displays 'Classification layer' or 'Regression layer'.

Custom classification layers also have the following property:

  • Classes – Classes of the output layer, specified as a categorical vector, string array, cell array of character vectors, or 'auto'. If Classes is 'auto', then the software automatically sets the classes at training time. If you specify a string array or cell array of character vectors str, then the software sets the classes of the output layer to categorical(str,str). The default value is 'auto'.

If the layer has no other properties, then you can omit the properties section.

The Dice loss requires a small constant value to prevent division by zero. Specify the property, Epsilon, to hold this value.

classdef dicePixelClassificationLayer < nnet.layer.ClassificationLayer

    properties(Constant)
       % Small constant to prevent division by zero. 
       Epsilon = 1e-8;

    end

    ...
end

Create Constructor Function

Create the function that constructs the layer and initializes the layer properties. Specify any variables required to create the layer as inputs to the constructor function.

Specify an optional input argument name to assign to the Name property at creation.

        function layer = dicePixelClassificationLayer(name)
            % layer =  dicePixelClassificationLayer(name) creates a Dice
            % pixel classification layer with the specified name.
            
            % Set layer name.          
            layer.Name = name;
            
            % Set layer description.
            layer.Description = 'Dice loss';
        end

Create Forward Loss Function

Create a function named forwardLoss that returns the weighted cross entropy loss between the predictions made by the network and the training targets. The syntax for forwardLoss is loss = forwardLoss(layer, Y, T), where Y is the output of the previous layer and T represents the training targets.

For semantic segmentation problems, the dimensions of T match the dimension of Y, where Y is a 4-D array of size H-by-W-by-K-by-N, where K is the number of classes, and N is the mini-batch size.

The size of Y depends on the output of the previous layer. To ensure that Y is the same size as T, you must include a layer that outputs the correct size before the output layer. For example, to ensure that Y is a 4-D array of prediction scores for K classes, you can include a fully connected layer of size K or a convolutional layer with K filters followed by a softmax layer before the output layer.

        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the Dice loss between
            % the predictions Y and the training targets T.   

            % Weights by inverse of region size.
            W = 1 ./ sum(sum(T,1),2).^2;
            
            intersection = sum(sum(Y.*T,1),2);
            union = sum(sum(Y.^2 + T.^2, 1),2);          
            
            numer = 2*sum(W.*intersection,3) + layer.Epsilon;
            denom = sum(W.*union,3) + layer.Epsilon;
            
            % Compute Dice score.
            dice = numer./denom;
            
            % Return average Dice loss.
            N = size(Y,4);
            loss = sum((1-dice))/N;
            
        end

Create Backward Loss Function

Create the backward loss function that returns the derivatives of the Dice loss with respect to the predictions Y. The syntax for backwardLoss is loss = backwardLoss(layer, Y, T), where Y is the output of the previous layer and T represents the training targets.

The dimensions of Y and T are the same as the inputs in forwardLoss.

        function dLdY = backwardLoss(layer, Y, T)
            % dLdY = backwardLoss(layer, Y, T) returns the derivatives of
            % the Dice loss with respect to the predictions Y.
            
            % Weights by inverse of region size.
            W = 1 ./ sum(sum(T,1),2).^2;
            
            intersection = sum(sum(Y.*T,1),2);
            union = sum(sum(Y.^2 + T.^2, 1),2);
     
            numer = 2*sum(W.*intersection,3) + layer.Epsilon;
            denom = sum(W.*union,3) + layer.Epsilon;
            
            N = size(Y,4);
      
            dLdY = (2*W.*Y.*numer./denom.^2 - 2*W.*T./denom)./N;
        end

Completed Layer

The completed layer is provided in dicePixelClassificationLayer.m.

classdef dicePixelClassificationLayer < nnet.layer.ClassificationLayer
    % This layer implements the generalized dice loss function for training
    % semantic segmentation networks.
    
    properties(Constant)
        % Small constant to prevent division by zero. 
        Epsilon = 1e-8;
    end
    
    methods
        
        function layer = dicePixelClassificationLayer(name)
            % layer =  dicePixelClassificationLayer(name) creates a Dice
            % pixel classification layer with the specified name.
            
            % Set layer name.          
            layer.Name = name;
            
            % Set layer description.
            layer.Description = 'Dice loss';
        end
        
        
        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the Dice loss between
            % the predictions Y and the training targets T.   

            % Weights by inverse of region size.
            W = 1 ./ sum(sum(T,1),2).^2;
            
            intersection = sum(sum(Y.*T,1),2);
            union = sum(sum(Y.^2 + T.^2, 1),2);          
            
            numer = 2*sum(W.*intersection,3) + layer.Epsilon;
            denom = sum(W.*union,3) + layer.Epsilon;
            
            % Compute Dice score.
            dice = numer./denom;
            
            % Return average Dice loss.
            N = size(Y,4);
            loss = sum((1-dice))/N;
            
        end
        
        function dLdY = backwardLoss(layer, Y, T)
            % dLdY = backwardLoss(layer, Y, T) returns the derivatives of
            % the Dice loss with respect to the predictions Y.
            
            % Weights by inverse of region size.
            W = 1 ./ sum(sum(T,1),2).^2;
            
            intersection = sum(sum(Y.*T,1),2);
            union = sum(sum(Y.^2 + T.^2, 1),2);
     
            numer = 2*sum(W.*intersection,3) + layer.Epsilon;
            denom = sum(W.*union,3) + layer.Epsilon;
            
            N = size(Y,4);
      
            dLdY = (2*W.*Y.*numer./denom.^2 - 2*W.*T./denom)./N;
        end
    end
end

GPU Compatibility

For GPU compatibility, the layer functions must support inputs and return outputs of type gpuArray. Any other functions used by the layer must do the same.

The MATLAB functions used in forwardLoss, and backwardLoss in dicePixelClassificationLayer all support gpuArray inputs, so the layer is GPU compatible.

Check Output Layer Validity

Create an instance of the layer.

layer = dicePixelClassificationLayer('dice');

Check the layer validity of the layer using checkLayer. Specify the valid input size to be the size of a single observation of typical input to the layer. The layer expects a H-by-W-by-K-by-N array inputs, where K is the number of classes, and N is the number of observations in the mini-batch.

numClasses = 2;
validInputSize = [4 4 numClasses];
checkLayer(layer,validInputSize, 'ObservationDimension',4)
Running nnet.checklayer.OutputLayerTestCase
.......... .......
Done nnet.checklayer.OutputLayerTestCase
__________

Test Summary:
	 17 Passed, 0 Failed, 0 Incomplete, 0 Skipped.
	 Time elapsed: 1.6227 seconds.

The test summary reports the number of passed, failed, incomplete, and skipped tests.

Use Custom Layer in Semantic Segmentation Network

Create a semantic segmentation network that uses the dicePixelClassificationLayer.

layers = [
    imageInputLayer([32 32 1])
    convolution2dLayer(3,64,'Padding',1)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,64,'Padding',1)
    reluLayer
    transposedConv2dLayer(4,64,'Stride',2,'Cropping',1)
    convolution2dLayer(1,2)
    softmaxLayer
    dicePixelClassificationLayer('dice')]
layers = 
  10x1 Layer array with layers:

     1   ''       Image Input              32x32x1 images with 'zerocenter' normalization
     2   ''       Convolution              64 3x3 convolutions with stride [1  1] and padding [1  1  1  1]
     3   ''       ReLU                     ReLU
     4   ''       Max Pooling              2x2 max pooling with stride [2  2] and padding [0  0  0  0]
     5   ''       Convolution              64 3x3 convolutions with stride [1  1] and padding [1  1  1  1]
     6   ''       ReLU                     ReLU
     7   ''       Transposed Convolution   64 4x4 transposed convolutions with stride [2  2] and output cropping [1  1]
     8   ''       Convolution              2 1x1 convolutions with stride [1  1] and padding [0  0  0  0]
     9   ''       Softmax                  softmax
    10   'dice'   Classification Output    Dice loss

Load training data for semantic segmentation using imageDatastore and pixelLabelDatastore.

dataSetDir = fullfile(toolboxdir('vision'),'visiondata','triangleImages');
imageDir = fullfile(dataSetDir,'trainingImages');
labelDir = fullfile(dataSetDir,'trainingLabels');

imds = imageDatastore(imageDir);

classNames = ["triangle" "background"];
labelIDs = [255 0];
pxds = pixelLabelDatastore(labelDir, classNames, labelIDs);

Associate the image and pixel label data using pixelLabelImageDatastore.

ds = pixelLabelImageDatastore(imds,pxds);

Set the training options and train the network.

options = trainingOptions('sgdm', ...
    'InitialLearnRate',1e-2, ...
    'MaxEpochs',100, ...
    'LearnRateDropFactor',1e-1, ...
    'LearnRateDropPeriod',50, ...
    'LearnRateSchedule','piecewise', ...
    'MiniBatchSize',128);

net = trainNetwork(ds,layers,options);
Training on single GPU.
Initializing image normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:03 |       27.89% |       0.8346 |          0.0100 |
|      50 |          50 |       00:00:34 |       89.67% |       0.6384 |          0.0100 |
|     100 |         100 |       00:01:09 |       94.35% |       0.5024 |          0.0010 |
|========================================================================================|

Evaluate the trained network by segmenting a test image and displaying the segmentation result.

I = imread('triangleTest.jpg');

[C,scores] = semanticseg(I,net);

B = labeloverlay(I,C);
figure
imshow(imtile({I,B}))

References

  1. Crum, William R., Oscar Camara, and Derek LG Hill. "Generalized overlap measures for evaluation and validation in medical image analysis." IEEE transactions on medical imaging 25.11 (2006): 1451-1461.

  2. Sudre, Carole H., et al. "Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations." Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support. Springer, Cham, 2017. 240-248.