- Ensure that the dimensions of WQ, WK, WV, and WO align correctly with the input dimensions. The dimensions should match the expected sizes for matrices.
- Ensure that the outputs from the fc_stft and fc_cwt layers are compatible with the input dimensions expected by your crossAttention function. As your layers end with fullyconnected layers with output size of 100, check the outputSize variable once again and if its matching the expected output size.
- The output of fc_stft and fc_cwt layers should be connected to the inputs of crossAttention instead of directly to the attention layer.
- Try using MATLAB's pagemtimes function for matrix layer multiplication of multi-dimensional arrays like queries, keys and values in the implementation. Here is the MathWorks documetation link for the same: https://www.mathworks.com/help/matlab/ref/pagemtimes.html
how to make cross attention use attentionlayer?
조회 수: 12 (최근 30일)
이전 댓글 표시
I want to replace the dual-branch merge section of the model in the following link with cross-attention for fusion, but it's not successful. Is my operation incorrect? I have written an example, but I still don't understand how to embed it into the model in the link.
net one:(failure, loss dont down)
initialLayers = [
sequenceInputLayer(1, "MinLength", numSamples, "Name", "input", "Normalization", "zscore", "SplitComplexInputs", true)
convolution1dLayer(7, 2, "stride", 1)
];
stftBranchLayers = [
stftLayer("TransformMode", "squaremag", "Window", hann(64), "OverlapLength", 52, "Name", "stft", "FFTLength", 256, "WeightLearnRateFactor", 0 )
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="stft_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "stft_conv_1")
layerNormalizationLayer("Name", "stft_layernorm_1")
reluLayer("Name", "stft_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "stft_conv_2")
layerNormalizationLayer("Name", "stft_layernorm_2")
reluLayer("Name", "stft_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "stft_conv_3")
layerNormalizationLayer("Name", "stft_layernorm_3")
reluLayer("Name", "stft_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_3")
flattenLayer("Name", "stft_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_stft")
];
cwtBranchLayers = [
cwtLayer("SignalLength", numSamples, "TransformMode", "squaremag", "Name","cwt", "WeightLearnRateFactor", 0);
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="cwt_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "cwt_conv_1")
layerNormalizationLayer("Name", "cwt_layernorm_1")
reluLayer("Name", "cwt_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "cwt_conv_2")
layerNormalizationLayer("Name", "cwt_layernorm_2")
reluLayer("Name", "cwt_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "cwt_conv_3")
layerNormalizationLayer("Name", "cwt_layernorm_3")
reluLayer("Name", "cwt_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_3")
flattenLayer("Name", "cwt_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_cwt")
];
finalLayers = [
attentionLayer(4,"Name","attention")
layerNormalizationLayer("Name","layernorm")
fullyConnectedLayer(48,"Name","fc_1")
fullyConnectedLayer(numel(waveformClasses),"Name","fc_2")
softmaxLayer("Name","softmax")
];
dlLayers2 = dlnetwork(initialLayers);
dlLayers2 = addLayers(dlLayers2, stftBranchLayers);
dlLayers2 = addLayers(dlLayers2, cwtBranchLayers);
dlLayers2 = addLayers(dlLayers2, finalLayers);
dlLayers2 = connectLayers(dlLayers2, "conv1d", "stft");
dlLayers2 = connectLayers(dlLayers2, "conv1d", "cwt");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/key");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/value");
dlLayers2 = connectLayers(dlLayers2,"fc_cwt","attention/query");
my example:(is it right ?)
numChannels = 10;
numObservations = 128;
numTimeSteps = 100;
X = rand(numChannels,numObservations,numTimeSteps);
X = dlarray(X);
Y = rand(numChannels,numObservations,numTimeSteps);
Y = dlarray(Y);
numHeads = 8;
outputSize = numChannels*numHeads;
WQ = rand(outputSize, numChannels, 1, 1);
WK = rand(outputSize, numChannels, 1, 1);
WV = rand(outputSize, numChannels, 1, 1);
WO = rand(outputSize, outputSize, 1, 1);
Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO);
function Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO)
queries = WQ * X;
keys = WK * Y;
values = WV * Y;
A = attention(queries, keys, values, numHeads, 'DataFormat', 'CBT');
Z = WO * A;
end
댓글 수: 0
채택된 답변
Sahas
2024년 12월 18일
편집: Sahas
2024년 12월 18일
As per my understanding, you would like to replace the dual-branch merge section of the model with cross-attention. I went through your implementation and observed a few things. The implementation looks structurally correct but ensure the following points when using cross-attention with Classification technique as given in the documentation example:
Hope this is beneficial!
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!