Using self attention layer with 2D images
조회 수: 34 (최근 30일)
이전 댓글 표시
Hi,
I am wondering how to use the selfattention layer in image calssaifcation using CNN without we need to flatten the data as explained in this example:
% load digit dataset
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.7, 'randomized');
% define network architecture
layers = [
imageInputLayer([28 28 1], 'Name', 'input')
convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1')
batchNormalizationLayer('Name', 'bn1')
reluLayer('Name', 'relu1')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')
batchNormalizationLayer('Name', 'bn2')
reluLayer('Name', 'relu2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2')
flattenLayer('Name', 'flatten')
selfAttentionLayer(8, 64, 'Name', 'self_attention')
fullyConnectedLayer(10, 'Name', 'fc')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'output')]
% set training options
options = trainingOptions('sgdm', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs', 5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', imdsValidation, ...
'ValidationFrequency', 30, ...
'Verbose', false, ...
'Plots', 'training-progress')
% training the network
net = trainNetwork(imdsTrain, layers, options);
댓글 수: 0
답변 (1개)
Neha
2023년 11월 21일
Hi Mahmoud,
I understand that you want to use self-attention layer in image classification. The self-attention layer, also known as the multi-head self-attention layer, is commonly employed in Transformer models like BERT and vision transformers (ViT). Its primary function is to understand the relationships between positions within the input data. This input data is usually sequential, representing either temporal sequences or 1D spatial information. Therefore it is necessary to use the "flattenLayer" to ensure that the input data to the "selfAttentionLayer" is one directional.
Hope this helps!
댓글 수: 0
참고 항목
카테고리
Help Center 및 File Exchange에서 Image Data Workflows에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!