retain_graph=True and disc_opt.zero_grad()

Hi All,
In course1 week 2 assignment, when training GAN, we first train discriminator by

  1. disc_opt.zero_grad()
  2. rest process
  3. disc_loss.backward(retain_graph=True)
  4. disc_opt.step()
    then we train generator
  5. gen_opt.zero_grad()
  6. rest process
  7. gen_loss.backward()
  8. gen_opt.step()

My question is why do we need to retain_graph of discriminator, even if its freed it will be computed in generator training. Also shouldn’t we do disc_opt.zero_grad() before starting generator training as well?

1 Like

@Ashish_Jha1,
I agree about retain_graph. It doesn’t look like there’s any real reason to use it here. retain_graph=True means we want to keep the graph that was used internally to calculate the gradients, which we would only need if we we wanted to do another backward() with the same gradients. It’s a very rare situation and not one we need here.
As an experiment, you can try removing the “retain_graph=True” and see that everything still works fine. If you did need it, you’d get an error when you called that second backward() that was relying on the graph still being there.

You shouldn’t need the disc_opt.zero_grad() for the generator training since the generator has its own separate loss and optimizer objects (gen_loss and gen_opt), and they should keep track of the gradients that affect the generator separately. TBH, when you raised this question, I wasn’t entirely sure this was true even though that’s how we’d like it to be implemented, but I did a little experiment to confirm - whether or not I do disc_opt.zero_grad() for the generator I get the same values for generator loss after training.

When training generator, .detach() is not called on discriminator. when we train the generator, we need the gradients of the discriminator, because the cost by definition is computed using the output of the discriminator (Detach() used in Assignment 4 - #4 by paulinpaloalto). now if we don’t do disc_opt.zero_grad() before generator training (after discriminator training) , won’t gradient from previous discriminator training accumulate while training generator

@Ashish_Jha1,
Definitely it’s important that you don’t detach() the image passed to the discriminator when training the generator. That value needs to be included because, as you point out, the cost is computed using the output of the discriminator. If you detached, anything in the chain that uses that value would be excluded from consideration during backprop - which basically would leave nothing left and the generator wouldn’t learn.

But, you’ll notice that disc_opt is initialized with the parameters from the discriminator model and gen_opt is initialized with the parameters from the generator model. When we call gen_opt.step(), it is only updating the generator’s parameters, just as when we call gen_opt.zero_grad(), it is only zeroing out the gradient info for the generator’s parameters. It’s true that the backprop is adding to the gradient values for both the generator and discriminator’s parameters (since we are using both discriminator and generator with no detaching), but we are only actually updating the generator’s parameters, and the gradients for the discriminator’s parameters will be zeroed out when we call disc_opt.zero_grad() before we do our next pass on the discriminator, so we’re good there.

You can try running with and without the disc_opt.zero_grad() in the generator training loop to see that there’s no real noticeable difference.

Thanks @Wendy for reply. While training generator, shouldn’t discriminator be in .eval() mode (ex: disc.eval()). Because we don’t want its (batchnorm) get updated in generator training and only want its output for generator.

@Ashish_Jha1, .eval() mode is really only applicable to a couple of types of layers, like dropout and batchnorm, that behave differently during training vs inference. For this assignment, we don’t use any of these special types of layers, so we don’t need to worry about .eval() mode.

The key points for this exercise are that when training the generator:

  • We should not detach, because the generator cost relies on the discriminator and we need to chain through the operations in the discriminator in order to calculate the gradients for the generator

  • We only apply the gradients for the generator, so it doesn’t matter that we’ve calculated some gradients for the discriminator. (And we will zero the discriminator gradients before we do our step to train the discriminator.)

I hope this helps. If it’s still fuzzy, the most helpful thing might be for you to try playing around with some different settings to see the effect - such as the outcome after training, including maybe looking at what happens to the gen and disc parameters at various stages.

Thanks @Wendy for explanation…