What is derivative trace in dlgradient function?

2 visualizaciones (últimos 30 días)
Theron FARRELL
Theron FARRELL el 25 de Nov. de 2019
Respondida: Gautam Pendse el 14 de En. de 2020
Hi there,
I am trying to train a GAN. By exploring MATLAB's official example, I realised the following
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);
And after reading the help of dlgradient(...), I have the following questions:
  1. What is derivative trace in dlgradient function? Consider a two-layered dlnetwork, in which
z=W*input+B;
output = sigmoid(z);
targetOutput = 1 * ones(size(z));
Cost = 0.5*mean(targetOutput-output).^2;
So my guess is that the derivative trace is del(Cost)/del(z)=-(targetOutput-output).*sigmoid(z).*(1-simoid(z), del(Cost)/del(input)=W'*del(Cost)/del(z), etc., is that correct? Or dose it indicate something else? May anyone tell me?
2. If my guess is correct, when I train a GAN and perform dlgradient for the discriminator and the generator in the same dlfeval, will it be the same if I calculate derivatives of the discriminator first? For example
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables,'RetainData',true);
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables);
Because when calculating gradients in the generator, the W's and B's in the discriminator remain unchanged.
3. As I can see in many GAN papers, the key to a successful training of an GAN is that the generator and the discriminator are trained separately, that is, the first set of synthetic (fake) images goes through both the discriminator and the generator, and the discriminator is trained by its cost together with the cost caused by real images so that W's and B's in the discriminator get updated. Then the second set of synthetic images goes through both, and the generator is trained by its cost so that ONLY W's and B's in the generator are updated. In Keras of Python, parameters of a model can be set not trainable explictly. In MATLAB, how can I make sure that it is EXACTLY what happens?
Thanks a lot.

Respuesta aceptada

Gautam Pendse
Gautam Pendse el 14 de En. de 2020
Hi Theron,
Re: 1. What is derivative trace in dlgradient function?
** Derivative trace is essentially the history containing a sequence of operations that were executed when computing a given set of values. See this doc page for more info (middle of the page): https://www.mathworks.com/help/deeplearning/ug/include-automatic-differentiation.html
Re: 2. If my guess is correct, when I train a GAN and perform dlgradient for the discriminator and the generator in the same dlfeval, will it be the same if I calculate derivatives of the discriminator first?
** Yes, switching the order of the two dlgradient calls should give the same gradients.
Re: 3. As I can see in many GAN papers, the key to a successful training of an GAN is that the generator and the discriminator are trained separately, that is, the first set of synthetic (fake) images goes through both the discriminator and the generator, and the discriminator is trained by its cost together with the cost caused by real images so that W's and B's in the discriminator get updated. Then the second set of synthetic images goes through both, and the generator is trained by its cost so that ONLY W's and B's in the generator are updated.
** The MATLAB GAN example uses simultaneous gradient descent for optimization. I think your description above refers to alternating gradient descent - another optimization method for GANs. Both methods are described in this paper: https://arxiv.org/abs/1705.10461.
To implement alternating gradient descent, the modelGradients function in MATLAB GAN example can be split into two functions - one computing the loss/gradient for the Discriminator only and the other computing the loss/gradient for the Generator only. Then the following gradient calculation/update sequence can be used:
  1. Compute loss/gradient for the Discriminator
  2. Update the Discriminator
  3. Compute loss/gradient for the Generator (using updated Discriminator)
  4. Update the Generator
Hope that helps,
Gautam

Más respuestas (0)

Etiquetas

Productos


Versión

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by