Using fft to replace feature learning in CNN

조회 수: 14 (최근 30일)
Juuso Korhonen
Juuso Korhonen 2021년 1월 19일
댓글: fawad ahmad 2021년 8월 3일
Hello,
I read this interesting article: https://www.groundai.com/project/reducing-deep-network-complexity-with-fourier-transform-methods/1 , where they managed to get really good results with replacing feature learning in CNN with basic fft. I'm very interested to try this out in Matlab, because of the implications that it could relax the requirements for the amount of data (I'm currently working with medical data where sample sizes are often small). But I can't seem to get it to work, since my accuracy stays at 10% in MNIST data, which means that it is basically not learning anything. There must be some major bug, but I can't figure it out. I suspect it has to do with my implementation of the preprocessForTraining function, which is applied as transformation function for the imageDataStore to do fft on the images and the flatten these fft images to 1-D vector to be inputted to featureInputLayer in my simple neural network. (However I think the transformation goes right since I can read an image from the dsTrain and transform it back to original image)
% data read
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% batch the data, so it can do batch normalization in training
miniBatchSize = 128;
imds.ReadSize = miniBatchSize;
% split to training and validation data
numTrainFiles = 750;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
% define a transform which is to be applied everytime data is read
% our transform of choice is in preprocessForTraining function (separate
% file) which includes grayscaling, resizing and fft and flattening into
% 1-d vector
dsTrain = transform(imdsTrain, @preprocessForTraining,'IncludeInfo',true);
dsValidation = transform(imdsValidation, @preprocessForTraining,'IncludeInfo',true);
% Network structure (basic MLP)
% input size is twice the pixel amount due to both real and imaginary part of
% fft
% one hidden layer with half the input size as the number of nodes
% relus as activation functions
layers = [
featureInputLayer(28*28*2)
fullyConnectedLayer(28*28)
reluLayer
fullyConnectedLayer(10)
reluLayer;
softmaxLayer
classificationLayer];
% training options
options = trainingOptions('adam', ...
'Plots','training-progress', ...
'MiniBatchSize',miniBatchSize);
% training
net = trainNetwork(dsTrain,layers,options);
function [dataOut,info] = preprocessForTraining(data,info)
numRows = size(data,1);
dataOut = cell(numRows,2);
targetSize = [28,28];
% since ReadSize is expected to be >1, data comes in cell form containing
% multiple images
for idx = 1:numRows
% get the image out of the datacell
img = data{idx,1};
% if rgb image, turn to grayscale
if size(img, 3) == 3
img = rgb2gray(img);
end
% resize and fft
fft_img = fftshift(fft2(imresize(img, targetSize)));
real_part = real(fft_img);
imag_part = imag(fft_img);
% flatten to vector
imgOut = [real_part(:); imag_part(:)];
% Return the label from info struct as the
% second column in dataOut.
dataOut(idx,:) = {imgOut,info.Label(idx)};
end
end
  댓글 수: 1
fawad ahmad
fawad ahmad 2021년 8월 3일
Brother have you found solution , can you please share code

댓글을 달려면 로그인하십시오.

채택된 답변

Hrishikesh Borate
Hrishikesh Borate 2021년 2월 2일
Hi,
I understand that you are using FFT for feature learning instead of CNN and the accuracy is staying at 10%. This is due to the use of reluLayer before the softmaxLayer in the layer before classificationLayer. You can use the following layer definition, to improve the training results.
layers = [
featureInputLayer(28*28*2)
fullyConnectedLayer(28*28)
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
For more information, refer the Define Network Architecture section in this example.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by