Detach of the Loss class of pix2pixHD

The class Loss in the optional notebook detach some parameters, do anyone know why they do this?

Line 96-98

# Get necessary outputs for loss/backprop for both generator and discriminator
fake_preds_for_g = discriminator(torch.cat((label_map, boundary_map, x_fake), dim=1))
fake_preds_for_d = discriminator(torch.cat((label_map, boundary_map, x_fake.detach()), dim=1))
real_preds_for_d = discriminator(torch.cat((label_map, boundary_map, x_real.detach()), dim=1))

Why they detach the x_fake and x_real?

Line 66-74

def fm_loss(self, real_preds, fake_preds):
       '''
        Computes feature matching loss from nested lists of fake and real outputs from discriminator.
        '''
        fm_loss = 0.0
        for real_features, fake_features in zip(real_preds, fake_preds):
            for real_feature, fake_feature in zip(real_features, fake_features):
                fm_loss += F.l1_loss(real_feature.detach(), fake_feature)
        return fm_loss

Why the fm_loss detach the real_feature?

Line 76-86

def vgg_loss(self, x_real, x_fake):
        '''
        Computes perceptual loss with VGG network from real and fake images.
        '''
        vgg_real = self.vgg(x_real)
        vgg_fake = self.vgg(x_fake)

        vgg_loss = 0.0
        for real, fake, weight in zip(vgg_real, vgg_fake, self.vgg_weights):
            vgg_loss += weight * F.l1_loss(real.detach(), fake)
        return vgg_loss

Same as vgg loss, why the real need to detach? The vgg weights already frozen

I don’t know the details of this optional assignment, but the general point is that computing gradients is expensive. So they try to avoid computing gradients for portions of the compute graph that you don’t really need for whatever training you are doing at that point. The classic and simplest example is when you train a discriminator. In that case, you don’t need the gradients for the generator, so you detach the outputs of the generator. But when you train the generator, you cannot detach the discriminator: that is because the gradients of the generator depend on the gradients of the discriminator. But we need to be careful not to apply the discriminator gradients when we are training the generator. This is not a “correctness” issue, but just a performance issue. Here’s a thread which discusses this point in the context of the simple case I just described.

1 Like