Main Content

trainSSDObjectDetector

Train SSD deep learning object detector

Since R2020a

Description

Train a Detector

trainedDetector = trainSSDObjectDetector(trainingData,detector,options) trains a single shot multibox detector (SSD) using deep learning. You can train an SSD detector to detect multiple object classes. Use this syntax to train either an untrained or pretrained SSD object detection network. You can also use this syntax to fine-tune a network with additional training data or to perform more training iterations to improve detector accuracy.

This function requires that you have Deep Learning Toolbox™. It is recommended that you also have Parallel Computing Toolbox™ to use with a CUDA®-enabled NVIDIA® GPU. For information about the supported compute capabilities, see GPU Computing Requirements (Parallel Computing Toolbox).

example

Resume Training a Detector

trainedDetector = trainSSDObjectDetector(trainingData,checkpoint,options) resumes training from a detector checkpoint.

Additional Properties

trainedDetector = trainSSDObjectDetector(___,Name=Value) uses additional options specified by one or more name-value arguments and any of the previous inputs.

[trainedDetector,info] = trainSSDObjectDetector(___) also returns information on the training progress, such as training loss and accuracy, for each iteration.

Examples

collapse all

This example shows how to train an SSD object detector on a vehicle data set. The example then uses the trained detector for detecting vehicles in an image.

Load the training data into the workspace.

data = load("vehicleTrainingData.mat");
trainingData = data.vehicleTrainingData;

Specify the directory in which training samples are stored. Add full path to the filenames in training data.

dataDir = fullfile(toolboxdir("vision"),"visiondata");
trainingData.imageFilename = fullfile(dataDir,trainingData.imageFilename);

Create an image datastore using the files from the table.

imds = imageDatastore(trainingData.imageFilename);

Create a box label datastore using the label columns from the table.

blds = boxLabelDatastore(trainingData(:,2:end));

Combine the datastores.

ds = combine(imds,blds);

Specify a base network.

baseNetwork = imagePretrainedNetwork("resnet50");

Specify the names of the classes to detect.

classNames = "vehicle";

Specify the anchor boxes to use for training the network.

anchorBoxes = { ...
    [30 60; 60 30; 50 50; 100 100], ...
    [40 70; 70 40; 60 60; 120 120]};

Specify the names of the feature extraction layers to connect to the detection subnetwork.

layersToConnect = ["activation_22_relu" "activation_40_relu"];

Create an SSD object detector by using the ssdObjectDetector function.

detector = ssdObjectDetector(baseNetwork,classNames,anchorBoxes, ...
    DetectionNetworkSource=layersToConnect);

Specify the training options.

options = trainingOptions("sgdm", ...
    InitialLearnRate=0.001, ...
    MiniBatchSize=16, ...
    Verbose=true, ...
    MaxEpochs=30, ...
    Shuffle="never", ...
    VerboseFrequency=10);

Train the SSD object detector.

[detector,info] = trainSSDObjectDetector(ds,detector,options);
*************************************************************************
Training an SSD Object Detector for the following object classes:

* vehicle

