Using transformer neural network for classification task

numChannels = inputSize;
maxPosition = 256;
numHeads = 4;
numKeyChannels = numHeads*32;
layers = [
sequenceInputLayer(numChannels,Name="input")
positionEmbeddingLayer(numChannels, maxPosition, Name="pos-emb");
additionLayer(2, Name="add")
selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal')
selfAttentionLayer(numHeads,numKeyChannels)
indexing1dLayer("last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph, "input", "add/in2");
maxEpochs = 100;
miniBatchSize = 32;
learningRate = 0.001;
solver = 'adam';
shuffle = 'every-epoch';
gradientThreshold = 10;
executionEnvironment = "auto"; % chooses local GPU if available, otherwise CPU
options = trainingOptions(solver, ...
'Plots','training-progress', ...
'MaxEpochs', maxEpochs, ...
'MiniBatchSize', miniBatchSize, ...
'Shuffle', shuffle, ...
'InitialLearnRate', learningRate, ...
'GradientThreshold', gradientThreshold, ...
'ExecutionEnvironment', executionEnvironment);
The input size is 12, so there are 12 features.
numClasses is 4, so I am classifying it into 4 class.
But it gives the following error when I try to run it
"
Error in test123_20240727 (line 195)
net=trainNetwork(XTrain, YTrain, layers, options);
Caused by:
Layer 'add': Unconnected input. Each layer input must be connected to the output of another layer.
"
line 195 is "net=trainNetwork(XTrain, YTrain, layers, options);"
Can anyone help me with this?

댓글 수: 7

Hi @ haohaoxuexi1,

By connecting the output of the 'pos-emb' layer to the input of the 'add' layer using connectLayers, and making sure that all layer inputs are properly linked, should resolve the unconnected input error. Here is the updated code.

numChannels = 12;

maxPosition = 256;

numHeads = 4;

numKeyChannels = numHeads * 32;

layers = [

    sequenceInputLayer(numChannels, 'Name', 'input')
    positionEmbeddingLayer(numChannels, maxPosition, 'Name', 'pos-emb')
    additionLayer(2, 'Name', 'add')
    selfAttentionLayer(numHeads, numKeyChannels, 'AttentionMask', 

'causal')

    selfAttentionLayer(numHeads, numKeyChannels)
    indexing1dLayer('last')
    fullyConnectedLayer(4)
    softmaxLayer
    classificationLayer];

lgraph = layerGraph(layers);

lgraph = connectLayers(lgraph, 'pos-emb', 'add/in2'); % Connect 'pos-emb'

output to 'add' input

maxEpochs = 100;

miniBatchSize = 32;

learningRate = 0.001;

solver = 'adam';

shuffle = 'every-epoch';

gradientThreshold = 10;

executionEnvironment = 'auto'; % chooses local GPU if available, otherwise CPU

options = trainingOptions(solver, ...

    'Plots', 'training-progress', ...
    'MaxEpochs', maxEpochs, ...
    'MiniBatchSize', miniBatchSize, ...
    'Shuffle', shuffle, ...
    'InitialLearnRate', learningRate, ...
    'GradientThreshold', gradientThreshold, ...
    'ExecutionEnvironment', executionEnvironment);

% Assuming XTrain and YTrain are your training data

net = trainNetwork(XTrain, YTrain, lgraph, options); % Use lgraph instead of layers

Hope this should help resolve your problem.

thank you Umar,
"net = trainNetwork(XTrain, YTrain, lgraph, options); % Use lgraph instead of layers"
here is the problem, now it works
Hi haohaoxuexi1,
If your problem is resolved, please accept answer by clicking it.
@Umar You have posted your solution as a comment rather than an answer. Therefore, it is not possible for the OP to Accept-click it.
@Matt J,
For some reason when I click Answer this question, it defaults to draft.
Hi @ haohaoxuexi1,
If you are still having issues with modifying your code, please let us know. We will be happy to help you out.
@Umar Hi Umar, I am good at the moment. Will let u know if I have further question.

댓글을 달려면 로그인하십시오.

 채택된 답변

Joss Knight
Joss Knight 2024년 7월 29일

0 개 추천

You've passed layers instead of lgraph to trainNetwork.

댓글 수: 2

@Joss Knight, Thanks for jumping in. Please advice how to use lgraph to trainNetwork by providing code snippet. Again, thanks for your cooperation.
net=trainNetwork(XTrain, YTrain, lgraph, options);
instead of
net=trainNetwork(XTrain, YTrain, layers, options);

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

도움말 센터File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품

릴리스

R2024a

질문:

2024년 7월 28일

댓글:

2024년 8월 13일

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by