How to create a transformer network for sequence to sequence classification task?

조회 수: 28 (최근 30일)
veritas
veritas 2024년 9월 6일
답변: Prasanna 2024년 9월 9일
I am currently trying to use MATLAB to complete a task of classifying time series using a transformer network. The following is my code, but I cannot solve the error after compiling.
lgraph = [ ...
sequenceInputLayer(InputSize,Name="input")
positionEmbeddingLayer(InputSize,maxPosition,Name="pos-emb");
additionLayer(2, Name="embed_add");
selfAttentionLayer(numHeads,numKeyChannels) % self attention
additionLayer(2,Name="attention_add") % residual connection around attention
layerNormalizationLayer(Name="attention_norm") % layer norm
fullyConnectedLayer(feedforwardHiddenSize) % feedforward part 1
reluLayer % nonlinear activation
fullyConnectedLayer(attentionHiddenSize) % feedforward part 2
additionLayer(2,Name="feedforward_add") % residual connection around feedforward
layerNormalizationLayer() % layer norm
% selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal');
% selfAttentionLayer(numHeads,numKeyChannels);
indexing1dLayer("last")
fullyConnectedLayer(NumClass)
softmaxLayer
classificationLayer];
% Layers = layerGraph(lgraph);
% Layers = connectLayers(Layers,"input","add/in2");
net = dlnetwork(lgraph,Initialize=false);
net = connectLayers(net,"embed_add","attention_add/in2");
net = connectLayers(net,"pos-emb","embed_add/in2");
net = connectLayers(net,"attention_norm","feedforward_add/in2");
% net = connectLayers(net,"encoder1_out","attention2_add/in2");
% net = connectLayers(net,"attention2_norm","feedforward2_add/in2");
net = initialize(net);

답변 (1개)

Prasanna
Prasanna 2024년 9월 9일
Hi veritas,
The error you're encountering is due to the use of the classificationLayer, which is not supported in the context of a dlnetwork‘ object because dlnetwork is designed for custom training loops and does not require an explicit output layer like classificationLayer. Instead, you should handle the loss calculation separately during training.
Here's how you can modify your setup to avoid using classificationLayer:
  • Remove the classificationLayer from your layer graph definition.
  • With dlnetwork, you typically use a custom training loop where you manually compute the loss and update the model parameters.
  • Use a loss function such as cross-entropy directly in your training loop.
To perform the above, you can use thetrainnet function instead of train dlnetwork objects and set the loss function to crossentropy instead. For more references on the functions, refer the following documentation:
Hope this helps!

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by