Training on single GPU.
Initializing input data normalization.
|=======================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |     Loss     |   Accuracy   |     RMSE     |      Rate       |
|=======================================================================================================|
|       1 |           1 |       00:00:00 |      54.8573 |       43.62% |         2.56 |          0.0010 |
|       1 |          10 |       00:00:04 |       3.5709 |       98.90% |         1.65 |          0.0010 |
|       2 |          20 |       00:00:07 |       3.2212 |       99.93% |         0.99 |          0.0010 |
|       2 |          30 |       00:00:11 |       5.3296 |       98.89% |         1.51 |          0.0010 |
|       3 |          40 |       00:00:15 |       4.5304 |       99.76% |         1.15 |          0.0010 |
|       3 |          50 |       00:00:18 |       6.2441 |       98.47% |         1.28 |          0.0010 |
|       4 |          60 |       00:00:22 |       3.6098 |       98.95% |         1.26 |          0.0010 |
|       4 |          70 |       00:00:25 |       5.0838 |       98.81% |         1.23 |          0.0010 |
|       5 |          80 |       00:00:29 |       4.9031 |       98.84% |         1.47 |          0.0010 |
|       5 |          90 |       00:00:32 |       4.5612 |       99.91% |         0.99 |          0.0010 |
|       6 |         100 |       00:00:36 |       3.0531 |       98.90% |         1.13 |          0.0010 |
|       7 |         110 |       00:00:40 |       3.4591 |       99.95% |         0.74 |          0.0010 |
|       7 |         120 |       00:00:43 |       3.1390 |       99.10% |         1.02 |          0.0010 |
|       8 |         130 |       00:00:47 |       3.7369 |       99.77% |         1.14 |          0.0010 |
|       8 |         140 |       00:00:50 |       3.3740 |       99.07% |         0.99 |          0.0010 |
|       9 |         150 |       00:00:54 |       3.9861 |       98.99% |         1.35 |          0.0010 |
|       9 |         160 |       00:00:57 |       3.1168 |       99.24% |         0.90 |          0.0010 |
|      10 |         170 |       00:01:01 |       4.1509 |       98.92% |         1.40 |          0.0010 |
|      10 |         180 |       00:01:04 |       1.8926 |       99.97% |         0.55 |          0.0010 |
|      11 |         190 |       00:01:08 |       1.9993 |       99.13% |         0.88 |          0.0010 |
|      12 |         200 |       00:01:11 |       1.0826 |       99.98% |         0.41 |          0.0010 |
|      12 |         210 |       00:01:15 |       2.2213 |       99.37% |         0.78 |          0.0010 |
|      13 |         220 |       00:01:18 |       2.4177 |       99.84% |         0.99 |          0.0010 |
|      13 |         230 |       00:01:22 |       2.7282 |       99.32% |         0.81 |          0.0010 |
|      14 |         240 |       00:01:26 |       3.7748 |       99.20% |         1.20 |          0.0010 |
|      14 |         250 |       00:01:29 |       2.5727 |       99.33% |         0.77 |          0.0010 |
|      15 |         260 |       00:01:33 |       3.8397 |       99.15% |         1.23 |          0.0010 |
|      15 |         270 |       00:01:36 |       0.7277 |       99.99% |         0.43 |          0.0010 |
|      16 |         280 |       00:01:40 |       2.9342 |       99.22% |         0.78 |          0.0010 |
|      17 |         290 |       00:01:43 |       0.8218 |       99.99% |         0.39 |          0.0010 |
|      17 |         300 |       00:01:47 |       1.9056 |       99.53% |         0.70 |          0.0010 |
|      18 |         310 |       00:01:50 |       1.3405 |       99.93% |         0.71 |          0.0010 |
|      18 |         320 |       00:01:54 |       1.5981 |       99.42% |         0.79 |          0.0010 |
|      19 |         330 |       00:01:57 |       2.9586 |       99.30% |         1.02 |          0.0010 |
|      19 |         340 |       00:02:01 |       1.2590 |       99.43% |         0.66 |          0.0010 |
|      20 |         350 |       00:02:05 |       2.8346 |       99.27% |         1.04 |          0.0010 |
|      20 |         360 |       00:02:08 |       0.9523 |       99.98% |         0.49 |          0.0010 |
|      21 |         370 |       00:02:12 |       1.5545 |       99.26% |         0.74 |          0.0010 |
|      22 |         380 |       00:02:16 |       0.7113 |       99.99% |         0.50 |          0.0010 |
|      22 |         390 |       00:02:19 |       2.0585 |       99.58% |         0.61 |          0.0010 |
|      23 |         400 |       00:02:23 |       1.0087 |       99.93% |         0.55 |          0.0010 |
|      23 |         410 |       00:02:26 |       2.1091 |       99.48% |         0.71 |          0.0010 |
|      24 |         420 |       00:02:30 |       2.9050 |       99.34% |         0.95 |          0.0010 |
|      24 |         430 |       00:02:33 |       2.1874 |       99.44% |         0.64 |          0.0010 |
|      25 |         440 |       00:02:37 |       2.9181 |       99.30% |         0.91 |          0.0010 |
|      25 |         450 |       00:02:40 |       1.9445 |       99.94% |         0.39 |          0.0010 |
|      26 |         460 |       00:02:44 |       3.2013 |       99.20% |         0.71 |          0.0010 |
|      27 |         470 |       00:02:48 |       0.3986 |       99.98% |         0.37 |          0.0010 |
|      27 |         480 |       00:02:51 |       1.3009 |       99.57% |         0.58 |          0.0010 |
|      28 |         490 |       00:02:55 |       1.0920 |       99.96% |         0.60 |          0.0010 |
|      28 |         500 |       00:02:58 |       1.7258 |       99.58% |         0.65 |          0.0010 |
|      29 |         510 |       00:03:02 |       2.7426 |       99.42% |         0.86 |          0.0010 |
|      29 |         520 |       00:03:05 |       1.4956 |       99.63% |         0.59 |          0.0010 |
|      30 |         530 |       00:03:09 |       2.0561 |       99.39% |         0.85 |          0.0010 |
|      30 |         540 |       00:03:12 |       1.0817 |       99.98% |         0.78 |          0.0010 |
|=======================================================================================================|
Training finished: Max epochs completed.
Detector training complete.
*************************************************************************

Verify the training accuracy by inspecting the training loss for each iteration.

figure
plot(info.TrainingLoss)
grid on
xlabel("Number of Iterations")
ylabel("Training Loss for Each Iteration")

Figure contains an axes object. The axes object with xlabel Number of Iterations, ylabel Training Loss for Each Iteration contains an object of type line.

Read a test image.

img = imread("detectcars.png");

Detect vehicles in the test image by using the trained SSD object detector.

[bboxes,scores] = detect(detector,img);

Display the detection results.

if(~isempty(bboxes))
    img = insertObjectAnnotation(img,"rectangle",bboxes,scores);
end
figure
imshow(img)

Figure contains an axes object. The hidden axes object contains an object of type image.

Input Arguments

collapse all

