# Train Network with LSTM Projected Layer

Train a deep learning network with an LSTM projected layer for sequence-to-label classification.

To compress a deep learning network, you can use projected layers. The layer introduces learnable projector matrices $Q$, replaces multiplications of the form $Wx$, where $W$ is a learnable matrix, with the multiplication $WQ{Q}^{\top }x$, and stores $Q$ and $W\prime =WQ$ instead of storing $W$. Projecting $x$ into a lower dimensional space using $Q$ typically requires less memory to store the learnable parameters and can have similarly strong prediction accuracy.

Reducing the number of learnable parameters by projecting an LSTM layer rather than reducing the number of hidden units of the LSTM layer maintains the output size of the layer and, in turn, the sizes of the downstream layers, which can result in better prediction accuracy.

These charts compare the test accuracy and the number of learnable parameters of the LSTM network and the projected LSTM network that you train in this example.

In this example, you train an LSTM network for sequence classification, then train an equivalent network with an LSTM projected layer. You then compare the test accuracy and the number of learnable parameters for each of the networks.

Load the Japanese Vowels data set described in [1] and [2]. `XTrain` is a cell array containing 270 sequences of varying length with 12 features corresponding to LPC cepstrum coefficients. `TTrain` is a categorical vector of labels 1, 2, ..., 9. The entries in `XTrain` are matrices with 12 rows (one row for each feature) and a varying number of columns (one column for each time step).

`[XTrain,TTrain] = japaneseVowelsTrainData;`

Visualize the first time series in a plot. Each line corresponds to a feature.

```figure plot(XTrain{1}') title("Training Observation 1") numFeatures = size(XTrain{1},1); legend("Feature " + string(1:numFeatures),Location="northeastoutside")```

### Define Network Architecture

Define the LSTM network architecture.

• Specify a sequence input layer with an input size matching the number of features of the input data.

• Specify an LSTM layer with 100 hidden units that outputs the last element of the sequence.

• Specify a fully connected layer of a size equal to the number of classes, followed by a softmax layer and a classification layer.

```inputSize = 12; numHiddenUnits = 100; numClasses = 9; layers = [ ... sequenceInputLayer(inputSize) lstmLayer(numHiddenUnits,OutputMode="last") fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];```

### Specify Training Options

Specify the training options.

• Train using the Adam solver.

• Train with a mini-batch size of 27 for 50 epochs.

• Because the mini-batches are small with short sequences, the CPU is better suited for training. Train using the CPU.

• Display the training progress in a plot and suppress the verbose output.

```maxEpochs = 50; miniBatchSize = 27; options = trainingOptions("adam", ... MiniBatchSize=miniBatchSize, ... MaxEpochs=maxEpochs, ... ExecutionEnvironment="cpu", ... Plots="training-progress", ... Verbose=false);```

### Train Network

Train the LSTM network with the specified training options.

`net = trainNetwork(XTrain,TTrain,layers,options);`

### Test Network

Calculate the classification accuracy of the predictions on the test data.

```[XTest,TTest] = japaneseVowelsTestData; YTest = classify(net,XTest,MiniBatchSize=miniBatchSize); acc = sum(YTest == TTest)./numel(TTest)```
```acc = 0.9297 ```

View the number of learnables of the network using the `analyzeNetwork` function.

`analyzeNetwork(net)`

In order to compare the total number of learnable parameters of each network, store the total number of learnable parameters in a variable.

`totalLearnables = 46100;`

### Train Projected LSTM Network

Create an identical network with an LSTM projected layer in place of the LSTM layer.

For the LSTM projected layer:

• Specify the same number of hidden units as the LSTM layer

• Specify an output projector size of 25% of the number of hidden units.

• Specify an input projector size of 75% of the input size.

• Ensure that the output and input projector sizes are positive by taking the maximum of the sizes and 1.

```outputProjectorSize = max(1,floor(0.25*numHiddenUnits)); inputProjectorSize = max(1,floor(0.75*inputSize)); layersProjected = [ ... sequenceInputLayer(inputSize) lstmProjectedLayer(numHiddenUnits,outputProjectorSize,inputProjectorSize,OutputMode="last") fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];```

Train the projected LSTM network with the same data and training options.

`netProjected = trainNetwork(XTrain,TTrain,layersProjected,options);`

### Test Projected Network

Calculate the classification accuracy of the predictions on the test data.

```[XTest,TTest] = japaneseVowelsTestData; YTest = classify(netProjected,XTest,MiniBatchSize=miniBatchSize); accProjected = sum(YTest == TTest)./numel(TTest)```
```accProjected = 0.8784 ```

View the number of learnables of the network using the `analyzeNetwork` function.

`analyzeNetwork(netProjected)`

In order to compare the total number of learnable parameters of each network, store the total number of learnable parameters in a variable.

`totalLearnablesProjected = 17500;`

### Compare Networks

Compare the test accuracy and number of learnables in each network. Depending on the projection sizes, the projected network can have significantly fewer learnable parameters and still maintain strong prediction accuracy.

Create a bar chart showing the test accuracy of each network.

```figure bar([acc accProjected]) xticklabels(["Unprojected","Projected"]) xlabel("Network") ylabel("Test Accuracy") title("Test Accuracy")```

Create a bar chart showing the test accuracy the number of learnables of each network.

```figure bar([totalLearnables totalLearnablesProjected]) xticklabels(["Unprojected","Projected"]) xlabel("Network") ylabel("Number of Learnables") title("Number of Learnables")```

### Bibliography

1. M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

2. UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels