Conditional GAN Training Error for TrainGAN function

조회 수: 8 (최근 30일)
Yang Liu
Yang Liu 2024년 6월 19일
댓글: Yang Liu 2024년 6월 28일
I try to make the Conditional GAN training working with input as a 2D matrix: 14*8.
I try to mimic the "GenerateSyntheticPumpSignalsUsingCGANExample", by changing the vector input as a 2D matrix input.
The error message pops out as:
It seems that there is a size mismatch in the function modelGradients. But since this is an official example, thus I have no idea how to revise it. Can someone give a hint?
The input data is attached as: test.mat
The training script is attached as: untitled3.m. I have also pasted it below.
clear;
%% Load the data
% LSTM_Reform_Data_SeriesData1_20210315_data001_for_GAN;
% load('LoadedData_20210315_data001_for_GAN.mat')
load('test.mat');
% load('test2.mat');
%% Generator Network
numFilters = 4;
numLatentInputs = 120;
projectionSize = [2 1 63];
numClasses = 2;
embeddingDimension = 120;
layersGenerator = [
imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','Input')
projectAndReshapeLayer(projectionSize,numLatentInputs,'ProjReshape');
concatenationLayer(3,2,'Name','Concate1');
transposedConv2dLayer([3 2],8*numFilters,'Stride',1,'Name','TransConv1') % 4*2*32
batchNormalizationLayer('Name','BN1','Epsilon',1e-5)
reluLayer('Name','Relu1')
transposedConv2dLayer([2 2],4*numFilters,'Stride',2,'Name','TransConv2') % 8*4*16
batchNormalizationLayer('Name','BN2','Epsilon',1e-5)
reluLayer('Name','Relu2')
transposedConv2dLayer([2 2],2*numFilters,'Stride',2,'Cropping',[2 1],'Name','TransConv3') % 12*6*8
batchNormalizationLayer('Name','BN3','Epsilon',1e-5)
reluLayer('Name','Relu3')
transposedConv2dLayer([3 3],2*numFilters,'Stride',1,'Name','TransConv4') % 14*8*1
];
lgraphGenerator = layerGraph(layersGenerator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'EmbedReshape1')];
lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'EmbedReshape1','Concate1/in2');
subplot(1,2,1);
plot(lgraphGenerator);
dlnetGenerator = dlnetwork(lgraphGenerator);
%% Discriminator Network
scale = 0.2;
Input_Num_Feature = [14 8 1]; % The input data is [14 8 1]
layersDiscriminator = [
imageInputLayer(Input_Num_Feature,'Normalization','none','Name','Input')
concatenationLayer(3,2,'Name','Concate2')
convolution2dLayer([2 2],4*numFilters,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv1')
leakyReluLayer(scale,'Name','LeakyRelu1')
convolution2dLayer([2 4],2*numFilters,'Stride',2,'DilationFactor',1,'Padding',[2 2],'Name','Conv2')
leakyReluLayer(scale,'Name','LeakyRelu2')
convolution2dLayer([2 2],numFilters,'Stride',2,'DilationFactor',1,'Padding',[0 0],'Name','Conv3')
leakyReluLayer(scale,'Name','LeakyRelu3')
convolution2dLayer([2 1],numFilters/2,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv4')
leakyReluLayer(scale,'Name','LeakyRelu4')
convolution2dLayer([2 2],numFilters/4,'Stride',1,'DilationFactor',1,'Padding',[0 0],'Name','Conv5')
];
lgraphDiscriminator = layerGraph(layersDiscriminator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(Input_Num_Feature,embeddingDimension,numClasses,'EmbedReshape2')];
lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'EmbedReshape2','Concate2/in2');
subplot(1,2,2);
plot(lgraphDiscriminator);
dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
%% Train model
params.numLatentInputs = numLatentInputs;
params.numClasses = numClasses;
params.sizeData = [Input_Num_Feature length(Series_Fused_Label)];
params.numEpochs = 50;
params.miniBatchSize = 512;
% Specify the options for Adam optimizer
params.learnRate = 0.0002;
params.gradientDecayFactor = 0.5;
params.squaredGradientDecayFactor = 0.999;
executionEnvironment = "cpu";
params.executionEnvironment = executionEnvironment;
% for test, 14*8*30779
[dlnetGenerator,dlnetDiscriminator] =...
trainGAN(dlnetGenerator,dlnetDiscriminator,Series_Fused_Expand_Norm_Input,Series_Fused_Label,params);

채택된 답변

Garmit Pant
Garmit Pant 2024년 6월 21일
Hello Yang Liu
From what I understand, you are following the “Generate Synthetic Signals Using Conditional GAN” MATLAB example to train a conditional GAN work with a 2-Dimensional input.
The error you have encountered is occurring due to a mismatch in the output dimension of the generator network and the input dimension of the discriminator network.
The discriminator network has been adapted correctly for the use case and expects a 14x8x1 input. The last transposed convolutional layer of the generator network has ‘numFilters’ set as 8. This results in an output dimension of 14x8x8. Kindly make the following change to fix the network.
transposedConv2dLayer([3 3],1,'Stride',1,'Name','TransConv4') % 14*8*1
Additionally, you can either comment out or remove the following lines from the ‘trainGAN.m’ file since they are used to visualise the signal data specific to the example and for that reason it will throw an error for your specific data.
% if mod(ct,50) == 0 || ct == 1
% % Generate signals using held-out generator input
% dlXGeneratedValidation = predict(dlnetGenerator, dlZValidation, dlTValidation);
% dlXGeneratedValidation = squeeze(extractdata(gather(dlXGeneratedValidation)));
%
% % Display spectra of validation signals
% subplot(1,2,1);
% pspectrum(dlXGeneratedValidation);
% set(gca, 'XScale', 'log')
% legend('Healthy', 'Faulty')
% title("Spectra of Generated Signals")
% end
For further understanding, I suggest you refer to the following MathWorks Documentation and resources:
  1. Refer to the “Input Arguments” section to understand further about various parameters of ‘transposedConv2dlayer’: https://www.mathworks.com/help/releases/R2023b/deeplearning/ref/transposedconv2dlayer.html
I hope you find the above explanation and suggestions useful!
  댓글 수: 1
Yang Liu
Yang Liu 2024년 6월 28일
Dear Garmit,
Thank you so much for your kind help! This is my bad and overlook.I should check the last line of constructing the Generator more carefully.
Yes, the part in TrainGAN related with signal spectrum should be commented out, otherwise it may report other errors. I will revise that part as I want to observe the generated signals while I train the GAN network.
Thanks again for your kind help!
Yang Liu

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Measurements and Feature Extraction에 대해 자세히 알아보기

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by