deep learning toolbox : Input format mismatch

조회 수: 4 (최근 30일)
湃 林
湃 林 2023년 10월 6일
댓글: 湃 林 2023년 10월 9일
I want to build a lstm network with an attention mechanism.
According to a paper, I need to add the output layer of lstm with the processed hidden layer.
But when I use the additionlayer I get a message that the input format doesn't match!
I found that the output layer is in 10(C) x 1(B) x 1(T) format, while the hidden layer is in 10(C) x 1(B) format.
Normally they should be addable because they are the same size in the first two dimensions and the output layer has a length of 1 in the third dimension.
What do I need to do to make them add up. Thanks!
Following is the deep network I designed according to the paper. It is a CNN-LSTM-Attention network. It is used to predict a time series sequences.
sequence = sequenceInputLayer([kim,1],"Name",'sequence');
conv1 = convolution1dLayer(3,64,'Name','conv1','Padding','causal');
batchnorm1 = batchNormalizationLayer('Name','batchnorm1');
relu1 = reluLayer('Name','relu1');
conv2 = convolution1dLayer(3,64,'Name','conv2','Padding','causal');
batchnorm2 = batchNormalizationLayer('Name','batchnorm2');
relu2 = reluLayer('Name','relu2');
conv3 = convolution1dLayer(3,64,'Name','conv3','Padding','causal');
batchnorm3 = batchNormalizationLayer('Name','batchnorm3');
add1 = additionLayer(2,"Name",'add1');
relu3 = reluLayer('Name','relu3');
maxpool = maxPooling1dLayer(10,'Name','maxpool');
flatten1 = flattenLayer("Name",'flatten1');
lstm = lstmLayer(10,'Name','lstm','HasStateOutputs',1);
fc1 = fullyConnectedLayer(10,'Name','fc1');
tanh1 = tanhLayer('Name','tanh1');
softmax1 = softmaxLayer('Name','softmax1');
multiplication1 = multiplicationLayer(2,'Name','multiplication1');
fc2 = fullyConnectedLayer(10,'Name','fc2');
tanh2 = tanhLayer('Name','tanh2');
softmax2 = softmaxLayer('Name','softmax2');
multiplication2 = multiplicationLayer(2,'Name','multiplication2');
add2 = additionLayer(3,"Name",'add2');
fc3 = fullyConnectedLayer(10,'Name','fc3');
softmax3 = softmaxLayer('Name','softmax3');
regression = regressionLayer('Name','regression');
lgraph = layerGraph;
lgraph = addLayers(lgraph,sequence);
lgraph = addLayers(lgraph,conv1);
lgraph = addLayers(lgraph,batchnorm1);
lgraph = addLayers(lgraph,relu1);
lgraph = addLayers(lgraph,conv2);
lgraph = addLayers(lgraph,batchnorm2);
lgraph = addLayers(lgraph,relu2);
lgraph = addLayers(lgraph,conv3);
lgraph = addLayers(lgraph,batchnorm3);
lgraph = addLayers(lgraph,add1);
lgraph = addLayers(lgraph,relu3);
lgraph = addLayers(lgraph,maxpool);
lgraph = addLayers(lgraph,flatten1);
lgraph = addLayers(lgraph,lstm);
lgraph = addLayers(lgraph,fc1);
lgraph = addLayers(lgraph,tanh1);
lgraph = addLayers(lgraph,softmax1);
lgraph = addLayers(lgraph,multiplication1);
lgraph = addLayers(lgraph,fc2);
lgraph = addLayers(lgraph,tanh2);
lgraph = addLayers(lgraph,softmax2);
lgraph = addLayers(lgraph,multiplication2);
lgraph = addLayers(lgraph,add2);
lgraph = addLayers(lgraph,fc3);
lgraph = addLayers(lgraph,softmax3);
lgraph = addLayers(lgraph,regression);
lgraph = connectLayers(lgraph,'sequence','conv1');
lgraph = connectLayers(lgraph,'conv1','batchnorm1');
lgraph = connectLayers(lgraph,'batchnorm1','relu1');
lgraph = connectLayers(lgraph,'batchnorm1','add1/in1');
lgraph = connectLayers(lgraph,'relu1','conv2');
lgraph = connectLayers(lgraph,'conv2','batchnorm2');
lgraph = connectLayers(lgraph,'batchnorm2','relu2');
lgraph = connectLayers(lgraph,'relu2','conv3');
lgraph = connectLayers(lgraph,'conv3','batchnorm3');
lgraph = connectLayers(lgraph,'batchnorm3','add1/in2');
lgraph = connectLayers(lgraph,'add1/out','relu3');
lgraph = connectLayers(lgraph,'relu3','maxpool');
lgraph = connectLayers(lgraph,'maxpool','flatten1');
lgraph = connectLayers(lgraph,'flatten1','lstm');
lgraph = connectLayers(lgraph,'lstm/out','add2/in1');
lgraph = connectLayers(lgraph,'lstm/hidden','fc1');
lgraph = connectLayers(lgraph,'fc1','tanh1');
lgraph = connectLayers(lgraph,'tanh1','softmax1');
lgraph = connectLayers(lgraph,'softmax1','multiplication1/in1');
lgraph = connectLayers(lgraph,'lstm/hidden','multiplication1/in2');
lgraph = connectLayers(lgraph,'multiplication1','add2/in2');
lgraph = connectLayers(lgraph,'lstm/cell','fc2');
lgraph = connectLayers(lgraph,'fc2','tanh2');
lgraph = connectLayers(lgraph,'tanh2','softmax2');
lgraph = connectLayers(lgraph,'softmax2','multiplication2/in1');
lgraph = connectLayers(lgraph,'lstm/cell','multiplication2/in2');
lgraph = connectLayers(lgraph,'multiplication2','add2/in3');
lgraph = connectLayers(lgraph,'add2','fc3');
lgraph = connectLayers(lgraph,'fc3','softmax3');
lgraph = connectLayers(lgraph,'softmax3','regression');
plot(lgraph)
analyzeNetwork(lgraph)

채택된 답변

David Ho
David Ho 2023년 10월 9일
Hello 湃林,
It looks like the error in the multiplication layer comes from the fact that one of the inputs, from the LSTM layer, has a time dimension, while the other two inputs do not. That means that if, for example, your sequences have length 100, the layer is trying to multiply a [10 1] matrix, a [10 1] matrix, and a [10 1 100] array, which is not a supported operation (multiplication layers do not permit implicit expansion).
One method of resolving the issue is to change the output mode of the LSTM layer from "sequence" to "last", so that it only outputs a single time step:
lstm = lstmLayer(10,'Name','lstm','HasStateOutputs',1,'OutputMode','last');
However, I'm not sure if this is the result you need. If the time steps are important and you want to perform multiplication with implicit expansion, you can substitute the multiplication layer (and the following addition layer) for a functionLayer:
multiplication1 = functionLayer(@(x,y,z) x .* y .* z, 'Formattable', true);
add2 = functionLayer(@(x,y) x + y, 'Formattable', true);
Hopefully one of these options will give you the result you're looking for.
As a final note, if you want to implement the self-attention mechanism from the Attention is all you need paper, you may wish to look at Deep Learning Toolbox's implementation of the self-attention layer:
This handles the matrix multiplication of the keys, queries and values internally, so you don't need to implement it yourself.
Best regards,
David
  댓글 수: 1
湃 林
湃 林 2023년 10월 9일
Thanks so much! I think I will solve this problem soon with your help.

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by