importNetworkFromONNX did not recognize softmax layer

조회 수: 8 (최근 30일)
S. Cho
S. Cho 2025년 6월 20일
편집: S. Cho 2025년 6월 21일
Hello,
I am using importNetworkFromONNX to import a neural network model exported from pyTorch.
The pyTorch model includes a softmax layer as below:
import torch
import torch.nn as nn
import torch.nn.functional as F
class TestNetwork(nn.Module):
def __init__(self, input_dim=4, output_dim=2, hidden_dim=5):
super(TestNetwork, self).__init__()
self.fc1_pi = nn.Linear(input_dim, hidden_dim)
self.fc2_pi = nn.Linear(hidden_dim, output_dim)
self.fc1_v = nn.Linear(hidden_dim,1)
def forward(self, x):
x1 = F.relu(self.fc1_pi(x))
x1 = self.fc2_pi(x1)
prob = F.softmax(x1, dim = 0)
x2 = F.relu(self.fc1_pi(x))
v = self.fc1_v(x2)
return prob, v
model = TestNetwork()
x = torch.rand(1,4)
model.to("cpu")
torch.onnx.export(model, x, 'onnx_model.onnx')
Below picture shows the model's netron view. (onnx_model.onnx file is attached as a zip file.)
However, when I imported the onnx model, MATLAB did not recognize the softmax layer.
I know that I can relace the layer with MATLAB's Softmax layer.
But, I want to know how to import the onnx model without replacing the layer.
Below is the code (test_import_onnx.m) that I used to import the onnx model.
clear
modelfile = "onnx_model.onnx";
net = importNetworkFromONNX(modelfile, InputDataFormats='BC');
layout = networkDataLayout([4 NaN],"CB");
net = initialize(net, layout);
net = expandLayers(net);
net.Layers
The results was:
>> test_import_onnx
ans =
9x1 Layer array with layers:
1 'onnx__Gemm_0' Feature Input 4 features
2 'onnx__Gemm_0_BatchSizeVerifier' Verify the fixed batch size Verify the fixed batch size of 1
3 'x_fc1_pi_Gemm' Fully Connected 5 fully connected layer
4 'x_Relu' ReLU ReLU
5 'x_fc2_pi_Gemm' Fully Connected 2 fully connected layer
6 'SoftmaxLayer1003' onnx_model.SoftmaxLayer1003 onnx_model.SoftmaxLayer1003
7 'x_fc1_v_Gemm' Fully Connected 1 fully connected layer
8 'x11Output' Custom output ('CB') See the OutputInformation property to find the output dimension ordering that is produced by this layer.
9 'x12Output' Custom output ('CB') See the OutputInformation property to find the output dimension ordering that is produced by this layer.
Because onnx_model.SoftmaxLayer1003 did not work as a softmax layer, the outputs of SoftmaxLayer1003 were always [1; 1].
  댓글 수: 1
Matt J
Matt J 2025년 6월 20일
편집: Matt J 2025년 6월 20일
If you import only layers 1-6 (i.e when the softmax layer is the final layer), does it work then?

댓글을 달려면 로그인하십시오.

채택된 답변

S. Cho
S. Cho 2025년 6월 21일
편집: S. Cho 2025년 6월 21일
The pyTorch network model had an error.
The softmax's dim option should be 1 instead 0.
So the correct code is as below:
class TestNetwork(nn.Module):
def __init__(self, input_dim=4, output_dim=2, hidden_dim=5):
super(TestNetwork, self).__init__()
self.fc1_pi = nn.Linear(input_dim, hidden_dim)
self.fc2_pi = nn.Linear(hidden_dim, output_dim)
self.fc1_v = nn.Linear(hidden_dim,1)
def forward(self, x):
x1 = F.relu(self.fc1_pi(x))
x1 = self.fc2_pi(x1)
prob = F.softmax(x1, dim = 1)
x2 = F.relu(self.fc1_pi(x))
v = self.fc1_v(x2)
return prob, v
After the modification, MATLAB could recognize the softmax layer as below:
>> test_import_onnx
ans =
9x1 Layer array with layers:
1 'onnx__Gemm_0' Feature Input 4 features
2 'onnx__Gemm_0_BatchSizeVerifier' Verify the fixed batch size Verify the fixed batch size of 1
3 'x_fc1_pi_Gemm' Fully Connected 5 fully connected layer
4 'x_Relu' ReLU ReLU
5 'x_fc2_pi_Gemm' Fully Connected 2 fully connected layer
6 'x_Softmax' Softmax softmax
7 'x_fc1_v_Gemm' Fully Connected 1 fully connected layer
8 'x10Output' Custom output ('CB') See the OutputInformation property to find the output dimension ordering that is produced by this layer.
9 'x11Output' Custom output ('CB') See the OutputInformation property to find the output dimension ordering that is produced by this layer.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Pretrained Networks from External Platforms에 대해 자세히 알아보기

제품


릴리스

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by