updateScore
Description
computes and accumulates Taylor-based importance scores of convolution filters in prunable
layers. This function returns another prunableNet_new
= updateScore(prunableNet
,pruningActivations
,pruningGradients
)TaylorPrunableNetwork
object whose state contains these updated scores.
To get robust estimates of the importance scores of the convolution filters in your
network, execute updateScore
several times on the same prunable network
for different mini-batches of data.
To prune a deep neural network, you require the Deep Learning Toolbox™ Model Quantization Library support package This support package is a free add-on that you can download using the Add-On Explorer. Alternatively, see Deep Learning Toolbox Model Quantization Library.
Examples
Prune dlnetwork
Object to Compress the Model
This example shows how to prune a dlnetwork
object by using a custom pruning loop.
Load dlnetwork
Object
Load a trained dlnetwork
object and the corresponding classes.
s = load("digitsCustom.mat");
dlnet_1 = s.dlnet;
classes = s.classes;
Inspect the layers of the dlnetwork
object. The network has three convolution layers at locations 2
, 5
, and 8
of the Layer
array.
layers_1 = dlnet_1.Layers
layers_1 = 12x1 Layer array with layers: 1 'input' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv1' 2-D Convolution 20 5x5x1 convolutions with stride [1 1] and padding [0 0 0 0] 3 'bn1' Batch Normalization Batch normalization with 20 channels 4 'relu1' ReLU ReLU 5 'conv2' 2-D Convolution 20 3x3x20 convolutions with stride [1 1] and padding [1 1 1 1] 6 'bn2' Batch Normalization Batch normalization with 20 channels 7 'relu2' ReLU ReLU 8 'conv3' 2-D Convolution 20 3x3x20 convolutions with stride [1 1] and padding [1 1 1 1] 9 'bn3' Batch Normalization Batch normalization with 20 channels 10 'relu3' ReLU ReLU 11 'fc' Fully Connected 10 fully connected layer 12 'softmax' Softmax softmax
Load Data for Prediction
Load the digits data for prediction.
dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset"); imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames");
Partition the data into pruning and validation sets. Set aside 10% of the data for validation using the splitEachLabel
function.
[imdsPrune,imdsValidation] = splitEachLabel(imds,0.9,"randomize");
The network used in this example requires input images of size 28-by-28-by-1. To automatically resize the images, use augmented image datastores.
inputSize = [28 28 1]; augimdsPrune = augmentedImageDatastore(inputSize(1:2),imdsPrune); augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Prune dlnetwork
Object
Convert the dlnetwork
object to a representation that is suitable for pruning by using the taylorPrunableNetwork
function. This function returns a TaylorPrunableNetwork
object that has the NumPrunables
property set to 48
. This indicates that 48
filters in the original model are suitable for pruning by using the Taylor pruning algorithm.
prunableNet_1 = taylorPrunableNetwork(dlnet_1)
prunableNet_1 = TaylorPrunableNetwork with properties: Learnables: [14x3 table] State: [6x3 table] InputNames: {'input'} OutputNames: {'softmax'} NumPrunables: 48
Create a minibatchqueue
object that processes and manages mini-batches of images during pruning. For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to convert the labels to one-hot encoded variables.Format the image data with the dimension labels
"SSCB"
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
. Do not format the class labels.Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(augimdsPrune, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@preprocessMiniBatch, ... MiniBatchFormat=["SSCB" ""]);
Calculate Taylor-based importance scores of the prunable filters in the network by looping over the mini-batches of data. For each mini-batch:
Calculate pruning activations and pruning gradients by using the
modelLoss
function defined at the end of this exampleUpdate importance scores of the prunable filters by using the
updateScore
function
while hasdata(mbq) [X,T] = next(mbq); [~,pruningActivations,pruningGradients] = dlfeval(@modelLoss,prunableNet_1,X,T); prunableNet_1 = updateScore(prunableNet_1,pruningActivations,pruningGradients); end
Finally, remove filters with the lowest importance scores to create a new TaylorPrunableNetwork
object by using the updatePrunables
function. By default, a single call to this function removes 8
filters. Observe that the new network prunableNet_2
has 40
prunable filters remaining.
prunableNet_2 = updatePrunables(prunableNet_1)
prunableNet_2 = TaylorPrunableNetwork with properties: Learnables: [14x3 table] State: [6x3 table] InputNames: {'input'} OutputNames: {'softmax'} NumPrunables: 40
To further compress the model, run the custom pruning loop and update prunables again.
Extract Pruned dlnetwork
Object
Use the dlnetwork
function to extract the pruned dlnetwork
object from the pruned TaylorPrunableNetwork
object. You can now use this compressed dlnetwork
object to perform inference.
dlnet_2 = dlnetwork(prunableNet_2);
Compare the convolution layers of the original and the pruned dlnetwork
objects. Observe that the three convolution layers in the pruned network have fewer filters. These counts agree with the fact that, by default, a single call to the updatePrunables
function removes 8
filters from the network.
conv_layers_1 = dlnet_1.Layers([2 5 8])
conv_layers_1 = 3x1 Convolution2DLayer array with layers: 1 'conv1' 2-D Convolution 20 5x5x1 convolutions with stride [1 1] and padding [0 0 0 0] 2 'conv2' 2-D Convolution 20 3x3x20 convolutions with stride [1 1] and padding [1 1 1 1] 3 'conv3' 2-D Convolution 20 3x3x20 convolutions with stride [1 1] and padding [1 1 1 1]
conv_layers_2 = dlnet_2.Layers([2 5 8])
conv_layers_2 = 3x1 Convolution2DLayer array with layers: 1 'conv1' 2-D Convolution 17 5x5x1 convolutions with stride [1 1] and padding [0 0 0 0] 2 'conv2' 2-D Convolution 18 3x3x17 convolutions with stride [1 1] and padding [1 1 1 1] 3 'conv3' 2-D Convolution 17 3x3x18 convolutions with stride [1 1] and padding [1 1 1 1]
Supporting Functions
Model Loss Function
The modelLoss
function takes a TaylorPrunableNetwork
object net
, a mini-batch of input data X
with corresponding targets T
and returns activations in net
and the gradients of the loss with respect to the activations in net
. To compute the gradients automatically, this function uses the dlgradient
function.
function [loss, pruningActivations, pruningGradients] = modelLoss(net,X,T) % Calculate network output for training. [out, ~, pruningActivations] = forward(net,X); % Calculate loss. loss = crossentropy(out,T); % Compute pruning gradients. pruningGradients = dlgradient(loss,pruningActivations); end
Mini Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses a mini-batch of predictors and labels using the following steps:
Preprocess the images using the
preprocessMiniBatchPredictors
function.Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.
One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
function [X,T] = preprocessMiniBatch(dataX,dataT) % Preprocess predictors. X = preprocessMiniBatchPredictors(dataX); % Extract label data from cell and concatenate. T = cat(2,dataT{1:end}); % One-hot encode labels. T = onehotencode(T,1); end
Mini-Batch Predictors Preprocessing Function
The preprocessMiniBatchPredictors
function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenating into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension.
function X = preprocessMiniBatchPredictors(dataX) % Concatenate. X = cat(4,dataX{1:end}); % Normalize the images. X = X/255; end
Input Arguments
prunableNet
— Network for pruning by using first-order Taylor approximation
TaylorPrunableNetwork
object
Network for pruning by using first-order Taylor approximation, specified as a
TaylorPrunableNetwork
object.
pruningActivations
— Activations of the pruning layers
cell array containing dlarray
objects
Activations of the pruning layers, specified as a cell array containing
dlarray
objects. To retrieve these values, call the forward
function on the prunable
network.
pruningGradients
— Gradients of loss with respect to activations
cell array containing dlarray
objects
Gradients of loss with respect to pruningActivations
, specified
as a cell array containing dlarray
objects. To calculate
pruningGradients
, first calculate the loss and then use the
dlgradient
function.
Output Arguments
prunableNet_new
— Updated network for pruning
TaylorPrunableNetwork
object
Network object for pruning that been updated to contain the accumulated Taylor-based
importance scores of the prunable filters, specified as a
TaylorPrunableNetwork
object.
Version History
Introduced in R2022a
MATLAB 명령
다음 MATLAB 명령에 해당하는 링크를 클릭했습니다.
명령을 실행하려면 MATLAB 명령 창에 입력하십시오. 웹 브라우저는 MATLAB 명령을 지원하지 않습니다.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)