Labeled ground truth images, specified as a datastore or a table.

  • If you use a datastore, your data must be set up so that calling the datastore with the read and readall functions returns a cell array or table with two or three columns. When the output contains two columns, the first column must contain bounding boxes, and the second column must contain labels, {boxes,labels}. When the output contains three columns, the second column must contain the bounding boxes, and the third column must contain the labels. In this case, the first column can contain any type of data. For example, the first column can contain images or point cloud data.

    databoxeslabels

    The first column must be images.

    M-by-4 matrices of bounding boxes of the form [x, y, width, height], where [x,y] represent the top-left coordinates of the bounding box.

    The third column must be a cell array that contains M-by-1 categorical vectors containing object class names. All categorical data returned by the datastore must contain the same categories.

    For more information, see Datastores for Deep Learning (Deep Learning Toolbox).

Untrained or pretrained SSD object detector, specified as a ssdObjectDetector object.

Training options, specified as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object returned by the trainingOptions (Deep Learning Toolbox) function. To specify the solver name and other options for network training, use the trainingOptions (Deep Learning Toolbox) function.

Note

The trainSSDObjectDetector function does not support these training options:

  • Datastore inputs are not supported when you set the DispatchInBackground training option to true.

  • Datastore inputs are not supported if the Shuffle training option is set to "never" when the ExecutionEnvironment training option is "multi-gpu". For more information about using datastore for parallel training, see Preprocess Data in the Background or in Parallel (Deep Learning Toolbox).

Saved detector checkpoint, specified as an ssdObjectDetector object. To periodically save a detector checkpoint during training, specify CheckpointPath. To control how frequently check points are saved see the CheckPointFrequency and CheckPointFrequencyUnit training options.

To load a checkpoint for a previously trained detector, load the MAT file from the checkpoint path. For example, if the CheckpointPath property of the object specified by options is "/checkpath", you can load a checkpoint MAT file by using this code.

data = load("/checkpath/ssd_checkpoint__216__2018_11_16__13_34_30.mat");
checkpoint = data.detector;

The name of the MAT file includes the iteration number and timestamp of when the detector checkpoint was saved. The detector is saved in the detector variable of the file. Pass this file back into the trainSSDObjectDetector function:

ssdDetector = trainSSDObjectDetector(trainingData,checkpoint,options);

Name-Value Arguments

Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

Before R2021a, use commas to separate each name and value, and enclose Name in quotes.

Example: "PositiveOverlapRange",[0.5 1] sets the vertical axis direction to up.

Range of bounding box overlap ratios between 0 and 1, specified as a two-element vector. Anchor boxes that overlap with ground truth bounding boxes within the specified range are used as positive training samples. The function computes the overlap ratio using the intersection-over-union between two bounding boxes.

Range of bounding box overlap ratios between 0 and 1, specified as a two-element vector. Anchor boxes that overlap with ground truth bounding boxes within the specified range are used as negative training samples. The function computes the overlap ratio using the intersection-over-union between two bounding boxes.

Detector training experiment monitoring, specified as an experiments.Monitor (Deep Learning Toolbox) object for use with the Experiment Manager (Deep Learning Toolbox) app. You can use this object to track the progress of training, update information fields in the training results table, record values of the metrics used by the training, and to produce training plots. For an example using this app, see Train Object Detectors in Experiment Manager.

Information monitored during training:

  • Training loss at each iteration.

  • Training accuracy at each iteration.

  • Training root mean square error (RMSE) for the box regression layer.

  • Learning rate at each iteration.

Validation information when the training options input contains validation data:

  • Validation loss at each iteration.

  • Validation accuracy at each iteration.

  • Validation RMSE at each iteration.

Output Arguments

collapse all

Trained SSD object detector, returned as ssdObjectDetector object. You can train a SSD object detector to detect multiple object classes.

Training progress information, returned as a structure array with eight fields. Each field corresponds to a stage of training.

  • TrainingLoss — Training loss at each iteration is calculated as the sum of regression loss and classification loss. To compute the regression loss, the trainSSDObjectDetector function uses smooth L1 loss function. To compute the classification loss the trainSSDObjectDetector function uses the softmax and binary cross-entropy loss function.

  • TrainingAccuracy — Training set accuracy at each iteration.

  • TrainingRMSE — Training root mean squared error (RMSE) is the RMSE calculated from the training loss at each iteration.

  • BaseLearnRate — Learning rate at each iteration.

  • ValidationLoss — Validation loss at each iteration.

  • ValidationAccuracy — Validation accuracy at each iteration.

  • ValidationRMSE — Validation RMSE at each iteration.

  • FinalValidationLoss — Final validation loss at end of the training.

  • FinalValidationRMSE — Final validation RMSE at end of the training.

Each field is a numeric vector with one element per training iteration. Values that have not been calculated at a specific iteration are assigned as NaN. The struct contains ValidationLoss, ValidationAccuracy, ValidationRMSE, FinalValidationLoss, and FinalValidationRMSE fields only when options specifies validation data.

References

[1] W. Liu, E. Anguelov, D. Erhan, C. Szegedy, S. Reed, C.Fu, and A.C. Berg. "SSD: Single Shot MultiBox Detector." European Conference on Computer Vision (ECCV), Springer Verlag, 2016

Version History

Introduced in R2020a

expand all