Why should we detach the discriminators input ?!

in the first programming assignment of week 4, in the training part, it has been commented that “remember to detach the generator” and I detached the fake variable before giving it to the discriminator and it worked. but it doesn’t make sense to me because when we do backpropagation to update the generator, the computational graph starts from the loss values then goes back through the discriminator without updating its parameters, and from the discriminators input it finds its way back to the generator. here the discriminators input is the one that had been detached from the computational graph, so how does it work?
Thanks in advance for your comments.

The situation is asymmetric: when we train the discriminator, we can get the gradients of the loss function without going through the generator. But when we compute the gradients of the generator to train the generator, they (by definition) go through the discriminator since the loss is computed from the output of the discriminator, right? So when we train the discriminator, we detach the generator, but not the other way around. Also note that this is not a “correctness” issue, because we are careful not to actually apply any gradients that we compute that aren’t relevant to the actual training we’re doing on any particular cycle and then we zero them before the next training cycle. It is only a performance issue: it is a non-trivial compute cost to compute the gradients, so it makes sense only to do that when you really need the values.

5 Likes

Thanks for your explanation, but there is a thing that I didn’t get. you said that when we train the discriminator we detach the generator but not the other way around. however, when we want to train the generator we should go through the discriminator and then we will face the discriminators input which had been detached.

  1. we start from gen.loss.backward()
  2. then criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)
  3. then we go to disc(fake_image_and_labels)
  4. then torch.cat((fake.float().detach(), image_one_hot_labels.float()), 1)
    the ones that are bold are those that matter in backpropogation process


I would like to add more details to the answer above.

When we train the Discriminator, we don’t want to track operations for the Generator, since we are not going to update it or use its gradients. So we can speed up the training by detaching (from the computational graph) the output of the Generator.

When we train the Generator, we need to calculate the gradients of the Discriminator, but we won’t update it. Note that the update is done by calling optimizer.step() and each model has its own optimizer, whereas backward just calculates gradients without updating.

Here is an illustration of the generator training step. (Note that we pass true label to the discriminator to calculate gradients towards real data)

You can read this thread to gain more intuition behind this.

8 Likes

To me, this is similar to a situation when you hire a Karate coach to train two guys, one on how to defend and the other how to attack. In order for they to both learn, one need to change (improve) first and as the consequence the other guy also needs to update (improve) accordingly.

Other than that, I just can’t image how the coach could help both two guys to improve their skill at a single moment of training. Even if the coach can do that (let’s say, for example, you have two coaches now instead of one in every single training session), then the supposed improvement on attack skill may mismatch the supposed improvement on defense. So then the coach(es) will have a hard time to know if the two guys are both learned.

To make it more clearly, I’d like to summary in the following steps:

  1. When the discriminator tries to sharpen its skill to classify a batch of generated images as reals or fakes, it simply just doesn’t care about how to improve generating images because image generation falls into generator’s responsibility.
  2. However, when the discriminator tries to do its job, it needs generated images from the generator. So it calls into the generator to help it generate some images.
  3. In one training batch (imagine it’s like a training session), we need to make sure both the discriminator & generator get improved (so one doesn’t outsmart the other). However, we can only improve one after another, in this case improve the discriminator and then generator.
1 Like