DCGAN Training

Hi, I have some thoughts and a question regarding the Deep Convolutional GAN (DCGAN) training procedure provided in the assignment.
Here is the code for reference:

## Update discriminator ##
disc_opt.zero_grad()
fake_noise = get_noise(cur_batch_size, z_dim, device=device)
fake = gen(fake_noise)
disc_fake_pred = disc(fake.detach())
disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
disc_real_pred = disc(real)
disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2

# Keep track of the average discriminator loss
mean_discriminator_loss += disc_loss.item() / display_step
# Update gradients
disc_loss.backward(retain_graph=True)
# Update optimizer
disc_opt.step()

## Update generator ##
gen_opt.zero_grad()
fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
fake_2 = gen(fake_noise_2)
disc_fake_pred = disc(fake_2)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
gen_loss.backward()
gen_opt.step()

The first thing I’ve noticed is that when we generate fake images for the first time:

fake = gen(fake_noise)

there is an overhead of constructing a computation graph for backpropagation. It is okay if we use this result later in the generator update, but instead, a new fake example fake_2 is generated. So, in this case, it is better to wrap it with a no_grad context manager. This way, there is no need to use the .detach() method because there is no computation graph:

with torch.no_grad():
      fake = gen(fake_noise)
disc_fake_pred = disc(fake)

Alternatively, we could use these examples to update the generator without generating new fake images, while keeping the discriminator update unchanged:

## Update generator ##
gen_opt.zero_grad()
disc_fake_pred = disc(fake)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
gen_loss.backward()
gen_opt.step()

This is more efficient and leads to similar results. The use of fake_noise and fake_noise_2 implies separate noise vectors for discriminator and generator updates. While this works, it might introduce unnecessary variability.

My question is: what is the reason for generating new fake examples?

Another thought is about the retain_graph=True option in the discriminator gradients computation. There is no need for this option because the computational graph is recomputed later in the forward pass here:

disc_fake_pred = disc(fake_2)

In some cases, such as in multitask learning where we have two losses computed from the outputs of different layers, it is necessary to retain the graph after backpropagation over the first loss, but this is not the case here.

1 Like

Hi, Pavel.

It looks like you are already a sophisticated and experienced user of PyTorch when you came to this course. If you’ve looked around at the Deep Learning courses, you’ve probably noticed that all of them except GANs use TensorFlow. So this course is set up to assume that the student is seeing PyTorch for the very first time, as you can tell from the torch tutorial in Week 1. So perhaps they didn’t want to get into the level of detail at which you’ve explained things here and are trying to keep things relatively simple.

I’m pretty sure from your explanation that you already understand the fundamental asymmetry here in that the gradients of the generator depend on those of the discriminator, but you can train the discriminator without the generator’s gradients. Here’s another thread that discusses that point, but I’m guessing it’s already clear to you.

Here’s a previous thread which discusses the difference between no_grad and “detach”, but there again your points are already at a more sophisticated level than the discussion there.

Also note that this is a little awkward because you are showing the solution code in your post. Normally I would delete the source code and leave a marker about it, since we are not supposed to share solution code in a public way (spoils the fun for everyone else and leads to cheating). But without the code, then your points are a lot harder for people to understand. So what I did instead is to “Unlist” this thread, which means that no-one but you and the mentors can see it. That also sort of defeats the purpose. Sorry, just trying to work within the intent of the rules here …

The other thing I could try would be to forward this thread to the course developers, but I’m not sure they are still listening here.

Regards,
Paul

4 Likes

Hi Paul, thank you very much for your reply. The code I posted was provided by the course developers and doesn’t contain any solutions, so I hope it is not violating the code of conduct. Following the links you kindly provided, I see that other community members have also pointed out these subtleties. Course developers likely made this choice to keep things simple and consistent with other assignments. I’m still in the middle of the first course, so it’s hard to tell for sure at this point.

1 Like

Great! Then that’s not a violation. Sorry, I was lazy and didn’t go back and check that actual notebook to see how much of that was code we had to write. I’ll “unhide” this thread then.

1 Like