Create Custom Deep Learning Training Plot
This example shows how to create a custom training plot that updates at each iteration during training of deep learning neural networks using trainnet
.
You can specify neural network training options using trainingOptions
. You can create training plots using the "Plots"
and "Metrics"
name-value pair arguments. To create a custom training plot and further customize training beyond the options available in trainingOptions
, specify output functions by using the "OutputFcn"
name-value pair argument of trainingOptions
. trainnet
calls these functions once before the start of training, after each training iteration, and once after training has finished.
Each time the output functions are called, trainnet
passes a structure containing information such as the current iteration number, loss, and accuracy.
The network trained in this example classifies the gear tooth condition of a transmission system into two categories, "Tooth Fault"
and "No Tooth Fault"
, based on a mixture of numeric sensor readings, statistics, and categorical labels. For more information, see Train Neural Network with Tabular Data.
The custom output function defined in this example plots the natural logarithms of gradient norm, step norm, training loss, and validation loss during training and stops training early once the training loss is lower than the desired loss threshhold.
Load and Preprocess Training Data
Read the transmission casing data from the CSV file "transmissionCasingData.csv"
.
filename = "transmissionCasingData.csv"; tbl = readtable(filename,TextType="String");
Convert the labels for prediction, and the categorical predictors to categorical using the convertvars
function. In this data set, there are two categorical features, "SensorCondition"
and "ShaftCondition"
.
labelName = "GearToothCondition"; categoricalPredictorNames = ["SensorCondition" "ShaftCondition"]; tbl = convertvars(tbl,[labelName categoricalPredictorNames],"categorical");
To train a network using categorical features, you must convert the categorical features to numeric. You can do this using the onehotencode
function.
for i = 1:numel(categoricalPredictorNames) name = categoricalPredictorNames(i); tbl.(name) = onehotencode(tbl.(name),2); end
Set aside data for testing. Partition the data into a training set containing 80% of the data, a validation set containing 10% of the data, and a test set containing the remaining 10% of the data. To partition the data, use the trainingPartitions
function, attached to this example as a supporting file. To access this file, open the example as a live script.
numObservations = size(tbl,1); [idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.80 0.1 0.1]); tblTrain = tbl(idxTrain,:); tblValidation = tbl(idxValidation,:); tblTest = tbl(idxTest,:);
Convert the data to a format that the trainnet
function supports. Convert the predictors and targets to numeric and categorical arrays, respectively, using the table2array
function.
predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ... "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ... "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ... "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"]; XTrain = table2array(tblTrain(:,predictorNames)); TTrain = tblTrain.(labelName); XValidation = table2array(tblValidation(:,predictorNames)); TValidation = tblValidation.(labelName); XTest = table2array(tblTest(:,predictorNames)); TTest = tblTest.(labelName);
Network Architecture
Define the neural network architecture.
For feature input, specify a feature input layer with the number of features. Normalize the input using Z-score normalization.
Specify a fully connected layer with a size of 16, followed by a layer normalization and ReLU layer.
For classification output, specify a fully connected layer with a size that matches the number of classes, followed by a softmax layer.
numFeatures = size(XTrain,2);
hiddenSize = 16;
classNames = categories(tbl{:,labelName});
numClasses = numel(classNames);
layers = [
featureInputLayer(numFeatures,Normalization="zscore")
fullyConnectedLayer(hiddenSize)
layerNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer];
Initialize Custom Plot
The next step in training your neural network is to define trainingOptions
. This includes your custom output function. To create an animated plot which is updated during training, set up a figure with animatedline
objects first and then pass the line handles to your output function.
Create a 3-by-1 tiled chart layout. Define animatedline
objects to plot gradientNorm
in the top tile and stepNorm
in the middle tile. Define animatedline
objects to plot trainingLoss
and validationLoss
in the bottom tile. Save the animatedline
handles to a struct called lines
.
tiledlayout(3,1); C = colororder; nexttile lines.gradientNormLine = animatedline(Color=C(1,:)); ylabel("log(gradientNorm)") ylim padded nexttile lines.stepNormLine = animatedline(Color=C(1,:)); ylabel("log(stepNorm)") ylim padded nexttile lines.trainingLossLine = animatedline(Color=C(1,:)); lines.validationLossLine = animatedline(Color=C(2,:)); xlabel("Iterations") ylabel("log(loss)") ylim padded
Define Training Options
Use the function updatePlotAndStopTraining
defined at the bottom of this page to update the animatedline
objects and to stop training early when the training loss is smaller than a desired loss threshold. Use the "OutputFcn"
name-value pair argument of trainingOptions
to pass this function to trainnet
.
Specify the training options:
Train using the L-BFGS solver. This solver suits tasks with small networks and when the data fits in memory.
Train using the CPU. Because the network and data are small, the CPU is better suited.
Validate the network every 5 iterations using the validation data.
Suppress the verbose output.
Include the custom output function
updatePlotAndStopTraining
.
Define the loss threshold.
lossThreshold = 0.4; options = trainingOptions("lbfgs", ... ExecutionEnvironment="cpu", ... ValidationData={XValidation,TValidation}, ... ValidationFrequency=5, ... Verbose=false, ... OutputFcn=@(info)updatePlotAndStopTraining(info,lines,lossThreshold));
Add the step tolerance to the top plot. Add the gradient tolerance to the middle plot. Add the loss threshold to the bottom plot
nexttile(1) yline(log(options.GradientTolerance),"--","log(gradientTolerance)") nexttile(2) yline(log(options.StepTolerance),"--","log(stepTolerance)") nexttile(3) yline(log(lossThreshold),"--","log(lossThreshold)") legend(["log(trainingLoss)","log(validationLoss)",""],"Location","eastoutside")
Train Neural Network
Train the network. To display the reason why training stops, use two output arguments. If training stops because the training loss is smaller than the loss threshold, then the field StopReason
of the info
output argument is "Stopped by OutputFcn"
.
[net,info] = trainnet(XTrain,TTrain,layers,"crossentropy",options);
disp(info.StopReason)
Stopped by OutputFcn
Test Network
Predict the labels of the test data using the trained network. Predict the classification scores using the trained network then convert the predictions to labels using the onehotdecode
function.
scoresTest = predict(net,XTest); YTest = onehotdecode(scoresTest,classNames,2); accuracy = mean(YTest==TTest)
accuracy = 0.8636
Custom Output Function
Define the output function updatePlotAndStopTraining(info,lines,lossThreshold)
, which plots the logarithm of gradient norm, step norm, training loss, and validation loss. It also stops training when the training loss is smaller than the loss threshold. Training stops when the output function returns true
.
function stop = updatePlotAndStopTraining(info,lines,lossThreshold) iteration = info.Iteration; gradientNorm = info.GradientNorm; stepNorm = info.StepNorm; trainingLoss = info.TrainingLoss; validationLoss = info.ValidationLoss; if ~isempty(trainingLoss) addpoints(lines.gradientNormLine,iteration,log(gradientNorm)) addpoints(lines.stepNormLine,iteration,log(stepNorm)) addpoints(lines.trainingLossLine,iteration,log(trainingLoss)) end if ~isempty(validationLoss) addpoints(lines.validationLossLine,iteration,log(validationLoss)) end stop = trainingLoss < lossThreshold; end
See Also
trainnet
| trainingOptions
| fullyConnectedLayer
| Deep Network
Designer | featureInputLayer