Tips on Importing Models from TensorFlow, PyTorch, and ONNX
This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.
Import Functions of Deep Learning Toolbox
This table lists the Deep Learning Toolbox import functions. Use these functions to import networks from TensorFlow, PyTorch, and ONNX.
You must have the relevant support package to run these import functions. If the support package is not installed, each function provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.
Autogenerated Custom Layers
The
importNetworkFromTensorFlow
,importNetworkFromPyTorch
, andimportNetworkFromONNX
functions can automatically generate custom layers, or custom layers with placeholder functions, when you import TensorFlow, PyTorch, or ONNX layers that the software cannot convert into equivalent built-in MATLAB functions or layers.The
importNetworkFromTensorFlow
,importNetworkFromPyTorch
, andimportNetworkFromONNX
functions import an external platform layer into MATLAB by trying these steps in order:The function imports the external layer as a built-in MATLAB layer.
The function imports the external layer as a built-in MATLAB function (for TensorFlow and PyTorch only).
The function imports the external layer as a custom layer.
The function imports the external layer as a custom layer with a placeholder function.
For more information about custom layer generation, see the
Algorithms
section of each function: Algorithms (TensorFlow), Algorithms (PyTorch), and Algorithms (ONNX).
Input Dimension Ordering
The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX. This table compares input dimension ordering between platforms for different input types.
Input Type | Dimension Ordering | |||
---|---|---|---|---|
MATLAB | TensorFlow | PyTorch | ONNX | |
Features | CN | NC | NC | NC |
2-D image | HWCN | NHWC | NCHW | NCHW |
3-D image | HWDCN | NHWDC | NCDHW | NCHWD |
Vector sequence | CSN | NSC | SNC | NSC |
2-D image sequence | HWCSN | NSWHC | NCSHW | NSCHW |
3-D image sequence | HWDCSN | NSWHDC | NCSDHW | NSCHWD |
Variable names in the table:
N — Number of observations
C — Number of features or channels
H — Height of images
W — Width of images
D — Depth of images
S — Sequence length
Data Formats for Prediction with dlnetwork
The importNetworkFromTensorFlow
function imports a TensorFlow network as an initialized dlnetwork
object. For an example,
see Import TensorFlow Network and Classify Image. If the network does
not have fixed input size, the function imports the model as an uninitialized
dlnetwork
object without an input layer. For an example about how
to initialize this network, see Import and Initialize TensorFlow Network.
The importNetworkFromPyTorch
function imports a PyTorch network as an uninitialized or initialized dlnetwork
object. If the imported network is uninitialized, before you use the network, do one of
the following:
Add an input layer to the imported network and initialize the network by using the
addInputLayer
function. For an example, see Import Network from PyTorch and Add Input Layer.Define a
dlarray
object with the appropriate data format and use theinitialize
function to initialize the network. For an example, see Import Network from PyTorch and Initialize.
A PyTorch network can be imported as an initialized dlnetwork
object by using the PyTorchInputSizes
name-value argument. For an
example, see Import Network from PyTorch using PyTorchInputSizes.
The importNetworkFromONNX
function imports an ONNX network as an initialized dlnetwork
object. For an example,
see Import ONNX Network and Classify Image.
To predict using a dlnetwork
object, you must convert the input data
to a dlarray
object with the appropriate data format. For an example, see Import TensorFlow Network and Classify Image. Use this table to
choose the right data format for each input type and layer.
Input Type | Input Layer ** | Input Format * |
---|---|---|
Features | featureInputLayer | CB |
2-D image | imageInputLayer | SSCB |
3-D image | image3dInputLayer | SSSCB |
Vector sequence | sequenceInputLayer | CBT |
2-D image sequence | sequenceInputLayer | SSCBT |
3-D image sequence | sequenceInputLayer | SSSCBT |
* In Deep Learning Toolbox, each data format must be one of these labels:
S
— SpatialC
— ChannelB
— Batch observationsT
— Time or sequenceU
— Unspecified
** A dlnetwork
object does not require an input layer. The network
can infer the input layer type from the input data format.
For more information on data formats, see dlarray
.
Input Data Preprocessing
To use a pretrained network for prediction or transfer learning on new images, you must preprocess your images in the as same way the images that were used to train the imported model. The most common preprocessing steps are resizing images, subtracting image average values, and converting the images from BGR format to RGB format.
For more information about preprocessing images for training and prediction, see Preprocess Images for Deep Learning.
See Also
importNetworkFromONNX
| importNetworkFromPyTorch
| importNetworkFromTensorFlow
| dlarray
Related Topics
- Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX
- Pretrained Deep Neural Networks