Using self attention layer with 2D images

조회 수: 34 (최근 30일)
MAHMOUD EID
MAHMOUD EID 2023년 11월 7일
답변: Neha 2023년 11월 21일
Hi,
I am wondering how to use the selfattention layer in image calssaifcation using CNN without we need to flatten the data as explained in this example:
% load digit dataset
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.7, 'randomized');
% define network architecture
layers = [
imageInputLayer([28 28 1], 'Name', 'input')
convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1')
batchNormalizationLayer('Name', 'bn1')
reluLayer('Name', 'relu1')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')
batchNormalizationLayer('Name', 'bn2')
reluLayer('Name', 'relu2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2')
flattenLayer('Name', 'flatten')
selfAttentionLayer(8, 64, 'Name', 'self_attention')
fullyConnectedLayer(10, 'Name', 'fc')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'output')]
% set training options
options = trainingOptions('sgdm', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs', 5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', imdsValidation, ...
'ValidationFrequency', 30, ...
'Verbose', false, ...
'Plots', 'training-progress')
% training the network
net = trainNetwork(imdsTrain, layers, options);

답변 (1개)

Neha
Neha 2023년 11월 21일
Hi Mahmoud,
I understand that you want to use self-attention layer in image classification. The self-attention layer, also known as the multi-head self-attention layer, is commonly employed in Transformer models like BERT and vision transformers (ViT). Its primary function is to understand the relationships between positions within the input data. This input data is usually sequential, representing either temporal sequences or 1D spatial information. Therefore it is necessary to use the "flattenLayer" to ensure that the input data to the "selfAttentionLayer" is one directional.
Hope this helps!

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by