C2_W1_Assignment pairwise multivariate distributions

At the end of the assignment, we have the following code:

indices = [2, 4, 5]
#indices = [20, 14, 8] # own trial

fake_dist = MultivariateNormal(mu_fake[indices], sigma_fake[indices][:, indices])
fake_samples = fake_dist.sample((5000,))

real_dist = MultivariateNormal(mu_real[indices], sigma_real[indices][:, indices])
real_samples = real_dist.sample((5000,))

import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(data = df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()

Due to lack of enough clarification and explanation in the lab, can somebody please breakdown and explain that part of the code? I am mainly interested in the following questions:

  1. What are we practically going to accomplish using such pairwise multivariate distributions of the inception? I am unable to read and understand the plotted figure.
  2. Does the resulted figure look good, bad, somewhat good or something else?
  3. what does indices = [2, 4, 5] really indicate?
  4. what should I get/observe in the resulted figures to be convinced about the fake generated images?
  5. since FID is relatively large, can we argue that other possible models can be used instead of inception_v3 to return lower FID ? if so, which model?

Cheers,

@mrgransky, good point that the assignment doesn’t give you a lot of information about what you’re seeing. The general idea is that the pairplot is letting you compare the distributions of the selected features from the real and fake images. Mainly, this is about giving you some visualization about how these features compare.

The indices ([2, 4, 5] by default, [20, 14, 8] in your own trial) are just picking out which features from the fake and real embedding you want to look at. We pick out just a small number of indices since it’s much easier for us to visualize a small number. In the pairplot, you can see these index numbers along the x and y axes.

Along the diagonal, you see the distribution curve for how the real and fake distributions compare for that feature. For example, when you are using indices [2, 4, 5], the upper left position shows the real & fake distributions for feature 2. Since the goal of a good GAN is to look like the real, then looking at how similar (or not) these are can give you some sense of how well your GAN is doing at reproducing this feature.

The non-diagonal cells plot the relative distributions of the two features identified by the corresponding indices. This gives you an idea of the covariance between the two features. If the distribution is fairly circular, then there is not much dependency of one feature on the other. If the shape is more of an oval, then there is more of a dependency. You may notice that you don’t see a lot of dots for the blue fake features in these distribution plots. This is because the orange real dots are on top of the blues. To see “through” the orange dots better to see the blue dots underneath, you can try setting the alpha parameter to pairplot to something even smaller to make the dots even more transparent.

For a little more info:

The “Frechet Inception Distance (FID)” video for this week gives a little more intuition about the shape of these multivariant plots.

The pairplot documentation gives more details about the pairplot function.