필터 지우기
필터 지우기

VAE import custom data

조회 수: 7 (최근 30일)
Glacial Claw
Glacial Claw 2023년 2월 8일
편집: Glacial Claw 2023년 2월 20일
I was able to implement this line by line. That being said, I am trying to get my own custom data uploaded to the Ai for training.
I have 2 rar files, with color images (1024x1024 [will be resized]). The first file has a lot of images for training, and the second file has a few images for testing the algorithm.
The issue is, the VAE used in the link uses a function called "processImagesMNIST", I have looked into the function that makes this up, not sure on how I have to structure my training and testing data to meet the criteria that the function allows for.
Or is there another way to upload my own training data that the VAE can use, without using the special function?
function X = processImagesMNIST(filename)
% The MNIST processing functions extract the data from the downloaded IDX
% [What are IDX files???]
% files into MATLAB arrays. The processImagesMNIST function performs these
% operations: Check if the file can be opened correctly. Obtain the magic
% number by reading the first four bytes. The magic number is 2051 for
% image data, and 2049 for label data. Read the next 3 sets of 4 bytes,
% which return the number of images, the number of rows, and the number of
% columns. Read the image data. Reshape the array and swaps the first two
% dimensions due to the fact that the data was being read in column major
% format. Ensure the pixel values are in the range [0,1] by dividing them
% all by 255, and converts the 3-D array to a 4-D dlarray object. Close the
% file.
%[What do I need to do to make my dataset applicable for this function? Or is there another way to implement my own data?]
dataFolder = fullfile(tempdir,'mnist');
gunzip(filename,dataFolder)
[~,name,~] = fileparts(filename);
[fileID,errmsg] = fopen(fullfile(dataFolder,name),'r','b');
if fileID < 0
error(errmsg);
end
magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2051
fprintf('\nRead MNIST image data...\n')
end
numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of images in the dataset: %6d ...\n',numImages);
numRows = fread(fileID,1,'int32',0,'b');
numCols = fread(fileID,1,'int32',0,'b');
X = fread(fileID,inf,'unsigned char');
X = reshape(X,numCols,numRows,numImages);
X = permute(X,[2 1 3]);
X = X./255;
X = reshape(X, [28,28,1,size(X,3)]);
fclose(fileID);
end

답변 (1개)

Image Analyst
Image Analyst 2023년 2월 8일
I would just extract all your images to regular image files (like PNG images) and then instead of calling
X = processImagesMNIST(filename)
just use imread()
X = imread(filename)
  댓글 수: 18
Glacial Claw
Glacial Claw 2023년 2월 16일
Ok I was able to get it to read the images, now the console shows each image being read.
However, now I get an error where the concatanation dimensions are incorrect.
Based on the "processImagesMNIST" function, there are these commands: (I wonder if this has anything to do with it.)
numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of images in the dataset: %6d ...\n',numImages);
numRows = fread(fileID,1,'int32',0,'b');
numCols = fread(fileID,1,'int32',0,'b');
X = fread(fileID,inf,'unsigned char');
X = reshape(X,numCols,numRows,numImages);
X = permute(X,[2 1 3]);
X = X./255;
X = reshape(X, [28,28,1,size(X,3)]);
Glacial Claw
Glacial Claw 2023년 2월 17일
편집: Glacial Claw 2023년 2월 20일
So, I was able to fix the concat error, by swapping the "train_datastored" with "thisImage" variables. Now I have a new issue because of this.
Is there a way to change the image channel dimension?
This is the entire code I am using, I have redacted the file paths, as they are not necessary and doesn't cause any issues.
filePattern = fullfile('(redacted)', '*.jpg');
theFiles = dir(filePattern);
for k = 1 : length(theFiles)
[filepath,name,ext] = fileparts(filePattern);
baseFileName = theFiles(k).name;
fullFileName = fullfile(theFiles(k).folder, baseFileName);
fprintf('Now reading %s\n', fullFileName);
thisImage = imread(fullFileName);
thisImage_resized = imresize(thisImage, [28 28]);
if k == 1
Train_datastored = thisImage_resized;
else
Train_datastored = cat(4,thisImage_resized, Train_datastored);
end
end
XTrain = Train_datastored;
filePattern_test = fullfile('(redacted)', '*.jpg');
theFiles_test = dir(filePattern_test);
for k = 1 : length(theFiles_test)
[filepath_test,name,ext_test] = fileparts(filePattern_test);
baseFileName_test = theFiles_test(k).name;
fullFileName_test = fullfile(theFiles_test(k).folder, baseFileName_test);
fprintf('Now reading %s\n', fullFileName_test);
thisImage_test = imread(fullFileName_test);
thisImage_test_resized = imresize(thisImage_test, [28 28]);
if k == 1
Test_datastored = thisImage_test_resized;
else
Test_datastored = cat(4,thisImage_test_resized, Test_datastored);
end
end
XTest = Test_datastored;
numLatentChannels = 16;
imageSize = [28 28 1];
layersE = [
imageInputLayer(imageSize,Normalization="none")
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
fullyConnectedLayer(2*numLatentChannels)
samplingLayer];
projectionSize = [7 7 64];
numInputChannels = size(imageSize,1);
layersD = [
featureInputLayer(numLatentChannels)
projectAndReshapeLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
numEpochs = 120;
miniBatchSize = 50;
learnRate = 1e-3;
dsTrain = arrayDatastore(XTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB", ...
PartialMiniBatch="discard");
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
numObservationsTrain = size(XTrain,4);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
monitor = trainingProgressMonitor( ...
Metrics="Loss", ...
Info="Epoch", ...
XLabel="Iteration");
epoch = 0;
iteration = 0;
dsTest = arrayDatastore(XTest,IterationDimension=4);
numOutputs = 1;
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
err = mean((XTest-YTest).^2,[1 2 3]);
figure
histogram(err)
numImages = 2;
ZNew = randn(numLatentChannels,numImages);
ZNew = dlarray(ZNew,"CB");
YNew = predict(netD,ZNew);
YNew = extractdata(YNew);

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Custom Training Loops에 대해 자세히 알아보기

제품


릴리스

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by