Can someone explain to me why are we taking gradients of the critic’s scores for the mixed images with respect to the these mixed images rather than critic’s parameters?
And I would really appreciate an intuitive explanation behind “grad_outputs” parameter of torch.autograd.grad
Let’s start a few steps back. In GAN networks, two networks are trained simultaneously: the generator and the critic.
The critic’s neural network is faced with a binary classification problem that the image generated by the generator is “real” or “fake”.
The generator and critic neural networks must be trained simultaneously, and the critic helps the generator generate images closer to actual images. So the difference between the generated image and the actual images is calculated to find how far is the generated output from the actual, but this difference is not enough. We somehow need to give the generator the direction of this distance as well. So we use the gradient of these values.
Its nice of you that you asked this. When I started reading about WGAN, I was also in a dilemma, why are we taking the gradient like that, why are we taking the interpolated mixed image?
So, here is the explanation that I found myself into.
The main ideology depends on making the discriminator not very good and clipping its gradient to a max of 1, such that when the gradient flow, max gradient surges to the generator.
The way to do it is to check the gradient of all the scores of images between the original and generated (w.r.t. the images) to be less than or equal to 1. As you can see the gradient has nothing to do in regards to the model parameters but it has everything to do with the input of the discriminator, which are the images. Hence, we take the gradients w.r.t the images.
Now the important part, its improbable to take all the images in between the original and generated and take their gradients. Think of the original and the generated images as 2 points in a number line. There would be infinite number of points in between them? That’s why we interpolate or deduce a single point (image) in between the original and generated image and check its gradient. If the gradient is higher than 1, we clip it to 1.
Hey @sahar.drfsh
Thank you for your reply
But my question was rather about the gradient penalty concept. I was wondering why when calculating this penalty we are calculating gradients for the mixed images with respect to the mixed images themselves rather than to the critic’s weights
Hey @sohonjit.ghosh
Thanks for your reply
To my understanding, for w-loss in order to estimate the distance correctly, we need to make sure that the critic’s gradients’ norm is less or equal to one. Which is why we add a gradient penalty value to the critic’s loss. What is unclear for me though is why are we calculating the gradients w.r.t to the interpolated images, given that it is the critic’s gradients that we are trying to force to have norm equal to 1 or less?
Wouldn’t it make more sense to calculate these gradients w.r.t. critic’s parameters to the model for having these gradients greater than we need it to be?
As I mentioned and you also mentioned that we want the discriminator/critic to be 1-L continuous right? Now you must remember that just like any other neural network, the critic is also a function (or a chain of mathematical function) and to get it be within 1 is to simply to get the norm of gradients to be within one.
Now why the gradient is w.r.t to the mixed image?
Its entirely because of the input of the critic.
What are the inputs to the critic? Its the generated and the original images, right?
But as I mentioned calculating gradients for all the input images is time consuming hence they find the score of the mixed image and get the gradient of the critic score w.r.t that mixed image.
This is mainly because the highest difference of gradient can be seen between the input and the output and clipping the highest difference to the max of 1 kind off gives surity that the entirety of the discriminator has norm of gradient less than 1, which in turn shows the discriminator being 1-L continuous.