digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'image2500');
digitData = imageDatastore(digitDatasetPath, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
augmenter = imageDataAugmenter('RandXReflection', true);
augimds = augmentedImageDatastore([128 128], digitData, 'DataAugmentation', augmenter)
projectionSize = [4 4 512];
featureInputLayer(numLatentInputs)
projectAndReshapeLayer(projectionSize)
transposedConv2dLayer(filterSize,4*numFilters)
transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
netG = dlnetwork(layersGenerator);
imageInputLayer(inputSize,Normalization="none")
dropoutLayer(dropoutProb)
convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
netD = dlnetwork(layersDiscriminator);
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
validationFrequency = 100;
augimds.MiniBatchSize = miniBatchSize;
mbq = minibatchqueue(augimds, ...
MiniBatchSize=miniBatchSize, ...
PartialMiniBatch="discard", ...
MiniBatchFcn=@preprocessMiniBatch, ...
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
ZValidation = dlarray(ZValidation,"CB");
ZValidation = gpuArray(ZValidation);
numObservationsTrain = numel(imds.Files);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
numIterations = numEpochs*numIterationsPerEpoch;
monitor = trainingProgressMonitor( ...
Metrics=["GeneratorScore","DiscriminatorScore"], ...
Info=["Epoch","Iteration"], ...
groupSubPlot(monitor,Score=["GeneratorScore","DiscriminatorScore"])
while epoch < numEpochs && ~monitor.Stop
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
Z = randn(numLatentInputs,miniBatchSize,"single");
[~,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
dlfeval(@modelLoss,netG,netD,X,Z,flipProb);
[netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
trailingAvg, trailingAvgSqD, iteration, ...
learnRate, gradientDecayFactor, squaredGradientDecayFactor);
[netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...
trailingAvgG, trailingAvgSqG, iteration, ...
learnRate, gradientDecayFactor, squaredGradientDecayFactor);
if mod(iteration,validationFrequency) == 0 || iteration == 1
XGeneratedValidation = predict(netG,ZValidation);
I = imtile(extractdata(XGeneratedValidation));
title("Generated Images");
recordMetrics(monitor,iteration, ...
GeneratorScore=scoreG, ...
DiscriminatorScore=scoreD);
updateInfo(monitor,Epoch=epoch,Iteration=iteration);
monitor.Progress = 100*iteration/numIterations;
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
modelLoss(netG,netD,X,Z,flipProb)
[XGenerated,stateG] = forward(netG,Z);
YGenerated = forward(netD,XGenerated);
scoreD = (mean(YReal) + mean(1-YGenerated)) / 2;
scoreG = mean(YGenerated);
numObservations = size(YReal,4);
idx = rand(1,numObservations) < flipProb;
YReal(:,:,:,idx) = 1 - YReal(:,:,:,idx);
[lossG, lossD] = ganLoss(YReal,YGenerated);
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);
function [lossG,lossD] = ganLoss(YReal,YGenerated)
lossD = -mean(log(YReal)) - mean(log(1-YGenerated));
lossG = -mean(log(YGenerated));
function X = preprocessMiniBatch(data)
X = rescale(X,-1,1,InputMin=0,InputMax=255);