Dynamic sequence length for transforme-based model - error when exporting from Python to MATLAB

조회 수: 2 (최근 30일)
We developed a simple transformer architecture (see the Python code below). This model, which we created using Python, can handle sequences of different lengths. I want to use my model in MATLAB. I tried to export the model to ONNX or to PT format. In both cases, I had to fix the input shape to export my model. I used the torch.jit.script() function in Python to trace and export my model in the .pt format. However, I think pytorchmex from the Deep Learning Toolbox Converter for PyTorch Models only works with torch.jit.trace.
I want to find a way to use a model in MATLAB that can accept inputs of any length.
Any help would be much appreciated.
# Python Code
# Model class to export
class TransformerModel(nn.Module):
def __init__(
self,
input_dim,
model_dim,
n_classes,
num_heads,
num_layers,
):
super(TransformerModel, self).__init__()
self.model_dim = model_dim
# Embedding Layer
self.embedding = nn.Linear(input_dim, model_dim)
# Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=model_dim,
nhead=num_heads,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Output Layer
self.output_layer = nn.Linear(model_dim, n_classes)
def forward(self, x, padding_mask):
padding_mask = ~padding_mask
x = self.embedding(x)
# Transformer Encoder
x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
# Model prediction
output = self.output_layer(x)
return output

답변 (0개)

카테고리

Help CenterFile Exchange에서 Deep Learning with GPU Coder에 대해 자세히 알아보기

제품


릴리스

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by