Your First GAN assignment: use of retain_graph=True

Hi,

I was able to get the full marks for the assignment, but I am still learning PyTorch: I had a question on the use of retain_graph for the discriminator backward call in the provided code.

disc_loss.backward(retain_graph=True)

I checked the PyTorch documentation and forums to understand the parameter. As far as I understand, this parameter is used to retain the graph used for backpropagation calculation to call the backward again on the full or the portion of the graph. I also saw some examples where this would be needed, but I don’t see such requirements in our assignment, and removing this also seems to work fine for me. Is this understanding correct? If not, please help me understand the utility of the parameter here.

Thank you very much!

Here’s another thread that’s worth a look, although I don’t think it directly addresses the role of retain_graph. It explains a couple of key points about how the gradients are managed.

Thank you for the link, Paul! Yes, that post is about the need to detach generator output to avoid calculating gradients which will not be used during the discriminator parameter updates.

I also went through the threads with retain_graph string search on the forums to make sure that the question hasn’t been answered already. I see that this parameter will show up in some coming assignments as well and I might get more insight then but I couldn’t see the need till now.

Thank you very much!