The situation is asymmetric: when we train the discriminator, we can get the gradients of the loss function without going through the generator. But when we compute the gradients of the generator to train the generator, they (by definition) go through the discriminator since the loss is computed from the output of the discriminator, right? So when we train the discriminator, we detach the generator, but not the other way around. Also note that this is not a “correctness” issue, because we are careful not to actually apply any gradients that we compute that aren’t relevant to the actual training we’re doing on any particular cycle and then we zero them before the next training cycle. It is only a performance issue: it is a non-trivial compute cost to compute the gradients, so it makes sense only to do that when you really need the values.
5 Likes