What is derivative trace in dlgradient function?

조회 수: 18 (최근 30일)
Theron FARRELL
Theron FARRELL 2019년 11월 25일
답변: Gautam Pendse 2020년 1월 14일
Hi there,
I am trying to train a GAN. By exploring MATLAB's official example, I realised the following
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);
And after reading the help of dlgradient(...), I have the following questions:
  1. What is derivative trace in dlgradient function? Consider a two-layered dlnetwork, in which
z=W*input+B;
output = sigmoid(z);
targetOutput = 1 * ones(size(z));
Cost = 0.5*mean(targetOutput-output).^2;
So my guess is that the derivative trace is del(Cost)/del(z)=-(targetOutput-output).*sigmoid(z).*(1-simoid(z), del(Cost)/del(input)=W'*del(Cost)/del(z), etc., is that correct? Or dose it indicate something else? May anyone tell me?
2. If my guess is correct, when I train a GAN and perform dlgradient for the discriminator and the generator in the same dlfeval, will it be the same if I calculate derivatives of the discriminator first? For example
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables,'RetainData',true);
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables);
Because when calculating gradients in the generator, the W's and B's in the discriminator remain unchanged.
3. As I can see in many GAN papers, the key to a successful training of an GAN is that the generator and the discriminator are trained separately, that is, the first set of synthetic (fake) images goes through both the discriminator and the generator, and the discriminator is trained by its cost together with the cost caused by real images so that W's and B's in the discriminator get updated. Then the second set of synthetic images goes through both, and the generator is trained by its cost so that ONLY W's and B's in the generator are updated. In Keras of Python, parameters of a model can be set not trainable explictly. In MATLAB, how can I make sure that it is EXACTLY what happens?
Thanks a lot.

채택된 답변

Gautam Pendse
Gautam Pendse 2020년 1월 14일
Hi Theron,
Re: 1. What is derivative trace in dlgradient function?
** Derivative trace is essentially the history containing a sequence of operations that were executed when computing a given set of values. See this doc page for more info (middle of the page): https://www.mathworks.com/help/deeplearning/ug/include-automatic-differentiation.html
Re: 2. If my guess is correct, when I train a GAN and perform dlgradient for the discriminator and the generator in the same dlfeval, will it be the same if I calculate derivatives of the discriminator first?
** Yes, switching the order of the two dlgradient calls should give the same gradients.
Re: 3. As I can see in many GAN papers, the key to a successful training of an GAN is that the generator and the discriminator are trained separately, that is, the first set of synthetic (fake) images goes through both the discriminator and the generator, and the discriminator is trained by its cost together with the cost caused by real images so that W's and B's in the discriminator get updated. Then the second set of synthetic images goes through both, and the generator is trained by its cost so that ONLY W's and B's in the generator are updated.
** The MATLAB GAN example uses simultaneous gradient descent for optimization. I think your description above refers to alternating gradient descent - another optimization method for GANs. Both methods are described in this paper: https://arxiv.org/abs/1705.10461.
To implement alternating gradient descent, the modelGradients function in MATLAB GAN example can be split into two functions - one computing the loss/gradient for the Discriminator only and the other computing the loss/gradient for the Generator only. Then the following gradient calculation/update sequence can be used:
  1. Compute loss/gradient for the Discriminator
  2. Update the Discriminator
  3. Compute loss/gradient for the Generator (using updated Discriminator)
  4. Update the Generator
Hope that helps,
Gautam

추가 답변 (0개)

태그

제품


릴리스

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by