Using fft to replace feature learning in CNN

7 visualizaciones (últimos 30 días)
Juuso Korhonen
Juuso Korhonen el 19 de En. de 2021
Comentada: fawad ahmad el 3 de Ag. de 2021
Hello,
I read this interesting article: https://www.groundai.com/project/reducing-deep-network-complexity-with-fourier-transform-methods/1 , where they managed to get really good results with replacing feature learning in CNN with basic fft. I'm very interested to try this out in Matlab, because of the implications that it could relax the requirements for the amount of data (I'm currently working with medical data where sample sizes are often small). But I can't seem to get it to work, since my accuracy stays at 10% in MNIST data, which means that it is basically not learning anything. There must be some major bug, but I can't figure it out. I suspect it has to do with my implementation of the preprocessForTraining function, which is applied as transformation function for the imageDataStore to do fft on the images and the flatten these fft images to 1-D vector to be inputted to featureInputLayer in my simple neural network. (However I think the transformation goes right since I can read an image from the dsTrain and transform it back to original image)
% data read
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% batch the data, so it can do batch normalization in training
miniBatchSize = 128;
imds.ReadSize = miniBatchSize;
% split to training and validation data
numTrainFiles = 750;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
% define a transform which is to be applied everytime data is read
% our transform of choice is in preprocessForTraining function (separate
% file) which includes grayscaling, resizing and fft and flattening into
% 1-d vector
dsTrain = transform(imdsTrain, @preprocessForTraining,'IncludeInfo',true);
dsValidation = transform(imdsValidation, @preprocessForTraining,'IncludeInfo',true);
% Network structure (basic MLP)
% input size is twice the pixel amount due to both real and imaginary part of
% fft
% one hidden layer with half the input size as the number of nodes
% relus as activation functions
layers = [
featureInputLayer(28*28*2)
fullyConnectedLayer(28*28)
reluLayer
fullyConnectedLayer(10)
reluLayer;
softmaxLayer
classificationLayer];
% training options
options = trainingOptions('adam', ...
'Plots','training-progress', ...
'MiniBatchSize',miniBatchSize);
% training
net = trainNetwork(dsTrain,layers,options);
function [dataOut,info] = preprocessForTraining(data,info)
numRows = size(data,1);
dataOut = cell(numRows,2);
targetSize = [28,28];
% since ReadSize is expected to be >1, data comes in cell form containing
% multiple images
for idx = 1:numRows
% get the image out of the datacell
img = data{idx,1};
% if rgb image, turn to grayscale
if size(img, 3) == 3
img = rgb2gray(img);
end
% resize and fft
fft_img = fftshift(fft2(imresize(img, targetSize)));
real_part = real(fft_img);
imag_part = imag(fft_img);
% flatten to vector
imgOut = [real_part(:); imag_part(:)];
% Return the label from info struct as the
% second column in dataOut.
dataOut(idx,:) = {imgOut,info.Label(idx)};
end
end
  1 comentario
fawad ahmad
fawad ahmad el 3 de Ag. de 2021
Brother have you found solution , can you please share code

Iniciar sesión para comentar.

Respuesta aceptada

Hrishikesh Borate
Hrishikesh Borate el 2 de Feb. de 2021
Hi,
I understand that you are using FFT for feature learning instead of CNN and the accuracy is staying at 10%. This is due to the use of reluLayer before the softmaxLayer in the layer before classificationLayer. You can use the following layer definition, to improve the training results.
layers = [
featureInputLayer(28*28*2)
fullyConnectedLayer(28*28)
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
For more information, refer the Define Network Architecture section in this example.

Más respuestas (0)

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by